Code for VeLO 2: Training Versatile Learned Optimizers by Scaling Up

上一篇文章已经介绍了怎么训练一个MLP网络,这篇文章将介绍一下怎么用VeLO训练resnets

这篇文章基于https://colab.research.google.com/drive/1-ms12IypE-EdDSNjhFMdRdBbMnH94zpH#scrollTo=RQBACAPQZyB-,将介绍使用learned optimizer in the VeLO family:

  • 一个简单的图片识别人物
  • resetnets
#@title imports, configuration, and model classesfrom absl import app
from datetime import datetimefrom functools import partial
from typing import Any, Callable, Sequence, Tuplefrom flax import linen as nnimport jax
import jax.numpy as jnpfrom jaxopt import loss
from jaxopt import OptaxSolver
from jaxopt import tree_utilimport optaximport tensorflow_datasets as tfds
import tensorflow as tf# 可以使用的数据集
dataset_names = ["mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100"
]L2REG = 1e-4
LEARNING_RATE = .2
EPOCHS = 10
MOMENTUM = .9
DATASET = 'cifar100' #@param [ "mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100"]
MODEL = 'resnet18' #@param ["resnet1", "resnet18", "resnet34"]
TRAIN_BATCH_SIZE = 256
TEST_BATCH_SIZE = 1024# 加载数据集
def load_dataset(split, *, is_training, batch_size):version = 3ds, ds_info = tfds.load(f"{DATASET}:{version}.*.*",as_supervised=True,  # remove useless keyssplit=split,with_info=True)ds = ds.cache().repeat()if is_training:ds = ds.shuffle(10 * batch_size, seed=0)ds = ds.batch(batch_size)return iter(tfds.as_numpy(ds)), ds_infoclass ResNetBlock(nn.Module):"""ResNet block."""filters: intconv: Anynorm: Anyact: Callablestrides: Tuple[int, int] = (1, 1)@nn.compactdef __call__(self, x,):residual = xy = self.conv(self.filters, (3, 3), self.strides)(x)y = self.norm()(y)y = self.act(y)y = self.conv(self.filters, (3, 3))(y)y = self.norm(scale_init=nn.initializers.zeros)(y)if residual.shape != y.shape:residual = self.conv(self.filters, (1, 1),self.strides, name='conv_proj')(residual)residual = self.norm(name='norm_proj')(residual)return self.act(residual + y)class ResNet(nn.Module):"""ResNetV1."""stage_sizes: Sequence[int]block_cls: Anynum_classes: intnum_filters: int = 64dtype: Any = jnp.float32act: Callable = nn.relu@nn.compactdef __call__(self, x, train: bool = True):conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)norm = partial(nn.BatchNorm,use_running_average=not train,momentum=0.9,epsilon=1e-5,dtype=self.dtype)x = conv(self.num_filters, (7, 7), (2, 2),padding=[(3, 3), (3, 3)],name='conv_init')(x)x = norm(name='bn_init')(x)x = nn.relu(x)x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')for i, block_size in enumerate(self.stage_sizes):for j in range(block_size):strides = (2, 2) if i > 0 and j == 0 else (1, 1)x = self.block_cls(self.num_filters * 2 ** i,strides=strides,conv=conv,norm=norm,act=self.act)(x)x = jnp.mean(x, axis=(1, 2))x = nn.Dense(self.num_classes, dtype=self.dtype)(x)x = jnp.asarray(x, self.dtype)return x# 虽然不太清楚为啥ResNet为啥没有__init__函数,但是估计又是python某个不知名的骚操作吧 emmm 我看它__call__这个函数也写的挺骚的。
ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock)
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock)
#@title training loop definition (run this cell to launch training)
import functools
from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import Optionalfrom dataclasses import dataclassimport jax
import jax.numpy as jnpfrom jaxopt._src import base
from jaxopt._src import tree_util# 这个类的目的只是为了保存状态信息的吧
class OptaxState(NamedTuple):"""Named tuple containing state information."""iter_num: int # 迭代的数量value: float  # valueerror: float  # internal_state: NamedTupleaux: Any# we need to reimplement optax's OptaxSolver's lopt_update method to properly pass in the loss data that VeLO expects.
def lopt_update(self,params: Any,state: NamedTuple,*args,**kwargs) -> base.OptStep:"""Performs one iteration of the optax solver.Args:params: pytree containing the parameters.  应该是resnet参数的pytreestate: named tuple containing the solver state.  *args: additional positional arguments to be passed to ``fun``.**kwargs: additional keyword arguments to be passed to ``fun``.Returns:(params, state)"""if self.pre_update:params, state = self.pre_update(params, state, *args, **kwargs)(value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)# note the only difference between this function and the baseline # optax.OptaxSolver.lopt_update is that `extra_args` is now passed.# if you would like to use a different optimizer, you will likely need to# remove these extra_args.delta, opt_state = self.opt.update(grad, state.internal_state, params, extra_args={"loss": value})params = self._apply_updates(params, delta)# Computes optimality error before update to re-use grad evaluation.new_state = OptaxState(iter_num=state.iter_num + 1,error=tree_util.tree_l2_norm(grad),value=value,aux=aux,internal_state=opt_state)return base.OptStep(params=params, state=new_state)def train():# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make# it unavailable to JAX.# tf.config.experimental.set_visible_devices([], 'GPU')# typical data loading and iterator setuptrain_ds, ds_info = load_dataset("train", is_training=True,batch_size=TRAIN_BATCH_SIZE)test_ds, _ = load_dataset("test", is_training=False,batch_size=TEST_BATCH_SIZE)input_shape = (1,) + ds_info.features["image"].shapenum_classes = ds_info.features["label"].num_classesiter_per_epoch_train = ds_info.splits['train'].num_examples // TRAIN_BATCH_SIZEiter_per_epoch_test = ds_info.splits['test'].num_examples // TEST_BATCH_SIZE# Set up model.if MODEL == "resnet1":net = ResNet1(num_classes=num_classes)elif MODEL == "resnet18":net = ResNet18(num_classes=num_classes)elif MODEL == "resnet34":net = ResNet34(num_classes=num_classes)else:raise ValueError("Unknown model.")def predict(params, inputs, aux, train=False):x = inputs.astype(jnp.float32) / 255.all_params = {"params": params, "batch_stats": aux}if train:# Returns logits and net_state (which contains the key "batch_stats").return net.apply(all_params, x, train=True, mutable=["batch_stats"])else:# Returns logits only.return net.apply(all_params, x, train=False)logistic_loss = jax.vmap(loss.multiclass_logistic_loss)def loss_from_logits(params, l2reg, logits, labels):mean_loss = jnp.mean(logistic_loss(labels, logits))sqnorm = tree_util.tree_l2_norm(params, squared=True)return mean_loss + 0.5 * l2reg * sqnormdef accuracy_and_loss(params, l2reg, data, aux):inputs, labels = datalogits = predict(params, inputs, aux)accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)loss = loss_from_logits(params, l2reg, logits, labels)return accuracy, lossdef loss_fun(params, l2reg, data, aux):inputs, labels = datalogits, net_state = predict(params, inputs, aux, train=True)loss = loss_from_logits(params, l2reg, logits, labels)# batch_stats will be stored in state.auxreturn loss, net_state["batch_stats"]# The default optimizer used by jaxopt is commented out here# opt = optax.sgd(learning_rate=LEARNING_RATE,#                 momentum=MOMENTUM,#                 nesterov=True)NUM_STEPS = EPOCHS * iter_per_epoch_trainopt = prefab.optax_lopt(NUM_STEPS)# We need has_aux=True because loss_fun returns batch_stats.solver = OptaxSolver(opt=opt, fun=jax.value_and_grad(loss_fun, has_aux=True), maxiter=EPOCHS * iter_per_epoch_train, has_aux=True, value_and_grad=True)# Initialize parameters.# 初始化训练的参数rng = jax.random.PRNGKey(0)init_vars = net.init(rng, jnp.zeros(input_shape), train=True)  # 这里的net是resnet,但是我不清楚这里的 init_vars['params']是个什么东西 emmparams = init_vars["params"]batch_stats = init_vars["batch_stats"]start = datetime.now().replace(microsecond=0)# Run training loop.# 训练的循环state = solver.init_state(params, L2REG, next(test_ds), batch_stats)  # 初始化优化器jitted_update = jax.jit(functools.partial(lopt_update, self=solver))  # 艹,各种骚操作,这里的jax.jit是什么东西呀?print(f'Iterations: {solver.maxiter}')for _ in range(solver.maxiter):  # 优化器的最大迭代次数train_minibatch = next(train_ds)if state.iter_num % iter_per_epoch_train == iter_per_epoch_train - 1:# Once per epoch evaluate the model on the train and test sets.test_acc, test_loss = 0., 0.# make a pass over test set to compute test accuracyfor _ in range(iter_per_epoch_test):tmp = accuracy_and_loss(params, L2REG, next(test_ds), batch_stats)test_acc += tmp[0] / iter_per_epoch_testtest_loss += tmp[1] / iter_per_epoch_testtrain_acc, train_loss = 0., 0.# make a pass over train set to compute train accuracyfor _ in range(iter_per_epoch_train):tmp = accuracy_and_loss(params, L2REG, next(train_ds), batch_stats)train_acc += tmp[0] / iter_per_epoch_traintrain_loss += tmp[1] / iter_per_epoch_traintrain_acc = jax.device_get(train_acc)train_loss = jax.device_get(train_loss)test_acc = jax.device_get(test_acc)test_loss = jax.device_get(test_loss)# time elapsed without microsecondstime_elapsed = (datetime.now().replace(microsecond=0) - start)print(f"[Epoch {(state.iter_num+1) // (iter_per_epoch_train+1)}/{EPOCHS}] "f"Train acc: {train_acc:.3f}, train loss: {train_loss:.3f}. "f"Test acc: {test_acc:.3f}, test loss: {test_loss:.3f}. "f"Time elapsed: {time_elapsed}")params, state = jitted_update(params=params,state=state,l2reg=L2REG,data=train_minibatch,aux=batch_stats)batch_stats = state.auxtrain()

