我们知道,机器学习模型的效果好坏很大程度上取决于超参的选取。人肉调参需要依赖经验与直觉,且花费大量精力。PBT(Population based training)是DeepMind在论文《Population Based Training of Neural Networks》中提出的一种异步的自动超参数调节优化方法。以往的自动调节超参方法可分为两类:parallel search和sequential optimization。前者并行执行很多不同超参的优化任务,优点是可以并行利用计算资源更快找到最优解;后者需要利用之前的信息来进行下一步的超参优化,因此只能串行执行,但一般能得到更好的解。PBT完美地结合两种方法,兼具两者优点。它被应用于一些领域取得了不错的效果。如DeepMind的论文《Human-level performance in first-person multiplayer games with population-based deep reinforcement learning》将之用于第一人称多人游戏使AI达到人类水平。还有今年UC Berkeley的论文《Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules》中用PBT来自动学习data augmentation策略,在几个benchmark上达到了不错的精度。另外,最近自动驾驶公司Waymo也称将PBT应用于识别任务,与手工调参相比可以提高精度和加快训练速度。

PBT开局与parallel search类似,会并行训练一批随机初始化的模型。过程中它会周期性地将表现好的模型替换表现不好的模型(exploitation),同时再加上随机扰动(主要是为了exploration)。PBT与其它方法的一个重要不同是它在训练的过程中对超参进行调节,因此可以更快地发现超参和优异的schedule。论文《Population Based Training of Neural Networks》中的示意图非常清楚地示意了整个过程,及与其它方法的区别:

PBT是一种很通用的方法,可以用于很多场景,其一般套路如下:

  1. Step:对模型训练一步。至于一步是一次iteration还是一个epoch还是其它可以根据需要指定。
  2. Eval:在验证集上做评估。
  3. Ready: 选取群体中的一个模型来进行下面的exploit和explore操作(即perturbation)。这个模型一般是上次做过该操作后经过指定的时间(或迭代次数等)。
  4. Exploit: 将那些经过评估比较烂的模型用那些比较牛叉的模型替代。
  5. Explore: 对上一步产生的复制体模型加随机扰动,如加上随机值或重采样。

Ray中实现了PBT算法。Ray中关于PBT有三个example:一个是learning rate搜索pbt_example.py,另一个是强化学习算法PPO的超参数搜索pbt_ppo_example.py。还有一个是pbt_tune_cifar10_with_keras.py。我们来看下最简单的pbt_example.py。其中的PBTBenchmarkExample类继承自Trainable类,它是一个toy的模拟环境,假设在模型训练过程中最优的learning rate是变化的,是accuracy的函数。目标是找到learning rate的schedule。它的核心函数是_train(),这里会模拟最优的learning rate。

