• 首发自公众号:RAIS

我们已经训练过几个神经网络了,识别手写数字,房价预测或者是区分猫和狗,那随之而来就有一个问题,这些训练出的网络怎么用,每个问题我都需要重新去训练网络吗?因为程序员都不太喜欢做重复的事情,因此答案肯定是已经有轮子了。

我们先来介绍一个数据集,ImageNet。这就不得不提一个大名鼎鼎的华裔 AI 科学家李飞飞。

2005 年左右,李飞飞结束了他的博士生涯,开始了他的学术研究不就她就意识到了一个问题,在此之前,人们都尽可能优化算法,认为无论数据如何,只要算法够好,就能做出更好的决策,李飞飞意识到了这个问题的局限性,恰巧她还是一个行动派,她要做出一个无比庞大的数据集,尽可能描述世界上一切物体的数据集,下载图片,给没一张图片做标注,简单而无聊,当然后来这项工作放到了亚马逊的众包平台上,全世界无数的人参与了这个伟大的项目,到此刻为止,已经有 14,197,122 张图片(一千四百万张),21841 个分类。在这个发展的过程中,人们也发现了这个数据集带来的成功远比预想的要多,甚至现在被认为最有前景的深度卷积神经网络的提出也与 ImageNet 不无关系。我忘记了谁这么说过:“就单单这一个数据集,就可以让李飞飞数据科学这个领域拥有一席之地”。暂且不说这么说是否准确,但这个数据集仍然在创造新的突破。(我曾经在台下听过李飞飞一次演讲,现在想想还觉得甚是激动,她真的充满热情)。

基于这个数据集,我们是不是可以训练出一些网络,一般情况下,大家就不用耗时再去训练网络了呢?答案是肯定的,并且在 Keras 就有个一些这样的模型,还是内置的,Keras 就是这么懂你,那就不用客气了,我们拿来用就好了,谢谢啦!

特征提取

我们之前用到的卷积神经网络都是分成了两部分,第一部分是由池化层和卷积层组成的卷积积,第二部分是由分类器,特征提取的含义就是第一部分不变,改变第二部分。

为什么可以这么做?我们之前解释过神经网络的运行原理,跟人脑的认识过程非常类似,还记得吗?我们还是看一看原来的图吧。

我们可以看出来,网络识别图像是有层次结构的,比如一开始的网络层是用来识别图像或者拼装线条的,这是通用且类似的,因此我们可以复用。而后面的分类器往往是根据具体的问题所决定的,比如识别猫或狗的眼睛就与识别桌子腿是不一样的,因此有越靠前越具有通用性的特点。Keras 中很多的内置模型都可以直接下载,如果你没有下载在使用的时候会自动下载:

https://github.com/fchollet/deep-learning-models/releases

我们举一个例子,用 VGG16 去识别猫或狗,这次的解释都比较简单且都是以前说明过的,因此放在代码注释中:

#!/usr/bin/env python3
​
import os
import time
​
import matplotlib.pyplot as plt
import numpy as np
from keras import layers
from keras import models
from keras import optimizers
from keras.applications import VGG16
from keras.preprocessing.image import ImageDataGenerator
​
​
def extract_features(directory, sample_count):# 图片转换区间datagen = ImageDataGenerator(rescale=1. / 255)batch_size = 20conv_base = VGG16(weights='imagenet',include_top=False,input_shape=(150, 150, 3))
​conv_base.summary()
​features = np.zeros(shape=(sample_count, 4, 4, 512))labels = np.zeros(shape=(sample_count))# 读出图片,处理成神经网络需要的数据格式,上一篇文章中有介绍generator = datagen.flow_from_directory(directory,target_size=(150, 150),batch_size=batch_size,class_mode='binary')i = 0for inputs_batch, labels_batch in generator:print(i, '/', len(generator))# 提取特征features_batch = conv_base.predict(inputs_batch)features[i * batch_size: (i + 1) * batch_size] = features_batchlabels[i * batch_size: (i + 1) * batch_size] = labels_batchi += 1if i * batch_size >= sample_count:break
​# 特征和标签return features, labels
​
​
def cat():base_dir = '/Users/renyuzhuo/Desktop/cat/dogs-vs-cats-small'train_dir = os.path.join(base_dir, 'train')validation_dir = os.path.join(base_dir, 'validation')
​# 提取出的特征train_features, train_labels = extract_features(train_dir, 2000)validation_features, validation_labels = extract_features(validation_dir, 1000)
​# 对特征进行变形展平train_features = np.reshape(train_features, (2000, 4 * 4 * 512))validation_features = np.reshape(validation_features, (1000, 4 * 4 * 512))
​# 定义密集连接分类器model = models.Sequential()model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512))model.add(layers.Dropout(0.5))model.add(layers.Dense(1, activation='sigmoid'))
​# 对模型进行配置model.compile(optimizer=optimizers.RMSprop(lr=2e-5),loss='binary_crossentropy',metrics=['acc'])
​# 对模型进行训练history = model.fit(train_features, train_labels,epochs=30,batch_size=20,validation_data=(validation_features, validation_labels))
​# 画图acc = history.history['acc']val_acc = history.history['val_acc']loss = history.history['loss']val_loss = history.history['val_loss']epochs = range(1, len(acc) + 1)plt.plot(epochs, acc, 'bo', label='Training acc')plt.plot(epochs, val_acc, 'b', label='Validation acc')plt.title('Training and validation accuracy')plt.legend()plt.show()plt.figure()plt.plot(epochs, loss, 'bo', label='Training loss')plt.plot(epochs, val_loss, 'b', label='Validation loss')plt.title('Training and validation loss')plt.legend()plt.show()
​
​
if __name__ == "__main__":time_start = time.time()cat()time_end = time.time()print('Time Used: ', time_end - time_start)​

