文章目录

  • TensorFlow中常见的CallBack
    • Tensorboard
    • Checkpoint
    • Earlystoping
    • CSVLogger
    • LearningRateScheduler
  • 定义CallBack类

TensorFlow中常见的CallBack

Tensorboard

model = build_model(dense_units=256)
model.compile(optimizer='sgd',loss='sparse_categorical_crossentropy', metrics=['accuracy'])logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir)model.fit(train_batches, epochs=10, validation_data=validation_batches, callbacks=[tensorboard_callback])

我们可以使用tensorboard画出我们想要的关于模型的图案,具体tensorboard的使用是一个很有趣的过程。

Checkpoint

model = build_model(dense_units=256)
model.compile(optimizer='sgd',loss='sparse_categorical_crossentropy', metrics=['accuracy'])model.fit(train_batches, epochs=5, validation_data=validation_batches, verbose=2,callbacks=[ModelCheckpoint('weights.{epoch:02d}-{val_loss:.2f}.h5', verbose=1),])

我们可以使用checkpoint按照一定的要求对模型进行保存(例如,按照一定的频率,时间)

Earlystoping

model = build_model(dense_units=256)
model.compile(optimizer='sgd',loss='sparse_categorical_crossentropy', metrics=['accuracy'])model.fit(train_batches, epochs=50, validation_data=validation_batches, verbose=2,callbacks=[EarlyStopping(patience=3,min_delta=0.05,baseline=0.8,mode='min',monitor='val_loss',restore_best_weights=True,verbose=1)])

当我们发现模型的方差在增大,val_loss上升,模型泛化能力变差,我们可以提前终止训练。

CSVLogger

model = build_model(dense_units=256)
model.compile(optimizer='sgd',loss='sparse_categorical_crossentropy', metrics=['accuracy'])csv_file = 'training.csv'model.fit(train_batches, epochs=5, validation_data=validation_batches, callbacks=[CSVLogger(csv_file)])

将训练中的信息按照CSV文件格式给出。

LearningRateScheduler

model = build_model(dense_units=256)
model.compile(optimizer='sgd',loss='sparse_categorical_crossentropy', metrics=['accuracy'])def step_decay(epoch):initial_lr = 0.01drop = 0.5epochs_drop = 1lr = initial_lr * math.pow(drop, math.floor((1+epoch)/epochs_drop))return lrmodel.fit(train_batches, epochs=5, validation_data=validation_batches, callbacks=[LearningRateScheduler(step_decay, verbose=1),TensorBoard(log_dir='./log_dir')])

在训练中动态修改学习率,使得模型能够更快收敛。

model = build_model(dense_units=256)
model.compile(optimizer='sgd',loss='sparse_categorical_crossentropy', metrics=['accuracy'])model.fit(train_batches, epochs=50, validation_data=validation_batches, callbacks=[ReduceLROnPlateau(monitor='val_loss', factor=0.2, verbose=1,patience=1, min_lr=0.001),TensorBoard(log_dir='./log_dir')])

和上面类似,只不过该方法只有当遇到瓶颈的时候才修改学习率

定义CallBack类

我们可以从Callback类继承,从而定义我们自己的类。

import tensorflow as tf
from tensorflow.python.keras.callbacks import Callbackclass MyCallback(Callback):def __init__(self, loss_threshold=0.01):super(MyCallback, self).__init__()self.loss_threshold = loss_thresholddef on_train_begin(self, logs=None):print("training begin")def on_epoch_end(self, epoch, logs=None):if logs['train_loss'] < self.loss_threshold:self.model.stop_training = Trueprint('loss is enough')

在Callback基类中定义了很多函数,我们都可以重载,例如 o n _ t r a i n i n g _ b e g i n , o n _ e p o c h _ e n d on\_training\_begin,on\_epoch\_end on_training_begin,on_epoch_end等等。

