我们之前谈到了2元分类,但是有时候我们需要多元分类,这时候sigmoid函数就不再适用了。

假如我们需要三个分类,而输出层在激活函数之前得到的值为3.,4.,5. ,如果我们用sigmoid:

sess.run(tf.nn.sigmoid([3.,4.,5.]))
array([0.95257413, 0.98201376, 0.9933072 ], dtype=float32)

我们可以看到,输出结果并不能很好的分类。如果改用softmax:

sess.run(tf.nn.softmax([3.,4.,5.]))
array([0.09003057, 0.24472848, 0.66524094], dtype=float32)

状况则要好得多。

下面是原文对softmax的介绍:

Softmax 选项

请查看以下 Softmax 变体:

  • 完整 Softmax 是我们一直以来讨论的 Softmax;也就是说,Softmax 针对每个可能的类别计算概率。

  • 候选采样指 Softmax 针对所有正类别标签计算概率,但仅针对负类别标签的随机样本计算概率。例如,如果我们想要确定某个输入图片是小猎犬还是寻血猎犬图片,则不必针对每个非狗狗样本提供概率。

类别数量较少时,完整 Softmax 代价很小,但随着类别数量的增加,它的代价会变得极其高昂。候选采样可以提高处理具有大量类别的问题的效率。

一个标签与多个标签

Softmax 假设每个样本只是一个类别的成员。但是,一些样本可以同时是多个类别的成员。对于此类示例:

  • 您不能使用 Softmax。
  • 您必须依赖多个逻辑回归。

例如,假设您的样本是只包含一项内容(一块水果)的图片。Softmax 可以确定该内容是梨、橙子、苹果等的概率。如果您的样本是包含各种各样内容(几碗不同种类的水果)的图片,您必须改用多个逻辑回归。

下面是mnist,学神经网络肯定会在不同程度上接触mnist数据集,这里我们用dnn的框架来识别mnist图片(cnn效果更佳)。

import numpy as np
from MNIST_data import input_data
from tensorflow.data import Dataset
import tensorflow as tf
import pandas as pd
from tensorflow.contrib import layers
import matplotlib.pyplot as pltmnist = input_data.read_data_sets('./MNIST_data', one_hot=True)def random_data(xs, ys):df_xs = pd.DataFrame(xs)df_ys = pd.DataFrame(ys)df_concat = pd.concat([df_xs, df_ys], axis=1)df_concat = df_concat.reindex(np.random.permutation(df_concat.index))df_concat = df_concat.sort_index()df_features = df_concat.iloc[::, 0:784]df_targets = df_concat.iloc[::, 784::]return np.matrix(df_features), np.matrix(df_targets)def my_input_fn(features, labels, batch_size=1, num_epochs=1, shuffle=False):ds = Dataset.from_tensor_slices((features, labels))ds = ds.batch(batch_size).repeat(num_epochs)if shuffle:ds.shuffle(10000)features, labels = ds.make_one_shot_iterator().get_next()return features, labelsdef add_layer(inputs, inputs_size, outputs_size, activation_function=None):weights = tf.Variable(tf.random_normal([inputs_size, outputs_size], stddev=.1))tf.add_to_collection('loss', layers.l1_regularizer(0.001)(weights))biases = tf.Variable(tf.zeros([outputs_size]) + 0.1)wx_b = tf.matmul(inputs, weights) + biasesif activation_function is None:outputs = wx_belse:outputs = activation_function(wx_b)return weights, biases, outputsdef _loss(pred, ys):loss = -tf.reduce_sum(ys*tf.log(pred))loss_total = loss + tf.add_n(tf.get_collection('loss'))return loss_totaldef train_step(learning_rate, loss):train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)return train_stepdef accuracy(vx_pred, vy):correct_prediction = tf.equal(tf.argmax(vx_pred, 1), tf.argmax(vy, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))return accuracyxs = mnist.train.images
ys = mnist.train.labels
ys = ys.astype('float32')
xs, ys = random_data(xs, ys)
vx = mnist.validation.images
vy = mnist.validation.labels
vy = vy.astype('float32')
global_step = tf.Variable(0, trainable=False)
start_learning_rate = .001
lr = tf.train.exponential_decay(start_learning_rate, global_step, 100, 0.95, staircase=True)x_input, y_input =my_input_fn(xs, ys, batch_size=50, num_epochs=2)
vx_input, vy_input = my_input_fn(vx, vy, batch_size=5000, num_epochs=40)
w1, b1, l1 = add_layer(x_input, 784, 200, activation_function=tf.nn.tanh)
w2, b2, l2 = add_layer(l1, 200, 100, activation_function=tf.nn.tanh)
w3, b3, pred = add_layer(l2, 100, 10, activation_function=tf.nn.softmax)loss = _loss(pred, y_input)
train = train_step(lr, loss)sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)vl1 = tf.nn.tanh(tf.matmul(vx_input, w1) + b1)
vl2 = tf.nn.tanh(tf.matmul(vl1, w2) + b2)
vpred = tf.matmul(vl2, w3) + b3
v_accuracy = accuracy(vpred, vy)for i in range(2000):sess.run(train)if i % 50 == 0:print(sess.run(v_accuracy))img = sess.run(w1).copy()
plt.figure(1)for i in range(200):plt.subplot(10, 20, i+1)plt.imshow(img[:, i].reshape(28, 28), cmap='binary')plt.show()

