文章目录

  • 载入数据
  • 数据增强(Data Augmentation)
  • 载入模型,并使用迁移学习修改模型
    • 迁移学习(transfer learning)

载入数据

  • 从文件夹中载入图片
BATCH_SIZE = 32
IMG_SIZE = (160, 160)
directory = "dataset/"
train_dataset = image_dataset_from_directory(directory,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE,validation_split=0.2,subset='training',seed=42)
validation_dataset = image_dataset_from_directory(directory,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE,validation_split=0.2,subset='validation',seed=42)

  • 打印几张图片
class_names = train_dataset.class_namesplt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

数据增强(Data Augmentation)

  • 仅仅简单的旋转、平移
# UNQ_C1
# GRADED FUNCTION: data_augmenter
def data_augmenter():'''Create a Sequential model composed of 2 layersReturns:tf.keras.Sequential'''### START CODE HEREdata_augmentation = tf.keras.Sequential()data_augmentation.add(RandomFlip('horizontal'))data_augmentation.add(RandomRotation(0.2))### END CODE HEREreturn data_augmentation

载入模型,并使用迁移学习修改模型

  • 载入 Mobile_v2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=True,weights='imagenet')

迁移学习(transfer learning)

首先,我们要了解迁移学习的主要原理,它的根据人们使用神经网络的经验,进行总结之后发现:

  • 神经网络的前几层学习到的是较为低级的特征,例如图案的边,或者其他简单线条组成的图案,这些是很多图片都公有的特征
  • 越往后,神经网络学习到的特征就越高级,例如,学习猫的图片,最后几层的神经网络可能就是 猫耳图案的高级(数据集特有的特征)特征。
  • 基于上述的经验总结,我们可以将一个训练好的模型的浅层神经网络直接拿过来使用,然后更改最后几层的神经网络结果,最后我们只更新最后几层的参数,浅层的参数不变,调低学习率对模型进行微调之后就可以达到 他人训练好的模型为我所有的目的
  • 迁移学习适用于:类似任务(猫和狗之类的任务),算力不足(具有很大的模型但算力不能支持从头训练),还可以借用开源的模型…

# UNQ_C2
# GRADED FUNCTION
def alpaca_model(image_shape=IMG_SIZE, data_augmentation=data_augmenter()):''' Define a tf.keras model for binary classification out of the MobileNetV2 modelArguments:image_shape -- Image width and heightdata_augmentation -- data augmentation functionReturns:Returns:tf.keras.model'''input_shape = image_shape + (3,)### START CODE HEREbase_model = tf.keras.applications.MobileNetV2(input_shape=input_shape,include_top=False, # <== Important!!!!weights='imagenet') # From imageNet# freeze the base model by making it non trainablebase_model.trainable = False# create the input layer (Same as the imageNetv2 input size)inputs = tf.keras.Input(shape=input_shape) # apply data augmentation to the inputsx = data_augmentation(inputs)# data preprocessing using the same weights the model was trained onx = preprocess_input(x) # set training to False to avoid keeping track of statistics in the batch norm layerx = base_model(x, training=False) # add the new Binary classification layers# use global avg pooling to summarize the info in each channelx = tf.keras.layers.GlobalAveragePooling2D()(x) # include dropout with probability of 0.2 to avoid overfittingx = tf.keras.layers.Dropout(0.2)(x)# use a prediction layer with one neuron (as a binary classifier only needs one)outputs = tf.keras.layers.Dense(units=1)(x)### END CODE HEREmodel = tf.keras.Model(inputs, outputs)return model
  • Mobile_v2 的最后几层是用于10中分类任务的,我们设置include_top = False,舍弃他们,然后使用Function Api 添加新的分类器。
  • 设置 trainable = False,使得我们不更新前面的参数

但是我们知道,最后几层的神经网络学习的是高级特征,我们需要对最后几层神经网络进行微调,使得这几层神经网络学习到我们数据集想要的特征,因此我们对最后几层神经网络进行‘解冻’。

model2 = alpaca_model(IMG_SIZE, data_augmentation)
# UNQ_C3
base_model = model2.layers[4]
base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))# Fine-tune from this layer onwards
fine_tune_at = 120### START CODE HERE# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:layer.trainable = False# Define a BinaryCrossentropy loss function. Use from_logits=True
loss_function=tf.keras.losses.BinaryCrossentropy(from_logits=True)
# Define an Adam optimizer with a learning rate of 0.1 * base_learning_rate
optimizer = tf.keras.optimizers.Adam(learning_rate=0.1*base_learning_rate)
# Use accuracy as evaluation metric
metrics=['accuracy']### END CODE HEREmodel2.compile(loss=loss_function,optimizer = optimizer,metrics=metrics)
fine_tune_epochs = 5
total_epochs =  initial_epochs + fine_tune_epochshistory_fine = model2.fit(train_dataset,epochs=total_epochs,initial_epoch=history.epoch[-1],validation_data=validation_dataset)