TensorFlow中常见的CallBack相关推荐

  1. tensorflow中常见的损失函数

    今天在构建一个卷积网络时看到书上例程里用的tf.nn.sparse_softmax_cross_entropy_with_logits()这个函数,打开Documentation看了没太明白,特地讲三 ...

  2. TensorFlow中的Keras用法和自定义模型和层

    Keras Keras 是一个用于构建和训练深度学习模型的高阶 API.它可用于快速设计原型.高级研究和生产,具有以下三个主要优势: 方便用户使用 Keras 具有针对常见用例做出优化的简单而一致的界 ...

  3. 如何使用TensorFlow中的Dataset API

    翻译 | AI科技大本营 参与 | zzq 审校 | reason_W 本文已更新至TensorFlow1.5版本 我们知道,在TensorFlow中可以使用feed-dict的方式输入数据信息,但是 ...

  4. TensorFlow中设置学习率的方式

    目录 1. 指数衰减 2. 分段常数衰减 3. 自然指数衰减 4. 多项式衰减 5. 倒数衰减 6. 余弦衰减 6.1 标准余弦衰减 6.2 重启余弦衰减 6.3 线性余弦噪声 6.4 噪声余弦衰减 ...

  5. 中tile函数_HelpGirlFriend 系列 --- tensorflow 中的张量运算思想

    GirlFriend 在复现论文的时候,我发现她不太会将通用数学公式转化为张量运算公式,导致 tensorflow 无法通过并行的方式优化其论文复现代码的运行速率. 这里对给 GirlFriend 讲 ...

  6. ML之模型文件:机器学习、深度学习中常见的模型文件(.h5、.keras)简介、h5模型文件下载集锦、使用方法之详细攻略

    ML之模型文件:机器学习.深度学习中常见的模型文件(.h5..keras)简介.h5模型文件下载集锦.使用方法之详细攻略 目录 ML/DL中常见的模型文件(.h5..keras)简介及其使用方法 一. ...

  7. TensorFlow 中文文档 介绍

    介绍 本章的目的是让你了解和运行 TensorFlow 在开始之前, 先看一段使用 Python API 撰写的 TensorFlow 示例代码, 对将要学习的内容有初步的印象. 这段很短的 Pyth ...

  8. 深度学习中常见的损失函数

    文章来源于AI的那些事儿,作者黄鸿波 2018年我出版了<TensorFlow进阶指南 基础.算法与应用>这本书,今天我把这本书中关于常见的损失函数这一节的内容公开出来,希望能对大家有所帮 ...

  9. tensorflow中batch normalization的用法

    转载网址:如果侵权,联系我删除 https://www.cnblogs.com/hrlnw/p/7227447.html https://www.cnblogs.com/eilearn/p/97806 ...

最新文章

  1. 人工智能和人类智能的类比
  2. 怎么给自己的python换源_windows/linux下如何更换Python的pip源
  3. 操作Checkbox标签
  4. 服务器收到消息加入数组,从聊天服务器发送到聊天客户端的数组更新
  5. PCM音频文件的制作
  6. SAP License:SAP自学SAP常见的问题二
  7. Repeater控件的嵌套使用
  8. 如何对酒店的固定资产进行日常管理?
  9. python自动生成文章原创_【Python】皮皮AI工具( AI文章伪原创工具)
  10. PPT宏编程——ChineseCounter
  11. 学生管理系统——数据库表设计
  12. 重装系统无法在计算机上运行,开机无法进入系统?重装系统开机不能进入系统怎么办...
  13. unbuntu20.04安装mysql5.7
  14. 问题 G: 结义兄弟
  15. 【数据库连接池】数据库连接池
  16. java 读取doc文件_如何在java中读取Doc或Docx文件?
  17. AGC012 - E: Camel and Oases
  18. 昆明世博园装mysql_昆明世博园太美丽了
  19. 百数在线表单如何实现表单套打?
  20. Excel——数据有效性+条件格式应用

热门文章

  1. 夏季国内10大避暑胜地指南|7、8、9月暑期最适合旅行目的地
  2. 利用java编写人机剪刀石头布的小游戏
  3. 【全志T113-S3_100ask】14-1 linux采集usb摄像头实现拍照(FFmpeg、fswebcam)
  4. HTML-Css文字排版--字体--段落
  5. 聊城大学matlab试题,聊城大学计算机学院11—12学年第2学期期末考试《编译原理》试题(闭卷B卷)...
  6. python deepcopy
  7. 【暗战】1999年杜琪峰指导上映的电影
  8. mysql闭包的概念_彻底搞懂JavaScript的闭包、防抖跟节流
  9. 用计算机打出下山这首歌,我要串词怎么引出下山这首歌?
  10. 2021年深圳杯A题火星探测器着陆控制方案