导入后我们发现labels的dtype是float64我们先改成float32。

下一步我们对训练集打乱顺序,方法与之前一样,先把矩阵变成DataFrame,然后打乱index再变回矩阵。

这里我改用了动态的learning_rate(直接设置一个参数也可以)。

tf.train.exponential_decay的几个参数,首先是初始学习速率,其次是步数,下一个是每多少步更新多少速率,下一个是更新后的速率。

这里其实就是每隔100步,lr*0.95, staircase就是是否每隔100步再更新,默认为false,如果是false的话,每一步都会进行更新,但是整体速率是不变的。

loss与我们之前写的不太一样,这里是交叉熵函数。之所以没有了-ys*tf.log(1-pred)是因为  softmax中每一个参数调整均会影响其他参数,前面已经介绍过了,不信你可以试试 tf.nn.softmax([3,4,5]) 和 tf.nn.softmax([3,4,6])输出的结果有何不同。

这里我们输入的784个features有很多0,我希望可以让部分权重变为0,原因以前提到过,节省RAM并且降低噪点。如果忘记了可以往前翻一翻或者直接看官方教程的(正则化:稀疏性)。

argmax则是找到当前最大值所在的位置,这里举个例子:

>>>x = np.eye(5)
>>>print(sess.run(tf.argmax(x,1)))array([0, 1, 2, 3, 4])

tf.equal则是判断元素是否相等返回布尔值, tf.cast则是转化为其他类型,我们这里转化为tf.float32.

依旧举个例子:

>>>print(sess.run(tf.equal(3,5)))False>>>print(sess.run(tf.cast(tf.equal([3,4],[5,4]), tf.float32)))array([0., 1.], dtype=float32)

至于图片显示,对我这种古董电脑真的很吃力。因此我在训练完成后显示到了一个figure中,不过什么也看不清。

figure过多的话,非常占用内存,我们这里可以创建10个figure,每个显示20张图片,或者干脆将第一层的weights数减少,这样更方便观看变化。

最后的准确率应该在96%-97%左右。如果你愿意调参的话,最后应该会到98%以上,但是表现依旧不如cnn(不过dnn运行起来速度要快的多的多)。

官方cnn的教程应该还会用到mnist集,后续问题我们之后再谈。

最后说一下我数据集的位置。我的是在当前py文件目录下创建了一个文件夹叫 MNIST_data,并且把网上下载好的4个.gz的压缩文件放入其中,并且把tensorflow.examples.tutorials.mnist.input_data.py 复制到了该目录下,因此直接import

在远古版本的tensorflow中, 直接from tensorflow.examples.tutorials.mnist import input_data 进行操作会报错,具体报错什么记不清了。因此采用了这种方法。