然后看主函数,首先通过ray.init()初始化ray,然后创建PopulationBasedTraining对象,接着通过run()函数开始超参搜索过程。

    pbt = PopulationBasedTraining(time_attr="training_iteration",metric="mean_accuracy",mode="max",perturbation_interval=20,hyperparam_mutations={# distribution for resampling"lr": lambda: random.uniform(0.0001, 0.02),# allow perturbations within this set of categorical values"some_other_factor": [1, 2],})run(PBTBenchmarkExample,name="pbt_test",scheduler=pbt,reuse_actors=True,verbose=False,**{"stop": {"training_iteration": 2000,},"num_samples": 4,"config": {"lr": 0.0001,# note: this parameter is perturbed but has no effect on# the model training in this example"some_other_factor": 1,},})

先看第一步,PopulationBasedTraining的实现在python/ray/tune/schedulers/pbt.py中。它继承自FIFOScheduler类。构造函数中几个主要参数:

  • time_attr: 用于定义训练时长的测度,要求单调递增,比如training_iteration
  • metric: 训练结果衡量目标。
  • mode: 上面metric属性是越高越好,还是越低越好。
  • perturbation_interval: 模型会以time_attr为间隔来进行perturbation。
  • hyperparam_mutations: 需要变异的超参。它是一个dict,对于每个key对应list或者function。如果没设这个,就需要在custom_explore_fn中指定。
  • quantile_fraction: 决定按多大比例将表现好的头部模型克隆到尾部模型。
  • resample_probability: 当对超参进行exploration时从原分布中重新采样的概率,否则会根据现有的值调整。
  • custom_explore_fn: 自定义的exploration函数。

第二步中run()函数实现在ray/python/ray/tune/tune.py中:

def run(run_or_experiment, name=None, ...):trial_executor = traial_executor or RayTrialExecutor(...)experiment = run_or_experimentif not isinstance(run_or_experiment, Experiment):if not isinstance(run_or_experiment, Experiment):experiment = Experiment(...)...runner = TrialRunner(search_alg=search_alg or BasicVariantGenerator(),scheduler=scheduler or FIFOScheduler(),local_checkpoint_dir=experiment.checkpoint_dir,remote_checkpoint_dir=experiment.remote_checkpoint_dir,sync_to_cloud=sync_to_cloud,checkpoint_period=global_checkpoint_period,resume=resume,launch_web_server=with_server,server_port=server_port,verbose=bool(verbose > 1),trial_executor=trial_executor)runner.add_experiment(experiment)...while not runner.is_finished():runner.step()...wait_for_sync()...return ExperimentAnalysis(runner.checkpoint_file, trials=trials)

第一个参数run_or_experiment是要训练的目标任务,参数scheduler就是上面创建的PopulationBasedTraining,负责超参搜索时的调度。

其中几个关键类关系如下图:

SearchAlgorithm的实现类BasicVariantGenerator会根据给定的Experiment产生参数变体。每个待训练的参数变体会创建相应的Trial对象。Trial有PENDING, RUNNING, PAUSED, TERMINATED, ERROR几种状态。它会开始于PENDING状态,开始训练后转为RUNNING状态,出错了就到ERROR状态,成功的话就是TERMINATED状态。训练中还可能被TrialScheduler暂停(转入PAUSED状态)并释放资源。

TrialRunner是最核心的数据结构,它管理一系列的Trial对象,并且执行一个事件循环,将这些任务通过TrialExecutor的实现类RayTrialExecutor提交到Ray cluster运行。RayTrialExecutor会负责资源的管理。这里通过Ray分布执行的主要是Trainable的实现类(上例中就是PBTBenchmarkExample)中的_train()函数。RayTrialExecutor对象中的_running维护了正在运行的Trial。在循环中,TrialRunner会通过TrialScheduler的实现类PopulationBasedTraining来进行调度。它的choose_trial_to_run()函数从trial_runner的queue中拿出状态为PENDING或者PAUSED的trial,并且选取离上次做perturbation最久的一个保证尽可能公平。

run函数主要做以下几步:

  1. 创建RayTrailExecutor对象(如果没有传入trial_executor的话)。
  2. 如果目标任务不是以Experiment对象形式给出,会按照给定的其它参数构建Experiment对象。
  3. 创建TrialRunner对象,它基于Ray来调度事件循环。
    1. 创建搜索算法对象(如果没给),默认为BasicVariantGenerator(实现在basic_variant.py)。它主要用于产生新的参数变体。
    2. 创建执行实验的调度器(如果没给),默认为FIFOScheduler。上例中给定了PopulationBasedTraining,所以这里就不需要创建了。
    3. 创建TrialRunner对象(实现在trial_runner.py)。并上面创建的Experiment对象通过add_experiment()函数加到TrialRunner对象中。
  4. 进入主循环,通过TrialRunneris_finished()函数判断是否结束。如果没有,就调用TrialRunnerstep()函数执行一步。step()函数的主要工作下面再细说。
  5. 收尾工作。如通过wait_for_sync()函数同步远端目标,记录没有正常结束的trial,返回分析信息。

其中比较关键的是step()函数,其主要流程如下:

当一个Trial训练结束返回结果时,TrialRunner会调用PopulationBasedTrainingon_trial_result()函数。这里就是PBT的精华了。结合文章开关的PBT一般套路,主要步骤如下:

  1. 如果离上次pertubation的时间还没到指定间隔,则返回让该Trial继续训练。
  2. 调用_quantiles()函数按设定的比例__quantile_fraction得到所有Trial中表现好的头部和表现不好的尾部。
  3. 如果当前trial是比较牛的那一批,那赶紧存成checkpoint,等着被其它trial克隆学习。
  4. 如果很不幸地,当前trial属于比较差的那一批,那就从牛的那批中随机挑一个(为trial_to_clone),然后调用_exploit()函数。该函数会调用explore()函数对trial_to_clone进行扰动,然后将它的参数设置和checkpoint设置到当前trial。这样,当前trial就“洗心革面”,重新出发了。
  5. 如果TrialRunner中有PENDING和PAUSED状态的trial,则请求暂停当前trial,让出资源。否则的话就继续训练着。

最后,总结下主要模块间的大体流程:

超参数自动优化方法PBT(Population Based Training)相关推荐

  1. 基于Python的随机森林(RF)回归与多种模型超参数自动优化方法

      本文详细介绍基于Python的随机森林(Random Forest)回归算法代码与模型超参数(包括决策树个数与最大深度.最小分离样本数.最小叶子节点样本数.最大分离特征数等等)自动优化代码.    ...

  2. 自动化机器学习(一)超参数自动优化技术

    文章目录 技术介绍 核心技术栈 项目选择 数据 基础模型 Hyperopt 实现 数据读取 使用lightgbm中的cv方法 定义参数空间 展示结果 贝叶斯优化 原理 使用lightgbm中的cv方法 ...

  3. 机器学习调参自动优化方法

    本文旨在介绍当前被大家广为所知的超参自动优化方法,像网格搜索.随机搜索.贝叶斯优化和Hyperband,并附有相关的样例代码供大家学习. 一.网格搜索(Grid Search) 网格搜索是暴力搜索,在 ...

  4. 超参数调整的方法介绍

    文章目录 超参数调整的方法介绍 常用的超参数调整方法 网格搜索(Grid Search) 如何进行网格搜索 小结 随机搜索(Random Search) 贝叶斯优化(Bayesian Optimiza ...

  5. 遗传算法优化rbf神经网络自校正控制的初值_【技术帖】轻量化设计中的NVH性能自动优化方法...

    摘 要:噪声.振动与声振粗糙度 (Noise,Vibration and Harshness,NVH)性能的自动优化是实现多学科联合优化的基础条件.以白车身模型的零件厚度作为设计变量,以针对动刚度性能 ...

  6. Python的数据分析中超参数调优方法:网格搜索

    [小白从小学Python.C.Java] [Python全国计算机等级考试] [Python数据分析考试必会题] ● 标题与摘要 Python的数据分析中 超参数调优方法:网格搜索 ● 选择题 以下说 ...

  7. 超参数及其优化办法:验证集

    一.超参数定义: 超参数是在开始学习过程之前设置值的参数,而不是通过训练得到的参数数据.和一般的参数比如权重.偏置之类的有差别. 通常情况下,需要对超参数进行优化,给学习机选择一组最优超参数,以提高学 ...

  8. 一文详解超参数调优方法

    ©PaperWeekly 原创 · 作者|王东伟 单位|Cubiz 研究方向|深度学习 本文介绍超参数(hyperparameter)的调优方法. 神经网络模型的参数可以分为两类: 模型参数,在训练中 ...

  9. python gridsearch_Python超参数自动搜索模块GridSearchCV上手

    1. 引言 当我们跑机器学习程序时,尤其是调节网络参数时,通常待调节的参数有很多,参数之间的组合更是繁复.依照注意力>时间>金钱的原则,人力手动调节注意力成本太高,非常不值得.For循环或 ...

  10. 贝叶斯优化原理及应用[附XGBoost、LightGBM超参数调优代码][scikit-optimize]

    近年来机器学习和深度学习算法被越来越广泛的应用于解决对未知数据的预测问题.由于超参数的选择对模型最终的效果可能有极大的影响,为了使模型达到更好的效果,通常会面临超参数调优问题.但如何选择合适的超参数并 ...

最新文章

  1. linux下rpm,yum学习
  2. python控制台输出颜色
  3. Modbus RTU 通信工具设计
  4. 教你如何更改xshell中的转发规则
  5. Erlang程序设计
  6. 【OpenCV 例程200篇】53. Scipy 实现图像二维卷积
  7. jquery getjson php,jquery中调用php json函数的方法分享
  8. vue 指令 v-on 事件修饰符-鼠标事件-什么是事件冒泡
  9. 多功能计算机使用说明,多功能分装机/多功能分装机
  10. registered php streams sqlsrv,tp5与SQL Server的爱恨情仇(1)
  11. java servlet 对象_java servlet的域对象
  12. debian 修改apache2 https 端口为11443
  13. PGP软件的安装及汉化
  14. Ali-tomcat之HSF框架Demo启动报错HSFServiceAddressNotFoundException
  15. Qt打包程序报错“应用程序无法正常启动(0xc000007b)”
  16. tomcat部署静态网页
  17. CapstoneCS5212|CapstoneCS5218|DP转VGA1080P方案设计| DP转HDMI4K 30Hz方案设计
  18. 神经网络架构搜索——可微分搜索(DARTS)
  19. windows虚拟桌面_在Windows中使用虚拟桌面的最佳免费程序
  20. 数字图像处理学习之路:图像变换(一)

热门文章

  1. WINCE 矩阵键盘 介绍
  2. 120_x轴与y轴平移【transform: translateX(n) translateY(n)】利用定位和变形使元素水平垂直居中
  3. 返回上一页,ajax读出来的数据丢失。
  4. Cloning into ‘vue-element-admin‘... fatal: unable to access ‘https://github.com/PanJiaChen/vue-eleme
  5. 松翰SN8P2511 SOP8单片机 可代烧录 提供单片机方案开发 单片机解密
  6. 强化学习DQN 入门小游戏 最简单的Pytorch代码
  7. “2022绿色智能制造创赢计划”全新集结:加入这个朋友圈,成为未来主角
  8. hypermedia_Hypermedia REST API简介
  9. 老男孩五篇重要文章:http://oldboy.blog.51cto.com/2561410/1184139
  10. C# Winform右下角弹窗方式