MobileNet_v2 with transfer learning(修改Mobile_v2 模型)相关推荐

  1. 基于Keras Application和Densenet迁移学习(transfer learning)的乳腺癌图像分类模型(良性、恶性)

    基于Keras Application和Densenet迁移学习(transfer learning)的乳腺癌图像分类模型(良性.恶性) 概论: 美国癌症学会官方期刊发表<2018年全球癌症统计 ...

  2. 【论文笔记09】Differentially Private Hypothesis Transfer Learning 差分隐私迁移学习模型, ECMLPKDD 2018

    目录导引 系列传送 Differentially Private Hypothesis Transfer Learning 1 Abstract 2 Bg & Rw 3 Setting &am ...

  3. Domain adaptation:连接机器学习(Machine Learning)与迁移学习(Transfer Learning)

    domain adaptation(域适配)是一个连接机器学习(machine learning)与迁移学习(transfer learning)的新领域.这一问题的提出在于从原始问题(对应一个 so ...

  4. 深度学习不得不会的迁移学习Transfer Learning

    http://blog.itpub.net/29829936/viewspace-2641919/ 2019-04-18 10:04:53 目录 一.概述 二.什么是迁移学习? 2.1 模型的训练与预 ...

  5. AI入门:Transfer Learning(迁移学习)

    迁移学习是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中 Pokemon Dataset 通过网络上收集宝可梦的图片,制作图像分类数据集.我收集了5种 ...

  6. 迁移学习(transfer learning)与finetune的关系?【finetune只是transfer learning的一种手段】

    目录 1.迁移学习简介 2.为什么要迁移学习? 3.迁移学习的几种方式 1)Transfer Learning: 2)Extract Feature Vector: 3)Fine-tune: 4.三种 ...

  7. 【论文阅读笔记】High Quality Monocular Depth Estimation via Transfer Learning

    文章目录 High Quality Monocular Depth Estimation via Transfer Learning Abstract 1. Introduction 2. Relat ...

  8. 迁移学习(Transfer Learning)-- 概念理解

    迁移学习(Transfer Learning) 迁移学习概述 背景 随着越来越多的机器学习应用场景的出现,而现有表现比较好的监督学习需要大量的标注数据,标注数据是一项枯燥无味且花费巨大的任务,所以迁移 ...

  9. scJoint integrates atlas-scale single-cell RNA-seq and ATAC-seq data with transfer learning

    scJoint integrates atlas-scale single-cell RNA-seq and ATAC-seq data with transfer learning Nature B ...

最新文章

  1. LabVIEW图像分割算法(基础篇—6)
  2. php内加百度熊掌号,百度熊掌号接入网站页面改造详细步骤
  3. 倍福TwinCAT(贝福Beckhoff)常见问题(FAQ)-Switch Case语句是否会自动跳转到下一个
  4. html在不同浏览器器下颜色不同,CSS在不同浏览器下实现颜色渐变效果
  5. MySQL Hex函数使用详解
  6. 移动端click事件延迟300ms到底是怎么回事,该如何解决?
  7. 证明n次根号下n阶乘等价于n/e
  8. linux实验十shell程序设计,实验二Linux Shell编程.doc
  9. java图书馆抢座系统_java毕业设计_springboot框架的图书馆座位预约占座
  10. CAD绘图的规范要点
  11. 广东电信在线人工服务器,202.96.128.86广东电信DNS故障及解决方法
  12. 用python编程 商品打折怎么计算_折扣怎么算用计算公式
  13. ECS 7天实践训练营-day1
  14. 手把手教你如何向 Linux 内核提交代码
  15. springboot 2.X——短信网关使用初体验
  16. Android 微信分享视频缩略图不显示问题
  17. 互联网电影院5G让3D体验更流畅
  18. 【每日一问】工作日问题
  19. 严重性 代码 说明 项目 文件 行 禁止显示状态 警告 CS8032 无法从...创建分析器...的实例: 未能加载文件或程序集...或它的某一个依赖项。系统找不到指定的文件
  20. 王者荣耀注销服务器的流程,王者荣耀账号怎么永久注销 王者荣耀注销账号流程介绍...

热门文章

  1. ABLIC今日推出S-576Z系列IC
  2. 【22SR】Revisiting RCAN: Improved Training for Image Super-Resolution
  3. JDBC项目实践与源码解析(十一)
  4. ROS 2 Humble Hawksbill 之 f1tenth gym
  5. tenth week
  6. 解决Chrome浏览器中部分字体显示模糊的问题
  7. Neural NILM Deep Neural Networks Applied to Energy Disaggregation
  8. 详解拉卡拉支付赋能商户运营模式
  9. 百度、海澜之家、维密、极星汽车、高盛等公司高管变动
  10. ege管理系统_医院挂号管理系统-智能新型医院挂号管理系统下载v95.44官方PC版-CE安全网...