TensorFlow使用Keras Tuner自动调参

  • 数据集
  • 归一化
  • 图像分类模型
  • Hyperband
  • 运行超参数搜索(自动调参)
  • 获取最佳超参数
  • 使用最佳超参数构建和训练模型
  • 整体代码

代码地址:
https://github.com/lilihongjava/deep_learning/tree/master/TensorFlow2.0%E8%87%AA%E5%8A%A8%E8%B0%83%E5%8F%82

数据集

Zalando商品图片数据集,通过load_data函数读取data目录下 ‘train-labels-idx1-ubyte.gz’, ‘train-images-idx3-ubyte.gz’, ‘t10k-labels-idx1-ubyte.gz’, 't10k-images-idx3-ubyte.gz’文件

def load_data():path = "./data/"files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']paths = [path + each for each in files]with gzip.open(paths[0], 'rb') as lbpath:y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)  # uint8无符号整数(0 to 255),一个字节,一张图片256色with gzip.open(paths[1], 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)  # 图像尺寸(28*28)with gzip.open(paths[2], 'rb') as lbpath:y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)  # offset=8,前8不读with gzip.open(paths[3], 'rb') as imgpath:x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)return (x_train, y_train), (x_test, y_test)
(img_train, label_train), (img_test, label_test) = load_data()

归一化

 img_train = img_train.astype('float32') / 255.0img_test = img_test.astype('float32') / 255.0

图像分类模型

hypermodel
调整第一个Dense层中的层数,在32-512之间选择一个最佳值

 hp.Int('units', min_value=32, max_value=512, step=32)

调整优化器的学习速率,从0.01、0.001或0.0001中选择一个最佳值

hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])
def model_builder(hp):model = keras.Sequential()model.add(keras.layers.Flatten(input_shape=(28, 28)))  # 输入“压平”,即把多维的输入一维化# Tune the number of units in the first Dense layer# Choose an optimal value between 32-512hp_units = hp.Int('units', min_value=32, max_value=512, step=32)model.add(keras.layers.Dense(units=hp_units, activation='relu'))model.add(keras.layers.Dense(10))# Tune the learning rate for the optimizer# Choose an optimal value from 0.01, 0.001, or 0.0001hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])  # accuracy,用于判断模型效果的函数return model

Hyperband

