全局平均池化能否完美代替全连接?

参考链接:https://www.cnblogs.com/hutao722/p/10008581.html

一.什么是全局平均池化?

   全局平均池化(GAP)通过池化操作把多维矩阵转化为特征向量,以顶替全连接(FC)。
优点
   ① 减少了FC中的大量参数,使得模型更加健壮,抗过拟合,当然,可能也会欠拟合。
   ② GAP在特征图与最终的分类间转换更加自然。
   GAP工作原理如下图所示:
   假设卷积层的最后输出是h × w × d 的三维特征图,具体大小为6 × 6 × 3,经过GAP转换后,变成了大小为 1 × 1 × 3 的输出值,也就是每一层 h × w 会被平均化成一个值。

二.GAP和FC的对比验证结果

  1 GAP在Keras中的定义

x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) #卷积层最后一层
x = layers.GlobalAveragePooling2D()(x) #GAP层
prediction = Dense(10, activation='softmax')(x) #输出层

再看看GAP的代码具体实现:

@tf_export('keras.layers.GlobalAveragePooling2D','keras.layers.GlobalAvgPool2D')
class GlobalAveragePooling2D(GlobalPooling2D):"""Global average pooling operation for spatial data.Arguments:data_format: A string,one of `channels_last` (default) or `channels_first`.The ordering of the dimensions in the inputs.`channels_last` corresponds to inputs with shape`(batch, height, width, channels)` while `channels_first`corresponds to inputs with shape`(batch, channels, height, width)`.It defaults to the `image_data_format` value found in yourKeras config file at `~/.keras/keras.json`.If you never set it, then it will be "channels_last".Input shape:- If `data_format='channels_last'`:4D tensor with shape:`(batch_size, rows, cols, channels)`- If `data_format='channels_first'`:4D tensor with shape:`(batch_size, channels, rows, cols)`Output shape:2D tensor with shape:`(batch_size, channels)`"""def call(self, inputs):if self.data_format == 'channels_last':return backend.mean(inputs, axis=[1, 2])else:return backend.mean(inputs, axis=[2, 3])

  实现很简单,对宽度和高度两个维度的特征数据进行平均化求值。如果是NHWC结构(数量、宽度、高度、通道数),则axis=[1, 2];反之如果是CNHW,则axis=[2, 3]。

  2. GAP VS GMP VS FC

  在验证GAP技术可行性前,我们需要准备训练和测试数据集。我在牛津大学网站上找到了17种不同花类的数据集,地址为:http://www.robots.ox.ac.uk/~vgg/data/flowers/17/index.html 。该数据集每种花有80张图片,共计1360张图片,我对花进行了分类处理,抽取了部分数据作为测试数据,这样最终训练和测试数据的数量比为7:1。
  我将数据集上传到我的百度网盘: https://pan.baidu.com/s/1YDA_VOBlJSQEijcCoGC60w ,大家可以下载使用。
  在Keras经典模型中,若支持迁移学习,不但有GAP,还有GMP,而默认是自己组建FC层,一个典型的实现为:

if include_top:# Classification blockx = layers.Flatten(name='flatten')(x)x = layers.Dense(4096, activation='relu', name='fc1')(x)x = layers.Dense(4096, activation='relu', name='fc2')(x)x = layers.Dense(classes, activation='softmax', name='predictions')(x)else:if pooling == 'avg':x = layers.GlobalAveragePooling2D()(x)elif pooling == 'max':x = layers.GlobalMaxPooling2D()(x)

  本文将在同一数据集条件下,比较GAP、GMP和FC层的优劣,选取测试模型为VGG19和InceptionV3两种模型的迁移学习版本。

  先看看在VGG19模型下,GAP、GMP和FC层在各自迭代50次后,验证准确度和损失度的比对。代码如下:

import keras
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.applications.vgg19 import VGG19from keras.layers import Dense, Flatten
from matplotlib import pyplot as plt
import numpy as np# 为保证公平起见,使用相同的随机种子
np.random.seed(7)
batch_size = 32
# 迭代50次
epochs = 50
# 依照模型规定,图片大小被设定为224
IMAGE_SIZE = 224
# 17种花的分类
NUM_CLASSES = 17
TRAIN_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/train'
TEST_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/test'
FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy','Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower','Tigerlily', 'tulip', 'WindFlower']def model(mode='fc'):if mode == 'fc':# FC层设定为含有512个参数的隐藏层base_model = VGG19(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False, pooling='none')x = base_model.outputx = Flatten()(x)x = Dense(512, activation='relu')(x)prediction = Dense(NUM_CLASSES, activation='softmax')(x)elif mode == 'avg':# GAP层通过指定pooling='avg'来设定base_model = VGG19(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False, pooling='avg')x = base_model.outputprediction = Dense(NUM_CLASSES, activation='softmax')(x)else:# GMP层通过指定pooling='max'来设定base_model = VGG19(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False, pooling='max')x = base_model.outputprediction = Dense(NUM_CLASSES, activation='softmax')(x)model = Model(input=base_model.input, output=prediction)model.summary()opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)model.compile(loss='categorical_crossentropy',optimizer=opt,metrics=['accuracy'])# 使用数据增强train_datagen = ImageDataGenerator()train_generator = train_datagen.flow_from_directory(directory=TRAIN_PATH,target_size=(IMAGE_SIZE, IMAGE_SIZE),classes=FLOWER_CLASSES)test_datagen = ImageDataGenerator()test_generator = test_datagen.flow_from_directory(directory=TEST_PATH,target_size=(IMAGE_SIZE, IMAGE_SIZE),classes=FLOWER_CLASSES)# 运行模型history = model.fit_generator(train_generator, epochs=epochs, validation_data=test_generator)return historyfc_history = model('fc')
avg_history = model('avg')
max_history = model('max')# 比较多种模型的精确度
plt.plot(fc_history.history['val_acc'])
plt.plot(avg_history.history['val_acc'])
plt.plot(max_history.history['val_acc'])
plt.title('Model accuracy')
plt.ylabel('Validation Accuracy')
plt.xlabel('Epoch')
plt.legend(['FC', 'AVG', 'MAX'], loc='lower right')
plt.grid(True)
plt.show()# 比较多种模型的损失率
plt.plot(fc_history.history['val_loss'])
plt.plot(avg_history.history['val_loss'])
plt.plot(max_history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['FC', 'AVG', 'MAX'], loc='upper right')
plt.grid(True)
plt.show()

  各自运行50次迭代后,我们看看准确度比较:

  再看看模型的损失变化:

  可以看到,GMP(MAX)完全GG。
  FC在1到40轮表现和GAP无太大差异,后期出现过拟合现象。唯一优势是前期学习速度快(考虑跟全连接的结构密切相关,更复杂,学的快),
  GAP并没有出现过拟合的现象,但是准确度只接近70%,可能是model的原因。
  我们再转向另一个模型InceptionV3,代码稍加改动如下:
下面展示一些 。

import keras
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.applications.inception_v3 import InceptionV3, preprocess_input
from keras.layers import Dense, Flatten
from matplotlib import pyplot as plt
import numpy as np# 为保证公平起见,使用相同的随机种子
np.random.seed(7)
batch_size = 32
# 迭代50次
epochs = 50
# 依照模型规定,图片大小被设定为224
IMAGE_SIZE = 224
# 17种花的分类
NUM_CLASSES = 17
TRAIN_PATH = '/home/hutao/Documents/tensorflow/images/17flowerclasses/train'
TEST_PATH = '/home/hutao/Documents/tensorflow/images/17flowerclasses/test'
FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy','Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower','Tigerlily', 'tulip', 'WindFlower']def model(mode='fc'):if mode == 'fc':# FC层设定为含有512个参数的隐藏层base_model = InceptionV3(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False, pooling='none')x = base_model.outputx = Flatten()(x)x = Dense(512, activation='relu')(x)prediction = Dense(NUM_CLASSES, activation='softmax')(x)elif mode == 'avg':# GAP层通过指定pooling='avg'来设定base_model = InceptionV3(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False, pooling='avg')x = base_model.outputprediction = Dense(NUM_CLASSES, activation='softmax')(x)else:# GMP层通过指定pooling='max'来设定base_model = InceptionV3(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False, pooling='max')x = base_model.outputprediction = Dense(NUM_CLASSES, activation='softmax')(x)model = Model(input=base_model.input, output=prediction)model.summary()opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)model.compile(loss='categorical_crossentropy',optimizer=opt,metrics=['accuracy'])# 使用数据增强train_datagen = ImageDataGenerator()train_generator = train_datagen.flow_from_directory(directory=TRAIN_PATH,target_size=(IMAGE_SIZE, IMAGE_SIZE),classes=FLOWER_CLASSES)test_datagen = ImageDataGenerator()test_generator = test_datagen.flow_from_directory(directory=TEST_PATH,target_size=(IMAGE_SIZE, IMAGE_SIZE),classes=FLOWER_CLASSES)# 运行模型history = model.fit_generator(train_generator, epochs=epochs, validation_data=test_generator)return historyfc_history = model('fc')
avg_history = model('avg')
max_history = model('max')# 比较多种模型的精确度
plt.plot(fc_history.history['val_acc'])
plt.plot(avg_history.history['val_acc'])
plt.plot(max_history.history['val_acc'])
plt.title('Model accuracy')
plt.ylabel('Validation Accuracy')
plt.xlabel('Epoch')
plt.legend(['FC', 'AVG', 'MAX'], loc='lower right')
plt.grid(True)
plt.show()# 比较多种模型的损失率
plt.plot(fc_history.history['val_loss'])
plt.plot(avg_history.history['val_loss'])
plt.plot(max_history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['FC', 'AVG', 'MAX'], loc='upper right')
plt.grid(True)
plt.show()

  先进行准确率的比较:

  再看看损失的变化过程:

  很明显,在InceptionV3模型下,FC、GAP和GMP都表现很好,但可以看出GAP的表现依旧最好,其准确度普遍在90%以上,而另两种的准确度在80~90%之间。