【学习笔记】softmax回归与mnist编程相关推荐

  1. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  2. 深度学习基础--SOFTMAX回归(单层神经网络)

    深度学习基础–SOFTMAX回归(单层神经网络) 最近在阅读一本书籍–Dive-into-DL-Pytorch(动手学深度学习),链接:https://github.com/newmonkey/Div ...

  3. 百度飞桨2021李宏毅机器学习特训营学习笔记之回归及作业PM2.5预测

    百度飞桨2021李宏毅机器学习特训营学习笔记之回归及作业PM2.5预测 前言 回归 什么是回归(Regression)? 怎么做回归? 线性回归(Linear Regression) 训练集与验证集 ...

  4. PyTorch学习笔记(二)——回归

    PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...

  5. Linux学习笔记(3)- 网络编程以及范例程序

    Linux学习笔记(3)- 网络编程以及范例程序 前言 网络介绍 IP地址的介绍 端口和端口号的介绍 通信流程 socket介绍 TCP介绍 python3编码转换 TCP客户端程序开发流程 多任务版 ...

  6. 阿里云“7天实践训练营”入门班第二期学习笔记 第五天 在线编程挑战

    阿里云"7天实践训练营"入门班第二期学习笔记 第五天 在线编程挑战 吾辈,完全不会编程 以下内容全程来自阿里云社区的大佬分析讲解 原题目 知识点:搜索.字符串.位运算 有一天Jer ...

  7. TensorFlow学习笔记(二)MNIST入门

    MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片: 它也包含每一张图片对应的标签,告诉我们这个是数字几.比如,上面这四张图片的标签分别是5,0,4,1. 在此教程中,我们将训练一个机器 ...

  8. python编程16章教程_Python学习笔记__16.2章 TCP编程

    # 这是学习廖雪峰老师python教程的学习笔记 Socket是网络编程的一个抽象概念.通常我们用一个Socket表示"打开了一个网络链接",而打开一个Socket需要知道目标计算 ...

  9. 开始利用CSDN做学习笔记,从windows 游戏编程大师技巧和3D游戏编程大师开始

    利用两个月的空余时间将windows 游戏编程大师技巧和3D游戏编程大师技巧看了一遍. 第一遍读的并不深入,代码也没有仔细研究.特别是3D下册,基本只是草草浏览了一遍而已.这一遍是为了对整体有个印象和 ...

最新文章

  1. html5类选择器选择权重,Python Html5和CSS3的新增功能:CSS权重与CSS3新增选择器
  2. HDFS之SequenceFile和MapFile
  3. 杨植麟:28 岁青年科学家,开挂人生的方法论
  4. 数据中心是虚拟现实的基石
  5. python1000个常用代码-比较了1000多个Python开源项目,精选出这34个
  6. Python爬虫(十二)_BeautifulSoup4 解析器
  7. vue_组件_监听组件事件
  8. apache ranger_Apache Ranger插件的美丽简洁
  9. 华为鸿蒙系统学习笔记5-华为方舟编译器正式开源及相关源码下载
  10. jQuery初识之安装与语法简介
  11. 用C#实现***程序
  12. 记一个tcp udp测试工具ethrc
  13. java list下标_Java list删除指定多个下标数据
  14. 努力是你最幸福的时候
  15. 互联网赚钱发展趋势,网赚案例精准分析
  16. 计算机网络隧道技术,隧道技术-高级计算机网络.ppt
  17. 写给认真学习却进步缓慢的大一学生
  18. MySQL修改自增字段的自增值
  19. 单片机的停车场计数系统c51_停车场车辆计数系统的设计.doc
  20. SimpleDateFormat日期格式转换及时间戳转换

热门文章

  1. python 培训课件
  2. 制作WinPE步骤教程
  3. 《人性的弱点》阅读摘录-2
  4. 奔驰c语言控制系统使用方法,奔驰C200L灯光使用方法,C200L灯光开关图解说明
  5. 经验分享-161分过N1的学习备考经验-送给准备12月考试的你
  6. ddr2是几代内存_DDR2内存简介及技术介绍:
  7. 判断一个字符串是否是一个有效的罗马数字
  8. 计算机怎么程序记事本,如何使用计算机的记事本
  9. el-upload 文件上传一次再次上传无反应
  10. c语言从字符串逐个输出汉字