使用Hyperband 算法搜索超参数
定义Hyperband,指定hypermodel,优化的目标,最大迭代次数,衰减系数,详细日志和checkpoints保存路径

    tuner = kt.Hyperband(model_builder,objective='val_accuracy',  # 优化的目标,验证集accuracymax_epochs=10,  # 最大迭代次数factor=3,directory='my_dir',  # my_dir/intro_to_kt目录包含超参数搜索期间运行的详细日志和checkpointsproject_name='intro_to_kt')

运行超参数搜索(自动调参)

ClearTrainingOutput为回调函数,在每个训练步骤结束时回调

 tuner.search(img_train, label_train, epochs=10, validation_data=(img_test, label_test),callbacks=[ClearTrainingOutput()])

获取最佳超参数

tuner.get_best_hyperparameters(num_trials=1)[0]

使用最佳超参数构建和训练模型

 model = tuner.hypermodel.build(best_hps)model.fit(img_train, label_train, epochs=10, validation_data=(img_test, label_test))

整体代码

if __name__ == '__main__':#  Zalando商品图片数据集(img_train, label_train), (img_test, label_test) = load_data()# 归一化img_train = img_train.astype('float32') / 255.0img_test = img_test.astype('float32') / 255.0# 使用 Hyperband 算法搜索超参数tuner = kt.Hyperband(model_builder,objective='val_accuracy',  # 优化的目标,验证集accuracymax_epochs=10,  # 最大迭代次数factor=3,directory='my_dir',  # my_dir/intro_to_kt目录包含超参数搜索期间运行的详细日志和checkpointsproject_name='intro_to_kt')tuner.search(img_train, label_train, epochs=10, validation_data=(img_test, label_test),callbacks=[ClearTrainingOutput()])# Get the optimal hyperparametersbest_hps = tuner.get_best_hyperparameters(num_trials=1)[0]print(f"""The hyperparameter search is complete. The optimal number of units in the first densely-connectedlayer is {best_hps.get('units')} and the optimal learning rate for the optimizeris {best_hps.get('learning_rate')}.""")# Build the model with the optimal hyperparameters and train it on the datamodel = tuner.hypermodel.build(best_hps)model.fit(img_train, label_train, epochs=10, validation_data=(img_test, label_test))

参考:https://www.tensorflow.org/tutorials/keras/keras_tuner

TensorFlow使用Keras Tuner自动调参相关推荐

  1. Keras Tuner自动调参工具使用入门教程

    主体是翻译的Keras Tuner的说明:https://keras-team.github.io/keras- tuner/documentation/tuners/ github地址:https: ...

  2. 【调参工具】微软自动调参工具—NNI

    参考链接: 微软自动调参工具-NNI-安装与使用教程(附错误解决) nni官方文档 总结一下步骤 1.pip安装nni pip install nni 2.配置search_space.json,co ...

  3. 使用Ray Tune自动调参

    文章目录 前言 一.Ray Tune是什么? 二.使用步骤 1.安装包 2.引入库 3.读入数据(与Ray Tune无关) 4.构建神经网络模型(与Ray Tune无关) 5.模型的训练和测试(与Ra ...

  4. 微软自动调参工具—NNI—安装与使用教程(附错误解决)

    简介 NNI是微软的开源自动调参的工具.人工调参实在是太麻烦了,最近试了下水,感觉还不错,能在帮你调参的同时,把可视化的工作一起给做了,简单明了.然后感觉很多博客写的并不是很明白,所以打算自己补充一下 ...

  5. 微软自动调参工具 NNI 使用事例教程

    第一步:安装 nni的安装通过pip命令就可以安装了.并且提供了example供参考学习. 系统配置要求:tensorflow,python >= 3.5 # 安装nnipython3 -m p ...

  6. NNI 自动调参使用。

    前言 NNI是由微软研究院,开发的深度学习开发工具. Neural Network Intelligence 是一个工具包,可以有效帮助用户设计并调优汲取学习模型的神经网络架构,以及超参数.具有易于使 ...

  7. PID自动调参simulink仿真

    PID自动调参----simulink仿真-----如何高效调参 设计PID控制器 系统识别APP识别传递函数 Simulink搭建仿真控制系统 使用Maltab自动调参工具PID Tuner调节PI ...

  8. Auto ML自动调参

    Auto ML自动调参 本文介绍Auto ML自动调参的算法介绍及操作流程. 操作步骤 登录PAI控制台. 单击左侧导航栏的实验并选择某个实验. 本文以雾霾天气预测实验为例. 在实验画布区,单击左上角 ...

  9. sklearn快速入门教程:(四)模型自动调参

    上个教程中我们已经看到在sklearn中调用机器学习模型其实非常简单.但要获得较好的预测效果则需要选取合适的超参数.在实际的项目中其实也有不少参数是由工程师借助其经验手动调整的,但在许多场景下这种方式 ...

最新文章

  1. 如何优化linux服务器,手把手教你如何优化linux服务器
  2. Shell字符串截取——获取oracle group名字
  3. Shell脚本个例二
  4. android摄像头代码,Android摄像头
  5. maven web项目不能创建src/main/java等文件夹的问题
  6. 关于windows cmd的一些便捷应用
  7. left join 效率_人力资源HR的人才测评工具,极大提高招聘效率
  8. C#操作 excel表格
  9. Python标准库shutil中rmtree()使用回调函数
  10. SpringBoot中如何优雅的使用拦截器
  11. 关于使用XLSTransformer.transformXLS导出Excel表格中遇到的问题
  12. 工作中如何进行接口测试
  13. 30行JS代码带你手写自动回复语音聊天机器人
  14. Power BI分解销售目标
  15. 解决光纤猫恢复出厂功能后的上网问题
  16. 利用球谐系数计算函数值及利用EGM球谐系数计算重力异常
  17. 加州欧文大学计算机工程,加州大学欧文分校计算机工程专业课程设置有哪些
  18. python写网络爬虫微博用户发布的视频
  19. python中经常使用的包扎材料_以下哪些是经常使用的包扎材料:
  20. 详解ELF重定向原理

热门文章

  1. Java 关于中文乱码问题的解决方案与经验【转载】
  2. 一位沪漂 11 年的程序员老兵,回老家了!
  3. 一年内经验前端面试题记录
  4. python中print()换行的问题
  5. day06_tomacat
  6. 【Redis】事务和锁机制
  7. 仪器数据自动化采集,助力提升实验室管理效率
  8. Hadoop MapReduce Job 相关参数设置 概念介绍与理解
  9. crontab环境变量问题
  10. 单点登录(SSO)、CAS介绍