跟baseline的值比较一下

静态的baseline数据

#@title Comparing a baseline optimizer vs. VeLO on resnet18 cifar100.
baseline_train_acc = [0.235,
0.333,
0.428,
0.430,
0.480,
0.528,
0.591,
0.617,
0.661,
0.709,]baseline_test_acc = [0.216,
0.298,
0.362,
0.343,
0.359,
0.371,
0.375,
0.377,
0.379,
0.399,]velo_train_acc = [0.170,
0.270,
0.346,
0.331,
0.466,
0.477,
0.551,
0.749,
0.848,
0.955,]velo_test_acc = [0.163,
0.255,
0.310,
0.290,
0.377,
0.369,
0.385,
0.458,
0.464,
0.492,]from matplotlib.pyplot import figurefigure(figsize=(8, 6), dpi=80)plt.plot(range(10), baseline_train_acc, label="Baseline Train Accuracy", c='b',  linestyle='dashed')
plt.plot(range(10), baseline_test_acc, label = "Baseline Test Accuracy", c='b')
plt.plot(range(10), velo_train_acc, label= "VeLO Train Accuracy", c='r', linestyle='dashed')
plt.plot(range(10), velo_test_acc, label="VeLO Test Accuracy", c='r')
plt.xlabel("Training Epochs")
plt.ylabel("Accuracy")
plt.title("Training Accuracy Curves for Resnet18 on Cifar100")
plt.legend()
plt.show()#@title Comparing a baseline optimizer vs. VeLO on resnet18 cifar100.
baseline_train_loss = [3.470,
2.979,
2.535,
2.567,
2.351,
2.183,
1.970,
1.925,
1.781,
1.644,]baseline_test_loss = [3.571,
3.206,
2.899,
3.064,
3.055,
3.107,
3.170,
3.447,
3.530,
3.597,]velo_train_loss = [3.701,
3.071,
2.771,
2.948,
2.294,
2.287,
2.059,
1.268,
0.948,
0.645,]velo_test_loss = [3.739,
3.188,
2.974,
3.266,
2.797,
2.925,
3.062,
2.769,
2.950,
2.882]from matplotlib.pyplot import figurefigure(figsize=(8, 6), dpi=80)plt.plot(range(10), baseline_train_loss, label="Baseline Train Loss", c='b',  linestyle='dashed')
plt.plot(range(10), baseline_test_loss, label = "Baseline Test Loss", c='b')
plt.plot(range(10), velo_train_loss, label= "VeLO Train Loss", c='r', linestyle='dashed')
plt.plot(range(10), velo_test_loss, label="VeLO Test Loss", c='r')
plt.xlabel("Training Epochs")
plt.ylabel("Loss")
plt.title("Training Loss Curves for Resnet18 on Cifar100 ")
plt.legend()
plt.show()