有点巧合的是这里居然看不到太多的过拟合的痕迹,其实也是有可能会有过拟合的隐患的,那样就需要进行数据增强,与以前是一样的,只不过这里的区别就是用到了内置模型,模型的参数需要冻结,我们是不希望对已经训练好的模型进行更改的,具体关键代码写法如下:

conv_base.trainable = False
​
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))

以上就是模型复用的一种方法,我们对模型都是原封不动的拿来用,我们下一篇文章将介绍另外一种方法,对模型进行微调。

  • 首发自公众号:RAIS

AI:拿来主义——预训练网络(一)相关推荐

  1. 使用预训练网络训练的两种方式:Keras Applications、TensorFlow Hub

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) tensorflow 2.0 画出model网络模型的拓扑图 ...

  2. keras从入门到放弃(十六)内置预训练网络VGG

    什么是预训练网络 一个常用.高效的在小图像数据集上深度学习的方法就是利用预训练网络.一个预训练网络只是简单的储存了之前在大的数据集训练的结果,通常是大的图像分类任务.如果原始的数据集已经足够大,足够一 ...

  3. 预训练网络的特征提取方法(VGG16)

    预训练网络的特征提取方法 1.知识点 #想要将深度学习应用于小型图像数据集,一种常用且非常高效的方法是使用预训练网络 #预训练网络:一个保存好的网络,之前已经在大型数据集(通常是大规模图像分类任务)上 ...

  4. Pytorch:图像语义分割-FCN, U-Net, SegNet, 预训练网络

    Pytorch: 图像语义分割-FCN, U-Net, SegNet, 预训练网络 Copyright: Jingmin Wei, Pattern Recognition and Intelligen ...

  5. pytorch深度学习实战——预训练网络

    来源:<Pytorch深度学习实战>,2.1,一个识别图像主体的预训练网络 from torchvision import models from torchvision import t ...

  6. python 动物分类_《python深度学习》笔记---5.3-1、猫狗分类(使用预训练网络)

    <python深度学习>笔记---5.3-1.猫狗分类(使用预训练网络) 一.总结 一句话总结: [小型图像数据集]:想要将深度学习应用于小型图像数据集,一种常用且非常高效的方法是使用预训 ...

  7. 预训练网络的模型微调方法

    是什么 神经网络需要数据来训练,从数据中获得信息,进而转化成相应的权重.这些权重能够被提取出来,迁移到其他的神经网络中. 迁移学习:通过使用之前在大数据集上经过训练的预训练模型,我们可以直接使用相应的 ...

  8. 利用预训练网络打造自己的分类网络

    卷积神经网络的训练是耗时的,很多场合不可能每次都从随机初始化参数开始训练网络. 1.训练 pytorch中自带几种常用的深度学习网络预训练模型,如VGG.ResNet等.往往为了加快学习的进度,在训练 ...

  9. 经典论文解读 — 端到端的VL预训练网络SOHO

    来源:投稿 作者:摩卡 编辑:学姐 论文标题: Seeing Out of tHe bOx :End-to-End Pre-training for Visual-Language Represent ...

最新文章

  1. C#(WPF)去除事件中注册的事件处理方法!
  2. Pocket Hacking: NetHunter实战指南
  3. 字少事大|两张表格教你快速选择适合的MCU进行物联网开发
  4. vim多窗口使用技巧
  5. 通俗易懂,常用线程池执行的-流程图
  6. gulp webpack整合
  7. 【Python】Numpy包的安装使用
  8. 信息学奥赛一本通 1226:装箱问题 | OpenJudge NOI 4.6 19:装箱问题
  9. OpenCV--实现图像滑动窗口截取子图操作
  10. 查看静态库(.a文件)内容
  11. asp连接mysql数据库增删查_【ASP】ASP对Access数据库的连接、增删改查及ASP的基本语法...
  12. crm高速开发之OrganizationService
  13. 【前端成长-读书群】
  14. Python-snap7 安装和测试
  15. matlab 使用uci数据集,如何使用UCI数据集
  16. 恒生O32系统的前世今生
  17. 计算机网络-路由交换技术
  18. oracle查询平均每月数据,oracle 按每天,每周,每月,每季度,每年查询统计数据
  19. c语言程序如何首行缩进,什么叫代码缩进
  20. 【数据结构】无向图的遍历(广度搜索和深度搜索)

热门文章

  1. 2023年五一法定节假日是几天?如何提醒自己放假时间?
  2. js获取元素高度比较
  3. Nokia E52 使用技巧
  4. html中 readonly和disabled的区别
  5. Pytorch之经典神经网络CNN(三) —— AlexNet(CIFAR-10) (LRN)
  6. linux u盘无损分区,科技常识:linux如何无损调整分区大小
  7. 利用反向代理对IP地址的文根修改
  8. Spring Cloud 第六天
  9. Mycat 读写分离实战
  10. (转)Android状态栏微技巧,带你真正理解沉浸式模式