残差神经网络Resnet(MNIST数据集tensorflow实现)
简述:
残差神经网络(ResNet)主要是用于搭建深度的网络结构模型
(一)优势:
与传统的神经网络相比残差神经网络具有更好的深度网络构建能力,能避免因为网络层次过深而造成的梯度弥散和梯度爆炸。
(二)残差模块:
通过在一个浅层网络基础上叠加y=x的层,可以让网络随深度增加而不退化。
残差学习的函数是F(x) = H(x) - x,这里如果F(x) =0,那么就是恒等映射。
resnet"short connections” 的在connection是在恒等映射的情况。
输入和输出的差别H(x)-x就是残差。
这个是通过shortcut connection 实现,通过shortcut将这个block的输入和输出进行一个
element-wise的加叠,这个简单的加法并不会给网络带来多大额外计算量,同时可以大大
增加模型的训练速度,提高训练效果,并且当模型的层数加深时,这个简单的结构能很好的解决退化问题。
实现公式:
a[l+2] 加上了 a[l]的残差块,即:残差网络中,直接将a[l]向后拷贝到神经网络的更深层,在ReLU非线性激活前面
加上a[l],a[l]的信息直接达到网络深层。使用残差块能够训练更深层的网络,构建一个ResNet网络就是通过将很多
这样的残差块堆积在一起,形成一个深度神经网络。
对于大型的网络,无论把残差块添加到神经网络的中间还是末端,都不会影响网络的表现。
可以提升网络效率。
残差神经网络与传统神经网络的模型区别:
可以看到普通直连的卷积神经网络最大区别在于,ResNet有一个shortcut结构,而传统卷积或多或少存在信息丢失的问题
。
在实际中关于成本考虑,既将两个3*3的卷积替换成1*1+3*3+1*1,如下图,新结构
中的中间3*3的卷积层在一个1*1降维,另一个1*1做还原,既保持精度又减少计算量。
如图结构:
tensorflow代码实现
定义残差
class ResNet:def __init__(self,X_input,kernel_size,in_filter,out_filters,stride):self.X=X_inputself.X_sortcut=X_inputself.stride=stridef1,f2,f3=out_filtersself.conv_1=tf.Variable(tf.truncated_normal(shape=[1,1,in_filter,f1],stddev=0.1,mean=0,dtype=tf.float32))self.conv_b1=tf.Variable(tf.zeros([f1]))self.conv_2=tf.Variable(tf.truncated_normal(shape=[kernel_size,kernel_size,f1,f2],stddev=0.1,mean=0,dtype=tf.float32))self.conv_b2=tf.Variable(tf.zeros([f2]))self.conv_3=tf.Variable(tf.truncated_normal(shape=[1,1,f2,f3],stddev=0.1,mean=0,dtype=tf.float32))self.conv_b3=tf.Variable(tf.zeros([f3]))self.b_conv_fin = tf.Variable(tf.zeros([f3]))def ResNetChoice(self,choice):if(choice):# fristy = tf.nn.relu(tf.nn.conv2d(self.X,self.conv_1,strides=[1,1,1,1],padding="SAME")+self.conv_b1)#secondy = tf.nn.relu(tf.nn.conv2d(y,self.conv_2,strides=[1,1,1,1],padding="SAME")+self.conv_b2)#thirdy = tf.nn.relu(tf.nn.conv2d(y,self.conv_3,strides=[1,1,1,1],padding="SAME")+self.conv_b3)#final steapadd=tf.add(y,self.X_sortcut)add_result = tf.nn.relu(add + self.b_conv_fin)return add_resultelse:# fristy = tf.nn.relu(tf.nn.conv2d(self.X, self.conv_1, strides=[1, self.stride,self.stride, 1], padding="SAME") + self.conv_b1)# secondy = tf.nn.relu(tf.nn.conv2d(y, self.conv_2, strides=[1, 1, 1, 1], padding="SAME") + self.conv_b2)# thirdy = tf.nn.relu(tf.nn.conv2d(y, self.conv_3, strides=[1, 1, 1, 1], padding="SAME") + self.conv_b3)# final steapadd = tf.nn.conv2d(self.X_sortcut,self.conv_3,strides=[1,1,1,1],padding="SAME")add = tf.add(y,add)add_result = tf.nn.relu(add+self.b_conv_fin)return add_result
网络结构
class Net:def __init__(self):# 输入x(数据输入为图片的格式)self.x=tf.placeholder(tf.float32,[None,28,28,1])# 输入y(标签)self.y=tf.placeholder(tf.float32,[None,10])# ----------------------------卷积初始化--------------------------------#卷积第一层self.conv1_w=tf.Variable(tf.random_normal([3,3,1,16],dtype=tf.float32,stddev=0.1))# 卷积第一层偏移self.convb1=tf.Variable(tf.zeros([16]))#--------------------------------全连接初始化---------------------------------------# 第一层全连接wself.W = tf.Variable(tf.random_normal([7*7*32,128],dtype=tf.float32,stddev=0.1))# 第一层全连接bself.B = tf.Variable(tf.zeros([128]))# 第二层全连接wself.W1 = tf.Variable(tf.random_normal([128,10],dtype=tf.float32,stddev=0.1))# 第二层全连接bself.B1 = tf.Variable(tf.zeros([10]))def forward(self):# ------------------------------卷积层-----------------------------------# 卷积第一层实现self.conv1=tf.nn.relu(tf.nn.conv2d(self.x,self.conv1_w,strides=[1,1,1,1],padding="SAME")+self.convb1)# 第一层池化self.pool1=tf.nn.max_pool(self.conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")# 残差部分Rt = ResNet(self.pool1, 3, 16, [8, 8, 16], 1)x = Rt.ResNetChoice(True)Rt1 = ResNet(x, 3, 16, [8, 8, 16], 1)x = Rt1.ResNetChoice(True)Rt2 = ResNet(x, 3, 16, [8, 8, 16], 1)x = Rt2.ResNetChoice(True)Rt3 = ResNet(x, 3, 16, [8, 8, 16], 1)x = Rt3.ResNetChoice(True)URt =ResNet(x,3,16,[16,16,32],1)x = URt.ResNetChoice(False)self.pool2=tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")# 均值tf.nn.avg_pool# 归一化层tf.nn.batch_normalization# 形状处理self.flat = tf.reshape(self.pool2,[-1,7*7*32])# ---------------------------------全链接层-------------------------------------------self.y0 = tf.nn.relu(tf.matmul(self.flat,self.W)+self.B)self.yo = tf.nn.softmax(tf.matmul(self.y0,self.W1)+self.B1)def backword(self):self.cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.yo, labels=self.y))self.optimizer = tf.train.AdamOptimizer(0.003).minimize(self.cross_entropy)
训练:
if __name__ == '__main__':net = Net()net.forward()net.backword()init=tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init)for i in range(100000):xs,ys = mnist.train.next_batch(128)cg=xs.reshape([128,28,28,1])rs,loss,_=sess.run([net.acc,net.cross_entropy,net.optimizer],feed_dict={net.x:cg,net.y:ys})
残差神经网络在图像识别领域很有优势
残差神经网络Resnet(MNIST数据集tensorflow实现)相关推荐
- 深度学习——残差神经网络ResNet在分别在Keras和tensorflow框架下的应用案例
原文链接:https://blog.csdn.net/loveliuzz/article/details/79117397 一.残差神经网络--ResNet的综述 深度学习网络的深度对最后的分类和识别 ...
- 【Pytorch(七)】基于 PyTorch 实现残差神经网络 ResNet
基于 PyTorch 实现残差神经网络 ResNet 文章目录 基于 PyTorch 实现残差神经网络 ResNet 0. 概述 1. 数据集介绍 1.1 数据集准备 1.2 分析分类难度:CIFAR ...
- 二隐层的神经网络实现MNIST数据集分类
二隐层的神经网络实现MNIST数据集分类 传统的人工神经网络包含三部分,输入层.隐藏层和输出层.对于一个神经网络模型的确定需要考虑以下几个方面: 隐藏层的层数以及各层的神经元数量 各层激活函数的选择 ...
- [转载] 卷积神经网络做mnist数据集识别
参考链接: 卷积神经网络在mnist数据集上的应用 Python TensorFlow是一个非常强大的用来做大规模数值计算的库.其所擅长的任务之一就是实现以及训练深度神经网络. 在本教程中,我们将学到 ...
- 深度学习基础: BP神经网络训练MNIST数据集
BP 神经网络训练MNIST数据集 不用任何深度学习框架,一起写一个神经网络训练MNIST数据集 本文试图让您通过手写一个简单的demo来讨论 1. 导包 import numpy as np imp ...
- 基于Python实现的卷积神经网络分类MNIST数据集
卷积神经网络分类MNIST数据集 目录 人工智能第七次实验报告 1 卷积神经网络分类MNIST数据集 1 一 .问题背景 1 1.1 卷积和卷积核 1 1.2 卷积神经网络简介 2 1.3 卷积神经网 ...
- 基于Python实现的神经网络分类MNIST数据集
神经网络分类MNIST数据集 目录 神经网络分类MNIST数据集 1 一 .问题背景 1 1.1 神经网络简介 1 前馈神经网络模型: 1 1.2 MINST 数据说明 4 1.3 TensorFlo ...
- Python实现bp神经网络识别MNIST数据集
title: "Python实现bp神经网络识别MNIST数据集" date: 2018-06-18T14:01:49+08:00 tags: [""] cat ...
- 神经网络——实现MNIST数据集的手写数字识别
由于官网下载手写数字的数据集较慢,因此提供便捷下载地址如下 手写数字的数据集MNIST下载:https://download.csdn.net/download/gaoyu1253401563/108 ...
最新文章
- python编写用户输入的是q么代码_Python课 #01号作业
- mysql 集群 备份_mysql cluster集群备份还原
- 缓存、内存、硬盘、虚拟硬盘
- 3.15曝光“山寨”杀毒软件“杀毒三宗罪”
- C语言按行和列求平均成绩代码(指针,二维数组)
- hadoop 单机伪分布式安装步骤
- docker简介与搭建
- Java多线程实现多客户端的通信
- 九日登望仙台呈刘明府 [唐] 崔曙
- java mysql教程基于_基于JAVA和MYSQL数据库实现的学生信息管理系统
- vue v-for循环的用法
- Maven setting文件配置错误:Non-parseable settings..in comment after two dashes (--) next character must be
- React.js+i18next实现国际化
- WiFi功耗管理(一)(概述)
- 无法访问eclipse官网?镜像源可以帮你
- 上海大学计算机专硕与学硕,计算机学院2017届留沪研究生学习成绩综合评定工作通知...
- 简单句、并列句、复合句、got his wish、 owe you
- 国际移动设备识别码IMEI
- 用opencv及cuda编译好的dakrnet训练yolo4
- python黑科技-五彩斑斓的黑
热门文章
- 如何在新时代下的结对编程中将代码玩出花来
- Intel Optane(tm) Memory Pinning ,无法加载 DLL ´iaStorAfsServiceApi.dll´ : 找不到指定的模块。
- jQuery each( ) 遍历 与 $.each( ) 遍历【一篇文章轻松拿下】
- 油溶性CdTe/CdSe/ZnS量子点/CdSe量子点化学修饰在钛酸钠纳米管/CdTe量子点修饰的ZnO纳米棒/GaN发光二极管
- android闹钟 推迟功能,闹钟延迟76秒才播放音乐
- 计算机系学生应该怎么学java?
- Sound Event Detection: A Tutorial
- VC++判断进程是否以管理员权限运行(附源码)
- 5轴雕刻机同款运动系统。 USB运动控制 (五轴雕刻机系统)全部开源 不保留任何关键技术,PCB可直接生产
- Valheim英灵神殿服务器端口2456-2457-2458开启