三.结论

  从以上结果来看,GAP是优于FC的,但也不能证明在所有的网络上都能够使用,比如在结构和深度都比较小的网络上,GAP代替FC作者猜想可能会导致欠拟合的现象,这需要更多的验证和尝试对比。

全局平均池化能否完美代替全连接?相关推荐

  1. Pytorch之经典神经网络CNN(七) —— GoogLeNet(InceptionV1)(Bottleneck)(全局平均池化GAP)(1*1卷积)(多尺度)(flower花卉数据集)

    2014年 Google提出的 是和VGG同年出现的,在ILSVRC(ImageNet) 2014中获得冠军,vgg屈居第二 GoogLeNet也称Inception V1.之所以叫GoogLeNet ...

  2. Lesson 16.1016.1116.1216.13 卷积层的参数量计算,1x1卷积核分组卷积与深度可分离卷积全连接层 nn.Sequential全局平均池化,NiN网络复现

    二 架构对参数量/计算量的影响 在自建架构的时候,除了模型效果之外,我们还需要关注模型整体的计算效率.深度学习模型天生就需要大量数据进行训练,因此每次训练中的参数量和计算量就格外关键,因此在设计卷积网 ...

  3. GAP(全局平均池化层)操作

    转载的文章链接: 为什么使用全局平均池化层? 关于 global average pooling https://blog.csdn.net/qq_23304241/article/details/8 ...

  4. Global Average Pooling全局平均池化的一点理解

    Traditional Pooling Methods 要想真正的理解Global Average Pooling,首先要了解深度网络中常见的pooling方式,以及全连接层. 众所周知CNN网络中常 ...

  5. 全局平均池化(Golbal Average Pooling)与Concatenate层

    转载:全剧平均池化 出处:Lin M, Chen Q, Yan S. Network in network[J]. arXiv preprint arXiv:1312.4400, 2013. 查看全文 ...

  6. 全局平均池化(Global Average Pooling)

    出处:Lin M, Chen Q, Yan S. Network in network[J]. arXiv preprint arXiv:1312.4400, 2013. 定义:将特征图所有像素值相加 ...

  7. CNN(卷积层convolutional layer,激励层activating layer,池化层pooling,全连接层fully connected)

    CNN产生的原因:当使用全连接的神经网络时,因为相邻两层之间的神经元都是有边相连的,当输入层的特征纬度非常高时(譬如图片),全连接网络需要被训练的参数就会非常多(参数太多,训练缓慢),CNN可以通过训 ...

  8. 池化层在全连接层之间吗,了解最大池化层之后的全连接层的尺寸

    In the diagram (architecture) below, how was the (fully-connected) dense layer of 4096 units derived ...

  9. cnn中关于平均池化和最大池化的理解

    cnn中关于平均池化和最大池化的理解 接触到pooling主要是在用于图像处理的卷积神经网络中,但随着深层神经网络的发展,pooling相关技术在其他领域,其他结构的神经网络中也越来越受关注. 一个典 ...

最新文章

  1. 【设计模式】抽象工厂模式 ( 简介 | 适用场景 | 优缺点 | 产品等级结构和产品族 | 代码示例 )
  2. C语言realtime stats实时统计(附完整源码)
  3. python相关背景及语言特点
  4. Python 常用函数 configparser模块
  5. 面试官:. NET5源码里用到了哪些设计模式?懵!
  6. angularJS constant和value
  7. php 常用编译参数,php编译参数,不用怕!!
  8. dict后缀_学习词根dict 成片记单词
  9. git解决冲突 删除本地_Git冲突:git pull时和本地改动冲突
  10. 数据挖掘十大经典算法(9) 朴素贝叶斯分类器 Naive Bayes
  11. MAMP Pro for Mac(PHP/MySQL开发环境)
  12. 网络安全职业_如何开始网络安全职业
  13. python获取outlook邮件内容_Python3读取Outlook邮件并写入MySQL
  14. python文件定位函数_C语言中文件定位函数总结
  15. cholesky求逆
  16. 赠书:深入理解 Spring Cloud 与实战
  17. docker 搭建本地 coredns 服务器
  18. java编程细节总结(一):等于号的作用
  19. /system32/ntoskrnl.exe丢失无法启动
  20. 如何在官网下载tomcat

热门文章

  1. 生物药公司“普米斯”获1.8亿元融资,华金资本、珠海高科创投、弘晖资本联合投资...
  2. python-selenium模块爬取动态网址实例---------【下载漫画码上面的漫画】
  3. web项目启动流程分析
  4. chardet判断中文编码
  5. 对scipy.cluster.vq中whiten()函数总结
  6. 详解离线语音和在线语音的区别
  7. 【梳理】离散数学 第2版 第9章 代数系统 9.1 二元运算及其性质
  8. iphone原彩显示对眼睛好吗_iPhone XS采用OLED屏,看久了觉得眼睛难受怎么办?
  9. 基于PyQt5和Pywinauto自动化测试客户端
  10. np.nditer函数