Code for VeLO 2: Training Versatile Learned Optimizers by Scaling Up相关推荐

  1. Code for VeLO 1: Training Versatile Learned Optimizers by Scaling Up

    Code for VeLO 1: Training Versatile Learned Optimizers by Scaling Up 这篇文章将介绍一下怎么用VeLO进行训练. 这篇文章基于htt ...

  2. 2022 VeLO: Training Versatile Learned Optimizers by Scaling Up

    VeLO: Training Versatile Learned Optimizers by Scaling Up 通过扩展模型的规模来训练一个通用的优化器. 设计上,优化器的原理基于元学习的思路,即 ...

  3. 推荐一个java技术文章公众号

    ☕️Java基础 2018年如何快速学Java 泛型就这么简单 注解就这么简单 Druid数据库连接池就是这么简单 Object对象你真理解了吗? JDK10都发布了,nio你了解多少? COW奶牛! ...

  4. ICML 2019 Accepted Papers (Title, Author, Abstract, Code) (001-150)

    本博客致力于整理出ICML 2019接收的所有论文,包括题目.作者.摘要等重要信息,能够方便广大读者迅速找到自己领域相关的论文. 相关论文代码.附录可参考ICML 2019 #####1-10#### ...

  5. 卷积神经网络 手势识别_如何构建识别手语手势的卷积神经网络

    卷积神经网络 手势识别 by Vagdevi Kommineni 通过瓦格德维·科米尼(Vagdevi Kommineni) 如何构建识别手语手势的卷积神经网络 (How to build a con ...

  6. 2020年 ICLR 国际会议最终接受论文(poster-paper)列表(三)

    来源:AINLPer微信公众号(点击了解一下吧) 编辑: ShuYini 校稿: ShuYini 时间: 2020-02-21     2020年的ICLR会议将于今年的4月26日-4月30日在Mil ...

  7. 2020年 ICLR 国际会议最终接受论文(poster-paper)列表(一)

    来源:AINLPer微信公众号(点击了解一下吧) 编辑: ShuYini 校稿: ShuYini 时间: 2020-01-22     2020年的ICLR会议将于今年的4月26日-4月30日在Mil ...

  8. 《Gans in Action》第三章 用GAN生成手写数字

    此为<Gans in Action>(对抗神经网络实战)第三章读书笔记 Chapter 3. Your first GAN: Generating handwritten digits 用 ...

  9. 2020年 ICLR 国际会议最终接受论文(poster-paper)列表(四)

    来源:AINLPer微信公众号(点击了解一下吧) 编辑: ShuYini 校稿: ShuYini 时间: 2020-02-21     2020年的ICLR会议将于今年的4月26日-4月30日在Mil ...

最新文章

  1. 一文读懂Https的安全性原理、数字证书、单项认证、双项认证等
  2. Windows 8下看漫画的程序发布
  3. crontab添加定时任务
  4. 为app录制展示gif
  5. Redis分布式锁实现方式
  6. 面试官问:JS的this指向
  7. 数据处理工具(一)——Matplotlib
  8. 数据库 ER图 EER图(鸭蹼图) freedgo绘图工具
  9. SpringBoot兼容人大金仓数据库
  10. Linux如何验证AP6212(AP6236)的bluetooth功能
  11. psp android 模拟器,安卓psp模拟器
  12. c mysql_stmt游标移动_MySql数据库--stmt语句(续)
  13. 攻克拖延症——经历记录与心得分享
  14. autosar arxml文件配置(四)
  15. OpenCV Flann
  16. leaflet地图生成图片下载
  17. JAVA_协同过滤算法商品推荐
  18. 蛮X搜神记的NetManager分析(1)
  19. 牛奶可乐经济学之Q7:为什么官僚们喜欢使用语焉不详的句子?
  20. JS贪心算法,含图解

热门文章

  1. 群晖服务器216j增加硬盘,群晖(Synology)NAS 升级硬盘扩展空间小记
  2. VS2012源代码管理没有AnkhSVN
  3. 记一次阿里云k8s部署-测试存储
  4. 【02】Hadoop入门
  5. 金融统计分析python论文_金融统计分析论文
  6. 零基础零经验自学Python,到精通Python要多久啊?
  7. 特征提取网络之Darknet
  8. Java多线程案例-Java多线程(3)
  9. 向windows服务器传输大文件时提示未知错误解决方法
  10. 无机物及有机物储氢材料/MNi4.8Sn0.2(M=La,Nd)合金粒子负载纳米碳管复合储氢材料/LaNi4.8Sn0.2/CNTs纳米碳管复合储氢材料