视觉神经网络模型优秀开源工作:PyTorch Image Models(timm)库

PyTorchImageModels,简称timm,是一个巨大的PyTorch代码集合,包括了一系列:

  • image models
  • layers
  • utilities
  • optimizers
  • schedulers
  • data-loaders / augmentations
  • training / validation scripts

旨在将各种SOTA模型整合在一起,并具有复现ImageNet训练结果的能力。

PyTorch Image Models(timm) 是一个优秀的图像分类 Python 库,其包含了大量的图像模型(Image Models)、Optimizers、Schedulers、Augmentations 等等.

除了使用torchvision.models进行预训练以外,还有一个常见的预训练模型库,叫做timm,这个库是由来自加拿大温哥华Ross Wightman创建的。里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。在本章内容中,我们主要是针对这个库的预训练模型的使用做叙述,其他部分内容(数据扩增,优化器等)如果大家感兴趣,可以参考以下几个链接。

  • Github链接:https://github.com/rwightman/pytorch-image-models
  • 官网链接:https://fastai.github.io/timmdocs/ https://rwightman.github.io/pytorch-image-models/
  • 简略文档:https://rwightman.github.io/pytorch-image-models/
  • 详细文档:https://fastai.github.io/timmdocs/

安装

PyTorch Image Models(timm) 是一个优秀的图像分类 Python 库,其包含了大量的图像模型(Image Models)、Optimizers、Schedulers、Augmentations 等等.

timm 提供了参考的 training 和 validation 脚本,用于复现在 ImageNet 上的训练结果;以及更多的 官方文档 和 timmdocs project.

https://rwightman.github.io/pytorch-image-models/

https://fastai.github.io/timmdocs/

但,由于 timm 的功能之多,所以在定制使用时很难知道如何入手. 这里主要进行概述.

pip install timm==0.5.4

所有的开发和测试都是在 Linux x86-64系统上的 Conda Python 3环境中完成的,尤其是 Python 3.6和3.7 3.8 3.9
PyTorch 版本1.4、1.5. x、1.6、1.7. x 和1.8已经使用此代码进行了测试。

import timm

加载预先训练好的模型

我们只需要简单的create_model就可以得到我们的模型,并且如果我们需要使用我们的预训练模型,只需要加上参数pretrained=True即可

import timmm = timm.create_model('mobilenetv3_large_100', pretrained=True)
m.eval()
MobileNetV3((conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish()(blocks): Sequential((0): Sequential((0): DepthwiseSeparableConv((conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): ReLU(inplace=True)(se): Identity()(conv_pw): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): Identity()))(1): Sequential((0): InvertedResidual((conv_pw): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): ReLU(inplace=True)(conv_dw): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): ReLU(inplace=True)(se): Identity()(conv_pwl): Conv2d(64, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): InvertedResidual((conv_pw): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): ReLU(inplace=True)(conv_dw): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72, bias=False)(bn2): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): ReLU(inplace=True)(se): Identity()(conv_pwl): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(2): Sequential((0): InvertedResidual((conv_pw): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): ReLU(inplace=True)(conv_dw): Conv2d(72, 72, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=72, bias=False)(bn2): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): ReLU(inplace=True)(se): SqueezeExcite((conv_reduce): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1))(act1): ReLU(inplace=True)(conv_expand): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))(gate): Hardsigmoid())(conv_pwl): Conv2d(72, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): InvertedResidual((conv_pw): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): ReLU(inplace=True)(conv_dw): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)(bn2): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): ReLU(inplace=True)(se): SqueezeExcite((conv_reduce): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))(act1): ReLU(inplace=True)(conv_expand): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))(gate): Hardsigmoid())(conv_pwl): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(2): InvertedResidual((conv_pw): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): ReLU(inplace=True)(conv_dw): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)(bn2): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): ReLU(inplace=True)(se): SqueezeExcite((conv_reduce): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))(act1): ReLU(inplace=True)(conv_expand): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))(gate): Hardsigmoid())(conv_pwl): Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(3): Sequential((0): InvertedResidual((conv_pw): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish()(conv_dw): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=240, bias=False)(bn2): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): Hardswish()(se): Identity()(conv_pwl): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): InvertedResidual((conv_pw): Conv2d(80, 200, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish()(conv_dw): Conv2d(200, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=200, bias=False)(bn2): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): Hardswish()(se): Identity()(conv_pwl): Conv2d(200, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(2): InvertedResidual((conv_pw): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish()(conv_dw): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)(bn2): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): Hardswish()(se): Identity()(conv_pwl): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(3): InvertedResidual((conv_pw): Conv2d(80, 184, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish()(conv_dw): Conv2d(184, 184, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=184, bias=False)(bn2): BatchNorm2d(184, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): Hardswish()(se): Identity()(conv_pwl): Conv2d(184, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): Sequential((0): InvertedResidual((conv_pw): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish()(conv_dw): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=480, bias=False)(bn2): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): Hardswish()(se): SqueezeExcite((conv_reduce): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))(act1): ReLU(inplace=True)(conv_expand): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))(gate): Hardsigmoid())(conv_pwl): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): InvertedResidual((conv_pw): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish()(conv_dw): Conv2d(672, 672, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=672, bias=False)(bn2): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): Hardswish()(se): SqueezeExcite((conv_reduce): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))(act1): ReLU(inplace=True)(conv_expand): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))(gate): Hardsigmoid())(conv_pwl): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(5): Sequential((0): InvertedResidual((conv_pw): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish()(conv_dw): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=672, bias=False)(bn2): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): Hardswish()(se): SqueezeExcite((conv_reduce): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))(act1): ReLU(inplace=True)(conv_expand): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))(gate): Hardsigmoid())(conv_pwl): Conv2d(672, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): InvertedResidual((conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish()(conv_dw): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)(bn2): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): Hardswish()(se): SqueezeExcite((conv_reduce): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))(act1): ReLU(inplace=True)(conv_expand): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))(gate): Hardsigmoid())(conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(2): InvertedResidual((conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish()(conv_dw): Conv2d(960, 960, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=960, bias=False)(bn2): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act2): Hardswish()(se): SqueezeExcite((conv_reduce): Conv2d(960, 240, kernel_size=(1, 1), stride=(1, 1))(act1): ReLU(inplace=True)(conv_expand): Conv2d(240, 960, kernel_size=(1, 1), stride=(1, 1))(gate): Hardsigmoid())(conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(6): Sequential((0): ConvBnAct((conv): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act1): Hardswish())))(global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity())(conv_head): Conv2d(960, 1280, kernel_size=(1, 1), stride=(1, 1))(act2): Hardswish()(flatten): Flatten(start_dim=1, end_dim=-1)(classifier): Linear(in_features=1280, out_features=1000, bias=True)
)

列出具有预训练权重的模型

我们可以简单看看,一共大概有400多个模型,我们都是可以随意使用的

import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)
['adv_inception_v3','cait_m36_384','cait_m48_448','cait_s24_224','cait_s24_384','cait_s36_384','cait_xs24_384','cait_xxs24_224','cait_xxs24_384','cait_xxs36_224','cait_xxs36_384','coat_lite_mini','coat_lite_small','coat_lite_tiny','coat_mini','coat_tiny','convit_base','convit_small','convit_tiny','cspdarknet53','cspresnet50','cspresnext50','deit_base_distilled_patch16_224','deit_base_distilled_patch16_384','deit_base_patch16_224','deit_base_patch16_384','deit_small_distilled_patch16_224','deit_small_patch16_224','deit_tiny_distilled_patch16_224','deit_tiny_patch16_224','densenet121','densenet161','densenet169','densenet201','densenetblur121d','dla34','dla46_c','dla46x_c','dla60','dla60_res2net','dla60_res2next','dla60x','dla60x_c','dla102','dla102x','dla102x2','dla169','dm_nfnet_f0','dm_nfnet_f1','dm_nfnet_f2','dm_nfnet_f3','dm_nfnet_f4','dm_nfnet_f5','dm_nfnet_f6','dpn68','dpn68b','dpn92','dpn98','dpn107','dpn131','eca_nfnet_l0','eca_nfnet_l1','eca_nfnet_l2','ecaresnet26t','ecaresnet50d','ecaresnet50d_pruned','ecaresnet50t','ecaresnet101d','ecaresnet101d_pruned','ecaresnet269d','ecaresnetlight','efficientnet_b0','efficientnet_b1','efficientnet_b1_pruned','efficientnet_b2','efficientnet_b2_pruned','efficientnet_b3','efficientnet_b3_pruned','efficientnet_b4','efficientnet_el','efficientnet_el_pruned','efficientnet_em','efficientnet_es','efficientnet_es_pruned','efficientnet_lite0','efficientnetv2_rw_m','efficientnetv2_rw_s','ens_adv_inception_resnet_v2','ese_vovnet19b_dw','ese_vovnet39b','fbnetc_100','gernet_l','gernet_m','gernet_s','ghostnet_100','gluon_inception_v3','gluon_resnet18_v1b','gluon_resnet34_v1b','gluon_resnet50_v1b','gluon_resnet50_v1c','gluon_resnet50_v1d','gluon_resnet50_v1s','gluon_resnet101_v1b','gluon_resnet101_v1c','gluon_resnet101_v1d','gluon_resnet101_v1s','gluon_resnet152_v1b','gluon_resnet152_v1c','gluon_resnet152_v1d','gluon_resnet152_v1s','gluon_resnext50_32x4d','gluon_resnext101_32x4d','gluon_resnext101_64x4d','gluon_senet154','gluon_seresnext50_32x4d','gluon_seresnext101_32x4d','gluon_seresnext101_64x4d','gluon_xception65','gmixer_24_224','gmlp_s16_224','hardcorenas_a','hardcorenas_b','hardcorenas_c','hardcorenas_d','hardcorenas_e','hardcorenas_f','hrnet_w18','hrnet_w18_small','hrnet_w18_small_v2','hrnet_w30','hrnet_w32','hrnet_w40','hrnet_w44','hrnet_w48','hrnet_w64','ig_resnext101_32x8d','ig_resnext101_32x16d','ig_resnext101_32x32d','ig_resnext101_32x48d','inception_resnet_v2','inception_v3','inception_v4','legacy_senet154','legacy_seresnet18','legacy_seresnet34','legacy_seresnet50','legacy_seresnet101','legacy_seresnet152','legacy_seresnext26_32x4d','legacy_seresnext50_32x4d','legacy_seresnext101_32x4d','levit_128','levit_128s','levit_192','levit_256','levit_384','mixer_b16_224','mixer_b16_224_in21k','mixer_b16_224_miil','mixer_b16_224_miil_in21k','mixer_l16_224','mixer_l16_224_in21k','mixnet_l','mixnet_m','mixnet_s','mixnet_xl','mnasnet_100','mobilenetv2_100','mobilenetv2_110d','mobilenetv2_120d','mobilenetv2_140','mobilenetv3_large_100','mobilenetv3_large_100_miil','mobilenetv3_large_100_miil_in21k','mobilenetv3_rw','nasnetalarge','nf_regnet_b1','nf_resnet50','nfnet_l0','pit_b_224','pit_b_distilled_224','pit_s_224','pit_s_distilled_224','pit_ti_224','pit_ti_distilled_224','pit_xs_224','pit_xs_distilled_224','pnasnet5large','regnetx_002','regnetx_004','regnetx_006','regnetx_008','regnetx_016','regnetx_032','regnetx_040','regnetx_064','regnetx_080','regnetx_120','regnetx_160','regnetx_320','regnety_002','regnety_004','regnety_006','regnety_008','regnety_016','regnety_032','regnety_040','regnety_064','regnety_080','regnety_120','regnety_160','regnety_320','repvgg_a2','repvgg_b0','repvgg_b1','repvgg_b1g4','repvgg_b2','repvgg_b2g4','repvgg_b3','repvgg_b3g4','res2net50_14w_8s','res2net50_26w_4s','res2net50_26w_6s','res2net50_26w_8s','res2net50_48w_2s','res2net101_26w_4s','res2next50','resmlp_12_224','resmlp_12_distilled_224','resmlp_24_224','resmlp_24_distilled_224','resmlp_36_224','resmlp_36_distilled_224','resmlp_big_24_224','resmlp_big_24_224_in22ft1k','resmlp_big_24_distilled_224','resnest14d','resnest26d','resnest50d','resnest50d_1s4x24d','resnest50d_4s2x40d','resnest101e','resnest200e','resnest269e','resnet18','resnet18d','resnet26','resnet26d','resnet34','resnet34d','resnet50','resnet50d','resnet51q','resnet101d','resnet152d','resnet200d','resnetblur50','resnetrs50','resnetrs101','resnetrs152','resnetrs200','resnetrs270','resnetrs350','resnetrs420','resnetv2_50x1_bit_distilled','resnetv2_50x1_bitm','resnetv2_50x1_bitm_in21k','resnetv2_50x3_bitm','resnetv2_50x3_bitm_in21k','resnetv2_101x1_bitm','resnetv2_101x1_bitm_in21k','resnetv2_101x3_bitm','resnetv2_101x3_bitm_in21k','resnetv2_152x2_bit_teacher','resnetv2_152x2_bit_teacher_384','resnetv2_152x2_bitm','resnetv2_152x2_bitm_in21k','resnetv2_152x4_bitm','resnetv2_152x4_bitm_in21k','resnext50_32x4d','resnext50d_32x4d','resnext101_32x8d','rexnet_100','rexnet_130','rexnet_150','rexnet_200','selecsls42b','selecsls60','selecsls60b','semnasnet_100','seresnet50','seresnet152d','seresnext26d_32x4d','seresnext26t_32x4d','seresnext50_32x4d','skresnet18','skresnet34','skresnext50_32x4d','spnasnet_100','ssl_resnet18','ssl_resnet50','ssl_resnext50_32x4d','ssl_resnext101_32x4d','ssl_resnext101_32x8d','ssl_resnext101_32x16d','swin_base_patch4_window7_224','swin_base_patch4_window7_224_in22k','swin_base_patch4_window12_384','swin_base_patch4_window12_384_in22k','swin_large_patch4_window7_224','swin_large_patch4_window7_224_in22k','swin_large_patch4_window12_384','swin_large_patch4_window12_384_in22k','swin_small_patch4_window7_224','swin_tiny_patch4_window7_224','swsl_resnet18','swsl_resnet50','swsl_resnext50_32x4d','swsl_resnext101_32x4d','swsl_resnext101_32x8d','swsl_resnext101_32x16d','tf_efficientnet_b0','tf_efficientnet_b0_ap','tf_efficientnet_b0_ns','tf_efficientnet_b1','tf_efficientnet_b1_ap','tf_efficientnet_b1_ns','tf_efficientnet_b2','tf_efficientnet_b2_ap','tf_efficientnet_b2_ns','tf_efficientnet_b3','tf_efficientnet_b3_ap','tf_efficientnet_b3_ns','tf_efficientnet_b4','tf_efficientnet_b4_ap','tf_efficientnet_b4_ns','tf_efficientnet_b5','tf_efficientnet_b5_ap','tf_efficientnet_b5_ns','tf_efficientnet_b6','tf_efficientnet_b6_ap','tf_efficientnet_b6_ns','tf_efficientnet_b7','tf_efficientnet_b7_ap','tf_efficientnet_b7_ns','tf_efficientnet_b8','tf_efficientnet_b8_ap','tf_efficientnet_cc_b0_4e','tf_efficientnet_cc_b0_8e','tf_efficientnet_cc_b1_8e','tf_efficientnet_el','tf_efficientnet_em','tf_efficientnet_es','tf_efficientnet_l2_ns','tf_efficientnet_l2_ns_475','tf_efficientnet_lite0','tf_efficientnet_lite1','tf_efficientnet_lite2','tf_efficientnet_lite3','tf_efficientnet_lite4','tf_efficientnetv2_b0','tf_efficientnetv2_b1','tf_efficientnetv2_b2','tf_efficientnetv2_b3','tf_efficientnetv2_l','tf_efficientnetv2_l_in21ft1k','tf_efficientnetv2_l_in21k','tf_efficientnetv2_m','tf_efficientnetv2_m_in21ft1k','tf_efficientnetv2_m_in21k','tf_efficientnetv2_s','tf_efficientnetv2_s_in21ft1k','tf_efficientnetv2_s_in21k','tf_inception_v3','tf_mixnet_l','tf_mixnet_m','tf_mixnet_s','tf_mobilenetv3_large_075','tf_mobilenetv3_large_100','tf_mobilenetv3_large_minimal_100','tf_mobilenetv3_small_075','tf_mobilenetv3_small_100','tf_mobilenetv3_small_minimal_100','tnt_s_patch16_224','tresnet_l','tresnet_l_448','tresnet_m','tresnet_m_448','tresnet_m_miil_in21k','tresnet_xl','tresnet_xl_448','tv_densenet121','tv_resnet34','tv_resnet50','tv_resnet101','tv_resnet152','tv_resnext50_32x4d','twins_pcpvt_base','twins_pcpvt_large','twins_pcpvt_small','twins_svt_base','twins_svt_large','twins_svt_small','vgg11','vgg11_bn','vgg13','vgg13_bn','vgg16','vgg16_bn','vgg19','vgg19_bn','visformer_small','vit_base_patch16_224','vit_base_patch16_224_in21k','vit_base_patch16_224_miil','vit_base_patch16_224_miil_in21k','vit_base_patch16_384','vit_base_patch32_224','vit_base_patch32_224_in21k','vit_base_patch32_384','vit_base_r50_s16_224_in21k','vit_base_r50_s16_384','vit_huge_patch14_224_in21k','vit_large_patch16_224','vit_large_patch16_224_in21k','vit_large_patch16_384','vit_large_patch32_224_in21k','vit_large_patch32_384','vit_large_r50_s32_224','vit_large_r50_s32_224_in21k','vit_large_r50_s32_384','vit_small_patch16_224','vit_small_patch16_224_in21k','vit_small_patch16_384','vit_small_patch32_224','vit_small_patch32_224_in21k','vit_small_patch32_384','vit_small_r26_s32_224','vit_small_r26_s32_224_in21k','vit_small_r26_s32_384','vit_tiny_patch16_224','vit_tiny_patch16_224_in21k','vit_tiny_patch16_384','vit_tiny_r_s16_p8_224','vit_tiny_r_s16_p8_224_in21k','vit_tiny_r_s16_p8_384','wide_resnet50_2','wide_resnet101_2','xception','xception41','xception65','xception71']

通过通配符选择模型架构

这个方法,可以让我们快速找到我们所需要的模型,这样可以方便我们进行create_model

model_names = timm.list_models('*resne*t*')
pprint(model_names)
['bat_resnext26ts','cspresnet50','cspresnet50d','cspresnet50w','cspresnext50','cspresnext50_iabn','eca_lambda_resnext26ts','ecaresnet26t','ecaresnet50d','ecaresnet50d_pruned','ecaresnet50t','ecaresnet101d','ecaresnet101d_pruned','ecaresnet200d','ecaresnet269d','ecaresnetlight','ecaresnext26t_32x4d','ecaresnext50t_32x4d','ens_adv_inception_resnet_v2','gcresnet50t','gcresnext26ts','geresnet50t','gluon_resnet18_v1b','gluon_resnet34_v1b','gluon_resnet50_v1b','gluon_resnet50_v1c','gluon_resnet50_v1d','gluon_resnet50_v1s','gluon_resnet101_v1b','gluon_resnet101_v1c','gluon_resnet101_v1d','gluon_resnet101_v1s','gluon_resnet152_v1b','gluon_resnet152_v1c','gluon_resnet152_v1d','gluon_resnet152_v1s','gluon_resnext50_32x4d','gluon_resnext101_32x4d','gluon_resnext101_64x4d','gluon_seresnext50_32x4d','gluon_seresnext101_32x4d','gluon_seresnext101_64x4d','ig_resnext101_32x8d','ig_resnext101_32x16d','ig_resnext101_32x32d','ig_resnext101_32x48d','inception_resnet_v2','lambda_resnet26t','lambda_resnet50t','legacy_seresnet18','legacy_seresnet34','legacy_seresnet50','legacy_seresnet101','legacy_seresnet152','legacy_seresnext26_32x4d','legacy_seresnext50_32x4d','legacy_seresnext101_32x4d','nf_ecaresnet26','nf_ecaresnet50','nf_ecaresnet101','nf_resnet26','nf_resnet50','nf_resnet101','nf_seresnet26','nf_seresnet50','nf_seresnet101','resnest14d','resnest26d','resnest50d','resnest50d_1s4x24d','resnest50d_4s2x40d','resnest101e','resnest200e','resnest269e','resnet18','resnet18d','resnet26','resnet26d','resnet26t','resnet34','resnet34d','resnet50','resnet50d','resnet50t','resnet51q','resnet61q','resnet101','resnet101d','resnet152','resnet152d','resnet200','resnet200d','resnetblur18','resnetblur50','resnetrs50','resnetrs101','resnetrs152','resnetrs200','resnetrs270','resnetrs350','resnetrs420','resnetv2_50','resnetv2_50d','resnetv2_50t','resnetv2_50x1_bit_distilled','resnetv2_50x1_bitm','resnetv2_50x1_bitm_in21k','resnetv2_50x3_bitm','resnetv2_50x3_bitm_in21k','resnetv2_101','resnetv2_101d','resnetv2_101x1_bitm','resnetv2_101x1_bitm_in21k','resnetv2_101x3_bitm','resnetv2_101x3_bitm_in21k','resnetv2_152','resnetv2_152d','resnetv2_152x2_bit_teacher','resnetv2_152x2_bit_teacher_384','resnetv2_152x2_bitm','resnetv2_152x2_bitm_in21k','resnetv2_152x4_bitm','resnetv2_152x4_bitm_in21k','resnext50_32x4d','resnext50d_32x4d','resnext101_32x4d','resnext101_32x8d','resnext101_64x4d','seresnet18','seresnet34','seresnet50','seresnet50t','seresnet101','seresnet152','seresnet152d','seresnet200d','seresnet269d','seresnext26d_32x4d','seresnext26t_32x4d','seresnext26tn_32x4d','seresnext50_32x4d','seresnext101_32x4d','seresnext101_32x8d','skresnet18','skresnet34','skresnet50','skresnet50d','skresnext50_32x4d','ssl_resnet18','ssl_resnet50','ssl_resnext50_32x4d','ssl_resnext101_32x4d','ssl_resnext101_32x8d','ssl_resnext101_32x16d','swsl_resnet18','swsl_resnet50','swsl_resnext50_32x4d','swsl_resnext101_32x4d','swsl_resnext101_32x8d','swsl_resnext101_32x16d','tresnet_l','tresnet_l_448','tresnet_m','tresnet_m_448','tresnet_m_miil_in21k','tresnet_xl','tresnet_xl_448','tv_resnet34','tv_resnet50','tv_resnet101','tv_resnet152','tv_resnext50_32x4d','vit_base_resnet26d_224','vit_base_resnet50_224_in21k','vit_base_resnet50_384','vit_base_resnet50d_224','vit_small_resnet26d_224','vit_small_resnet50d_s16_224','wide_resnet50_2','wide_resnet101_2']

https://rwightman.github.io/pytorch-image-models/models/ 介绍了timm实现的一些网络模型及其论文和参考代码
https://paperswithcode.com/lib/timm 也有列出

模型及论文

  • CNN模型:

添加了经典的 NFNet,RegNet,TResNet,Lambda Networks,GhostNet,ByoaNet 等以及 TResNet, MobileNet-V3, ViT 的 ImageNet-21k 训练的权重,EfficientNet-V2 ImageNet-1k,ImageNet-21k 训练的权重。

  • Transformer模型:

添加了经典的 TNT,Swin Transformer,PiT,Bottleneck Transformers,Halo Nets,CoaT,CaiT,LeViT, Visformer, ConViT,Twins,BiT 等。

  • MLP模型:

添加了经典的 MLP-Mixer,ResMLP,gMLP等。

  • 优化器层面:

更新了Adabelief optimizer等。

所以本文是对 timm 库代码的最新解读,不只限于视觉 transformer 模型。

所有的PyTorch模型及其对应arxiv链接如下:

  • Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
  • Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370
  • Bottleneck Transformers - https://arxiv.org/abs/2101.11605
  • CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
  • CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
  • ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
  • CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
  • DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
  • DenseNet - https://arxiv.org/abs/1608.06993
  • DLA - https://arxiv.org/abs/1707.06484
  • DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
  • EfficientNet (MBConvNet Family)
    • EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252
    • EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665
    • EfficientNet (B0-B7) - https://arxiv.org/abs/1905.11946
    • EfficientNet-EdgeTPU (S, M, L) - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html
    • EfficientNet V2 - https://arxiv.org/abs/2104.00298
    • FBNet-C - https://arxiv.org/abs/1812.03443
    • MixNet - https://arxiv.org/abs/1907.09595
    • MNASNet B1, A1 (Squeeze-Excite), and Small - https://arxiv.org/abs/1807.11626
    • MobileNet-V2 - https://arxiv.org/abs/1801.04381
    • Single-Path NAS - https://arxiv.org/abs/1904.02877
  • GhostNet - https://arxiv.org/abs/1911.11907
  • gMLP - https://arxiv.org/abs/2105.08050
  • GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
  • Halo Nets - https://arxiv.org/abs/2103.12731
  • HardCoRe-NAS - https://arxiv.org/abs/2102.11646
  • HRNet - https://arxiv.org/abs/1908.07919
  • Inception-V3 - https://arxiv.org/abs/1512.00567
  • Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
  • Lambda Networks - https://arxiv.org/abs/2102.08602
  • LeViT (Vision Transformer in ConvNet’s Clothing) - https://arxiv.org/abs/2104.01136
  • MLP-Mixer - https://arxiv.org/abs/2105.01601
  • MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
  • NASNet-A - https://arxiv.org/abs/1707.07012
  • NFNet-F - https://arxiv.org/abs/2102.06171
  • NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
  • PNasNet - https://arxiv.org/abs/1712.00559
  • Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
  • RegNet - https://arxiv.org/abs/2003.13678
  • RepVGG - https://arxiv.org/abs/2101.03697
  • ResMLP - https://arxiv.org/abs/2105.03404
  • ResNet/ResNeXt
    • ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385
    • ResNeXt - https://arxiv.org/abs/1611.05431
    • ‘Bag of Tricks’ / Gluon C, D, E, S variations - https://arxiv.org/abs/1812.01187
    • Weakly-supervised (WSL) Instagram pretrained / ImageNet tuned ResNeXt101 - https://arxiv.org/abs/1805.00932
    • Semi-supervised (SSL) / Semi-weakly Supervised (SWSL) ResNet/ResNeXts - https://arxiv.org/abs/1905.00546
    • ECA-Net (ECAResNet) - https://arxiv.org/abs/1910.03151v4
    • Squeeze-and-Excitation Networks (SEResNet) - https://arxiv.org/abs/1709.01507
    • ResNet-RS - https://arxiv.org/abs/2103.07579
  • Res2Net - https://arxiv.org/abs/1904.01169
  • ResNeSt - https://arxiv.org/abs/2004.08955
  • ReXNet - https://arxiv.org/abs/2007.00992
  • SelecSLS - https://arxiv.org/abs/1907.00837
  • Selective Kernel Networks - https://arxiv.org/abs/1903.06586
  • Swin Transformer - https://arxiv.org/abs/2103.14030
  • Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
  • TResNet - https://arxiv.org/abs/2003.13630
  • Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf
  • Vision Transformer - https://arxiv.org/abs/2010.11929
  • VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
  • Xception - https://arxiv.org/abs/1610.02357
  • Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611
  • Xception (Modified Aligned, TF) - https://arxiv.org/abs/1802.02611
  • XCiT (Cross-Covariance Image Transformers) - https://arxiv.org/abs/2106.09681

1. Models

timm 提供了大量的模型结构集合,而且很多模型都包含了预训练权重,或 PyTorch 训练、或从Jax和TensorFlow中移植,很方便下载使用.

模型列表:https://paperswithcode.com/lib/timm

查看模型列表:

#打印 timm 提供的模型列表
print(timm.list_models())
print(len(timm.list_models())) #739#带有预训练权重的模型列表
print(timm.list_models(pretrained=True))
print(len(timm.list_models(pretrained=True))) #592

其中,timm.list_models() 函数:

list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False)

查看特定族模型,如:

print(timm.list_models('gluon_resnet*'))
print(timm.list_models('*resnext*', 'resnet') )
print(timm.list_models('resnet*', pretrained=True))

1.1. create_model 一般用法

timm 创建模型最简单的方式是采用 create_model.

以 Resnet-D 模型为例(Bag of Tricks for Image Classification For Convolutional Neural Networks paper),其是Resnet 的一种变形,其采用 average pool 进行下采样.

model = timm.create_model('resnet50d', pretrained=True)
print(model)#查看模型配置参数
print(model.default_cfg)
'''
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth','num_classes': 1000,'input_size': (3, 224, 224),'pool_size': (7, 7),'crop_pct': 0.875,'interpolation': 'bicubic','mean': (0.485, 0.456, 0.406),'std': (0.229, 0.224, 0.225),'first_conv': 'conv1.0','classifier': 'fc','architecture': 'resnet50d'}
'''

1.2. create_model 修改输入通道

timm models 有个非常有用的特点,其可以处理任意通道数量的输入图像. 这是很多其他库所不具备的. 其实现原理可参考:

https://fastai.github.io/timmdocs/models#So-how-is-timm-able-to-load-these-weights?

model = timm.create_model('resnet50d', pretrained=True, in_chans=1)
print(model)#test, single channel image
x = troch.randn(1, 1, 224, 224)out = model(x)
print(out.shape) #torch.Size([1, 1000])

1.3. create_model 定制模型

timm create_model 函数提供了很多参数,用于模型定制,函数定义如:

create_model(model_name, pretrained=False, checkpoint_path='', scriptable=None, exportable=None, no_jit=None, **kwargs)

**kwargs 示例参数如,

  • global_pool - 定义最终分类层所采用的 global pooling 类型. 取决于网络结构是否用到了全局池化层.
  • drop_rate - 设定训练时的 dropout 比例,默认是 0.
  • num_classes - 输出类别数

1.3.1. 修改类别数

查看当前模型输出层:

#如果输出层是 fc,则如
print(model.fc)
#Linear(in_features=2048, out_features=1000, bias=True)#通用方式,查看输出层,
print(model.get_classifier())

修改输出层类别数:

model = timm.create_model('resnet50d', pretrained=True, num_classes=10)
print(model)
print(model.get_classifier())
#Linear(in_features=2048, out_features=10, bias=True)

如果完全不需要创建最后一层,可以将 num_classes 设为 0,模型将用恒等函数作为最后一层,其对于查看倒数第二层的输出有用.

model = timm.create_model('resnet50d', pretrained=True, num_classes=0)
print(model)
print(model.get_classifier())
#Identity()

1.3.2. Global pooling

model.default_cfg 中出现的 pool_size 设置,说明了在分类器前用到了一个全局池化层,如:

print(model.global_pool)
#SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))

其中,pool_type 支持:

  • avg - 平均池化
  • max - 最大池化
  • avgmax - 平均池化和最大池化的求和,加权 0.5
  • catevgmax - 沿着特征维度的平均池化和最大池化的输出的拼接,特征维度会翻倍
  • '' - 不采用 pooling,其被替换为恒等操作(Identity)
pool_types = ['avg', 'max', 'avgmax', 'catavgmax', '']
x = torch.randn(1, 3, 224, 224)for pool_type in pool_types:model = timm.create_model('resnet50d', pretrained=True, num_classes=0, global_pool=pool_type)model.eval()out = model(x)print(out.shape)

1.3.3. 修改已有模型

如,

model = timm.create_model('resnet50d', pretrained=True)
print(f'[INFO]Original Pooling: {model.global_pool}')
print(f'[INFO]Original Classifier: {model.get_classifier}')model = model.reset_classifier(10, 'max')
print(f'[INFO]Modified Pooling: {model.global_pool}')
print(f'[INFO]Modified Classifier: {model.get_classifier}')

1.3.4. 创建新的分类 head

虽然单个线性层已经足够得到比较好的结果,但有些时候需要更大的分类 head 来提升性能.

model = timm.create_model('resnet50d', pretrained=True, num_classes=10, global_pool='catavgmax')
print(model)num_in_features = model.get_classifier().in_features
print(num_in_features)model.fc = nn.Sequential(nn.BatchNorm1d(num_in_features),nn.Linear(in_features=num_in_features, out_features=512, bias=False),nn.ReLU(),nn.BatchNorm1d(512),nn.Dropout(0.4),nn.Linear(in_features=512, out_features=10, bias=False))
model.eval()
x = troch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)

1.4. 特征提取

timm 提供了很多不同类型网络中间层的机制,其有助于作为特征提取以应用于下游任务.

1.4.1. 最终特征图

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch image = Image.open('test.jpg')
image = torch.as_tensor(np.array(image, dtype=np.float32)).transpose(2, 0)[None]model = timm.create_model("resnet50d", pretrained=True)
print(model.default_cfg)#如,只查看最终特征图,这里是池化层前的最后一个卷积层的输出
feature_output = model.forward_features(image)def vis_feature_output(feature_output):plt.imshow(feature_output[0]).transpose(0, 2).sum(-1).detach().numpy())plt.show()
#
vis_feature_output(feature_output)

1.4.2. 多种特征输出

model = timm.create_model("resnet50d", pretrained=True, features_only=True)print(model.feature_info.module_name())
#['act1', 'layer1', 'layer2', 'layer3', 'layer4']print(model.feature_info.reduction())
#[2, 4, 8, 16, 32]print(model.feature_info.channels())
#[64, 256, 512, 1024, 2048]out = model(image)
print(len(out)) # 5
for o in out:print(o.shape)plt.imshow(o[0].transpose(0, 2).sum(-1).detach().numpy())plt.show()

1.4.3. 采用 Torch FX

TorchVision 新增了一个 FX 模块,其更便于获得输入在前向计算过程中的中间变换. 通过符号性的追踪前向方法,以生成一个图,途中的每个节点表示一个操作. 由于节点是易读的,其可以很方便的准确指定到具体节点.

https://pytorch.org/docs/stable/fx.html#module-torch.fx

https://pytorch.org/blog/FX-feature-extraction-torchvision/

#torchvision >= 0.11.0
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractormodel = timm.create_model("resnet50d", pretrained=True, exportable=True)nodes, _ = get_graph_node_names(model)
print(nodes)features = {'layer1.0.act2': 'out'}
feature_extractor = create_feature_extractor(model, return_nodes=features)
print(feature_extractor)out = feature_extractor(image)
plt.imshow(out['out'][0].transpose(0, 2).sum(-1).detach().numpy())
plt.show()

1.5. 模型导出不同格式

模型训练后,一般推荐将模型导出为优化的格式,以进行推断.

1.5.1. 导出 TorchScript

https://pytorch.org/docs/stable/jit.html

https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

model = timm.create_model("resnet50d", pretrained=True, scriptable=True)
model.eval() #重要scripted_model = torch.jit.script(model)
print(scripted_model)
print(scripted_model(torch.rand(8, 3, 224, 224)).shape)

1.5.2. 导出 ONNX

Open Neural Network eXchange (ONNX)

https://pytorch.org/docs/master/onnx.html

model = timm.create_model("resnet50d", pretrained=True, exportable=True)
model.eval() #重要x = torch.randn(2, 3, 224, 224, requires_grad=True)
torch_out = model(x)#Export the model
torch.onnx.export(model,                   #模型x,                        #输入'resnet50d.onnx',         #模型导出路径export_params=True,      #模型文件存储训练参数权重opset_version=10,        #ONNX 版本do_constant_folding=True,#是否执行不断折叠优化input_names=['input'],   #输入名output_names=['output'], #输出名dynamic_axes={'input': {0: 'batch_size'},'output': {0: 'batch_size'}})#验证导出模型
import onnxonnx_model = onnx.load('resnet50d.onnx')
onnx.checker.check_model(onnx_model)traced_model = torch.jit.trace(model, torch.rand(8, 3, 224, 224))
type(traced_model)print(traced_model(torch.rand(8, 3, 224, 224)).shape)

2. Augmentations

timm 的数据格式与 TorchVision 类似,PIL 图像作为输入.

from timm.data.transforms_factory import create_transformprint(create_transform(224, ))
'''
Compose(Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)CenterCrop(size=(224, 224))ToTensor()Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
'''print(create_transform(224, is_training=True))
'''
Compose(RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear)RandomHorizontalFlip(p=0.5)ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.6, 1.4], hue=None)ToTensor()Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)'''

2.1. RandAugment

对于新任务场景,很难确定要用到哪些数据增强. 且,鉴于如此多的数据增强策略,其组合数量更是庞大.

一种好的起点是,采用在其他任务上被验证有效的数据增强pipeline. 如,RandAugment

RandAugment,是一种自动数据增强方法,其从增强方法集合中均匀采样,如, equalization, rotation, solarization, color jittering, posterizing, changing contrast, changing brightness, changing sharpness, shearing, and translations,并按序应用其中的一些.

RandAugment: Practical automated data augmentation with a reduced search space

RandAugment 参数:

  • N - 随机变换的数量( number of distortions uniformly sampled and applied per-image)
  • M - 变换的幅度(distortion magnitude)

timm 中 RandAugment 是通过配置字符串来指定的,以 - 分割符.

  • m - 随机增强的幅度
  • n - 每张图像进行的随机变换数,默认为 2.
  • mstd - 标准偏差的噪声幅度
  • mmax - 设置幅度的上界,默认 10
  • w - 加权索引的概率(index of a set of weights to influence choice of operation)
  • inc - 采用随幅度增加的数据增强,默认为 0

如,

  • rand-m9-n3-mstd0.5 - 幅度为9,每张图像 3 种数据增强,mstd 为 0.5
  • rand-mstd1-w0 - mstd 为 1.0,weights 为 0,默认幅度m为10,每张图像 2 种数据增强
print(create_transform(224, is_training=True, auto_augment='rand-m9-mstd0.5'))
'''
Compose(RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear)RandomHorizontalFlip(p=0.5)RandAugment(n=2, ops=AugmentOp(name=AutoContrast, p=0.5, m=9, mstd=0.5)AugmentOp(name=Equalize, p=0.5, m=9, mstd=0.5)AugmentOp(name=Invert, p=0.5, m=9, mstd=0.5)AugmentOp(name=Rotate, p=0.5, m=9, mstd=0.5)AugmentOp(name=Posterize, p=0.5, m=9, mstd=0.5)AugmentOp(name=Solarize, p=0.5, m=9, mstd=0.5)AugmentOp(name=SolarizeAdd, p=0.5, m=9, mstd=0.5)AugmentOp(name=Color, p=0.5, m=9, mstd=0.5)AugmentOp(name=Contrast, p=0.5, m=9, mstd=0.5)AugmentOp(name=Brightness, p=0.5, m=9, mstd=0.5)AugmentOp(name=Sharpness, p=0.5, m=9, mstd=0.5)AugmentOp(name=ShearX, p=0.5, m=9, mstd=0.5)AugmentOp(name=ShearY, p=0.5, m=9, mstd=0.5)AugmentOp(name=TranslateXRel, p=0.5, m=9, mstd=0.5)AugmentOp(name=TranslateYRel, p=0.5, m=9, mstd=0.5))ToTensor()Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
'''

也可以通过 rand_augment_transform 函数来实现:

from timm.data.auto_augment import rand_augment_transformtfm = rand_augment_transform(config_str='rand-m9-mstd0.5',hparams={'img_mean': (124, 116, 104)})
print(tfm)
'''
RandAugment(n=2, ops=AugmentOp(name=AutoContrast, p=0.5, m=9, mstd=0.5)AugmentOp(name=Equalize, p=0.5, m=9, mstd=0.5)AugmentOp(name=Invert, p=0.5, m=9, mstd=0.5)AugmentOp(name=Rotate, p=0.5, m=9, mstd=0.5)AugmentOp(name=Posterize, p=0.5, m=9, mstd=0.5)AugmentOp(name=Solarize, p=0.5, m=9, mstd=0.5)AugmentOp(name=SolarizeAdd, p=0.5, m=9, mstd=0.5)AugmentOp(name=Color, p=0.5, m=9, mstd=0.5)AugmentOp(name=Contrast, p=0.5, m=9, mstd=0.5)AugmentOp(name=Brightness, p=0.5, m=9, mstd=0.5)AugmentOp(name=Sharpness, p=0.5, m=9, mstd=0.5)AugmentOp(name=ShearX, p=0.5, m=9, mstd=0.5)AugmentOp(name=ShearY, p=0.5, m=9, mstd=0.5)AugmentOp(name=TranslateXRel, p=0.5, m=9, mstd=0.5)AugmentOp(name=TranslateYRel, p=0.5, m=9, mstd=0.5))
'''

2.2. CutMix 和 Mixup

CutMix

Mixup

timm 的 Mixup 类,支持的不同混合策略有:

  • batch - CutMix vs Mixup selection, lambda, and CutMix region sampling are performed per batch
  • pair - mixing, lambda, and region sampling are performed on sampled pairs within a batch
  • elem - mixing, lambda, and region sampling are performed per image within batch
  • half - the same as elementwise but one of each mixing pair is discarded so that each sample is seen once per epoch

Mixup 支持的数据增强有:

  • mixup_alpha (float): mixup alpha value, mixup is active if > 0., (default: 1)
  • cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. (default: 0)
  • cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
  • prob (float): the probability of applying mixup or cutmix per batch or element (default: 1)
  • switch_prob (float): the probability of switching to cutmix instead of mixup when both are active (default: 0.5)
  • mode (str): how to apply mixup/cutmix params (default: batch)
  • label_smoothing (float): the amount of label smoothing to apply to the mixed target tensor (default: 0.1)
  • num_classes (int): the number of classes for the target variable
from timm.data import ImageDataset
from torch.utils.data import DataLoaderdef create_dataloader_iterator():dataset = ImageDataset('pets/images', transform=create_transform(224, ))dl = iter(DataLoader(dataset, batch_size=2))return dldataloader = create_dataloader_iterator()
inputs, classes = next(dataloader)#
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])#
from timm.data.mixup import Mixupmixup_args = {'mixup_alpha': 1.,'cutmix_alpha': 1.,'prob': 1,'switch_prob': 0.5,'mode': 'batch', 'label_smoothing': 0.1,'num_classes': 2}mixup_fn = Mixup(**mixup_args)
mixed_inputs, mixed_classes = mixup_fn(inputs.to(torch.device('cuda:0')),classes.to(torch.device('cuda:0')))
out = torchvision.utils.make_grid(mixed_inputs)
imshow(out, title=mixed_classes)

3. Datasets

timm 中 create_dataset 函数期望有两个输入参数:

  • name - 指定待加载数据集的名字
  • root - 数据集存放根目录

其支持不同的数据存储:

  • TorchVision
  • TensorFlow datasets
  • 本地文件夹
#TorchVision
ds = create_dataset('torch/cifar10', 'cifar10', download=True, split='train')
print(ds, type(ds))
print(ds[0])#TensorFlow
ds = create_dataset('tfds/beans', 'beans', download=True, split='train[:10%]', batch_size=2, is_training=True)
print(ds)
ds_iter = iter(ds)
image, label = next(ds_iter)#本地文件夹
ds = create_dataset(name='', root='imagenette/imagenette2-320.tar', transfor=create_transform(224))
image, label = ds[0]
print(image.shape)

3.1. ImageDataset 类

除了 create_dataset,timm 还提供了两个 ImageDatasetIterableImageDataset 以适应更多的场景.

from timm.data import ImageDatasetimagenette_ds = ImageDataset('imagenette/imagenette2-320/train')
print(len(imagenette_ds))
print(imagenette_ds.parser)
print(imagenette_ds.parser.class_to_idx)from timm.data.parser.parser_image_in_tar import ParserImageTardata_path = 'imagenette'
ds = ImageDataset(data_path, parser=ParserImageInTar(data_path))

3.1.1. 定制 Parser

参考 ParserImageFolder:

""" A dataset parser that reads images from foldersFolders are scannerd recursively to find image files. Labels are based
on the folder hierarchy, just leaf folders by default.Hacked together by / Copyright 2020 Ross Wightman
"""
import osfrom timm.utils.misc import natural_keyfrom .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONSdef find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):labels = []filenames = []for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):rel_path = os.path.relpath(root, folder) if (root != folder) else ''label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')for f in files:base, ext = os.path.splitext(f)if ext.lower() in types:filenames.append(os.path.join(root, f))labels.append(label)if class_to_idx is None:# building class indexunique_labels = set(labels)sorted_labels = list(sorted(unique_labels, key=natural_key))class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]if sort:images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))return images_and_targets, class_to_idxclass ParserImageFolder(Parser):def __init__(self,root,class_map=''):super().__init__()self.root = rootclass_to_idx = Noneif class_map:class_to_idx = load_class_map(class_map, root)self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)if len(self.samples) == 0:raise RuntimeError(f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')def __getitem__(self, index):path, target = self.samples[index]return open(path, 'rb'), targetdef __len__(self):return len(self.samples)def _filename(self, index, basename=False, absolute=False):filename = self.samples[index][0]if basename:filename = os.path.basename(filename)elif not absolute:filename = os.path.relpath(filename, self.root)return filename

如:

from pathlib import Pathfrom timm.data.parsers.parser import Parserclass ParserImageName(Parser):def __init__(self, root, class_to_idx=None):super().__init__()self.root = Path(root)self.samples = list(self.root.glob("*.jpg"))if class_to_idx:self.class_to_idx = class_to_idxelse:classes = sorted(set([self.__extract_label_from_path(p) for p in self.samples]),key=lambda s: s.lower(),)self.class_to_idx = {c: idx for idx, c in enumerate(classes)}def __extract_label_from_path(self, path):return "_".join(path.parts[-1].split("_")[0:-1])def __getitem__(self, index):path = self.samples[index]target = self.class_to_idx[self.__extract_label_from_path(path)]return open(path, "rb"), targetdef __len__(self):return len(self.samples)def _filename(self, index, basename=False, absolute=False):filename = self.samples[index][0]if basename:filename = filename.parts[-1]elif not absolute:filename = filename.absolute()return filename#
data_path = 'test'
ds = ImageDataset(data_path, parser=ParserImageName(data_path))
print(ds[0])
print(ds.parser.class_to_idx)

4. Optimizers

timm 支持的优化器有:

  • SGD
  • Adam
  • AdamW
  • AdamP
  • RMSPropTF
  • LAMB - FusedLAMB optimizer from Apex 的 PyTorch 版
  • AdaBelief
  • MADGRAD
  • AdaHessian
import inspect
import timm.optimoptims_list = [cls_name for cls_name, cls_obj in inspect.getmembers(timm.optim) if inspect.isclass(cls_obj) if cls_name != 'Lookhead']
print(optims_list)

timm 中 create_optimizer_v2 函数.

import torchmodel = torch.nn.Sequential(torch.nn.Linear(2, 1), torch.nn.Flatten(0, 1))optimizer = timm.optim.create_optimizer_v2(model, opt='sgd', lr=0.01, momentum=0.8)
print(optimizer, type(optimizer))
'''
SGD (
Parameter Group 0dampening: 0lr: 0.01momentum: 0.8nesterov: Trueweight_decay: 0.0
)
<class 'torch.optim.sgd.SGD'>
'''optimizer = timm.optim.create_optimizer_v2(model, opt='lamb', lr=0.01, weight_decay=0.01)
print(optimizer, type(optimizer))
'''
Lamb (
Parameter Group 0always_adapt: Falsebetas: (0.9, 0.999)bias_correction: Trueeps: 1e-06grad_averaging: Truelr: 0.01max_grad_norm: 1.0trust_clip: Falseweight_decay: 0.0Parameter Group 1always_adapt: Falsebetas: (0.9, 0.999)bias_correction: Trueeps: 1e-06grad_averaging: Truelr: 0.01max_grad_norm: 1.0trust_clip: Falseweight_decay: 0.01
)
<class 'timm.optim.lamb.Lamb'>
'''

手工创建优化器,如:

optimizer = timm.optim.RMSpropTF(model.parameters(), lr=0.01)

4.1. 使用示例

# replace
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# with
optimizer = timm.optim.AdamP(model.parameters(), lr=0.01)for epoch in num_epochs:for batch in training_dataloader:inputs, targets = batchoutputs = model(inputs)loss = loss_function(outputs, targets)loss.backward()optimizer.step()optimizer.zero_grad()#
optimizer = timm.optim.Adahessian(model.parameters(), lr=0.01)is_second_order = (hasattr(optimizer, "is_second_order") and optimizer.is_second_order
)  # Truefor epoch in num_epochs:for batch in training_dataloader:inputs, targets = batchoutputs = model(inputs)loss = loss_function(outputs, targets)loss.backward(create_graph=second_order)optimizer.step()optimizer.zero_grad()

4.2. Lookahead

Lookahead Optimizer: k steps forward, 1 step back

optimizer = timm.optim.create_optimizer_v2(model.parameters(), opt='lookahead_adam', lr=0.01)
#或
timm.optim.Lookahead(optimizer, alpha=0.5, k=6)
optimizer.sync_lookahead()

示例如,

optimizer = timm.optim.AdamP(model.parameters(), lr=0.01)
optimizer = timm.optim.Lookahead(optimizer)for epoch in num_epochs:for batch in training_dataloader:inputs, targets = batchoutputs = model(inputs)loss = loss_function(outputs, targets)loss.backward()optimizer.step()optimizer.zero_grad()optimizer.sync_lookahead()

5. Schedulers

timm 支持的 Schedulers 有:

  • StepLRScheduler: 每 n 次迭代衰减一次学习率,类似于 torch.optim.lr_scheduler.StepLR
  • MultiStepLRScheduler: 设置特定迭代次数,衰减学习率,类似于 torch.optim.lr_scheduler.MultiStepLR
  • PlateauLRScheduler: reduces the learning rate by a specified factor each time a specified metric plateaus; 类似于 torch.optim.lr_scheduler.ReduceLROnPlateau
  • CosineLRScheduler: cosine decay schedule with restarts, 类似于 torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
  • TanhLRScheduler: hyberbolic-tangent decay schedule with restarts
  • PolyLRScheduler: polynomial decay schedule

5.1. 使用示例

与PyTorch shceduler 不同的是,timm scheduler 每个 epoch 更新两次:

  • .step_update - 每次 optimizer 更新后调用.
  • .step - 每个 epoch 结束后调用
training_epochs = 300
cooldown_epochs = 10
num_epochs = training_epochs + cooldown_epochsoptimizer = timm.optim.AdamP(my_model.parameters(), lr=0.01)
scheduler = timm.scheduler.CosineLRScheduler(optimizer, t_initial=training_epochs)for epoch in range(num_epochs):num_steps_per_epoch = len(train_dataloader)num_updates = epoch * num_steps_per_epochfor batch in training_dataloader:inputs, targets = batchoutputs = model(inputs)loss = loss_function(outputs, targets)loss.backward()optimizer.step()scheduler.step_update(num_updates=num_updates)optimizer.zero_grad()scheduler.step(epoch + 1)

5.2. CosineLRScheduler

为了深入阐述 timm 所提供的参数选项,这里以 timm 默认训练脚本中所采用的 sheduler - CosineLRScheduler 为例.

timm 的 cosine scheduler 与 PyTorch 中的实现是不同的.

5.2.1. PyTorch CosineAnnealingWarmRestarts

CosineAnnealingWarmRestarts 需要设定如下参数:

  • T_0 (int): Number of iterations for the first restart.
  • T_mult (int): A factor that increases T_{i} after a restart. (Default: 1)
  • eta_min (float): Minimum learning rate. (Default: 0.)
  • last_epoch (int) — The index of last epoch. (Default: -1)
#args
num_epochs=300
num_epoch_repeat=num_epochs//2
num_steps_per_epoch=10def create_model_and_optimizer():model = torch.nn.Linear(2, 1)optimizer = torch.optim.SGD(model.parameters(), lr=0.05)return model, optimizer#create learning rate scheduler
model, optimizer = create_model_and_optimizer()
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=num_epoch_repeat*num_steps_per_epoch,T_mult=1,eta_min=1e-6,last_epoch=-1)#vis
import matplotlib.pyplot as plt lrs = []
for epoch in range(num_epochs):for i in range(num_steps_per_epoch):scheduler.step()lrs.append(optimizer.param_groups[0]['lr'])
plt.plot(lrs)
plt.show()

可以看出,lr 在 150 epoch 前保持衰减,而在第 150 epoch 时重启为初始值,并开始再次衰减.

5.2.2. timm CosineLRScheduler

timm CosineLRScheduler 需要设定如下参数:

  • t_initial (int): Number of iterations for the first restart, this is equivalent to T_0 in torch’s implementation
  • lr_min (float): Minimum learning rate, this is equivalent to eta_min in torch’s implementation (Default: 0.)
  • cycle_mul (float): A factor that increases T_{i} after a restart, this is equivalent to T_mult in torch’s implementation (Default: 1)
  • cycle_limit (int): Limit the number of restarts in a cycle (Default: 1)
  • t_in_epochs (bool): Whether the number iterations is given in terms of epochs rather than the number of batch updates (Default: True)
#args
num_epochs=300
num_epoch_repeat=num_epochs//2
num_steps_per_epoch=10def create_model_and_optimizer():model = torch.nn.Linear(2, 1)optimizer = torch.optim.SGD(model.parameters(), lr=0.05)return model, optimizer#create learning rate scheduler
model, optimizer = create_model_and_optimizer()
scheduler = timm.scheduler.CosineLRScheduler(optimizer,t_initial=num_epoch_repeat*num_steps_per_epoch,lr_min=1e-6,cycle_limit=num_epoch_repeat+1,t_in_epochs=False)
#or
scheduler = timm.scheduler.CosineLRScheduler(optimizer,t_initial=num_epoch_repeat,lr_min=1e-6,cycle_limit=num_epoch_repeat+1,t_in_epochs=True)#vis
import matplotlib.pyplot as plt lrs = []
for epoch in range(num_epochs):num_updates = epoch * num_steps_per_epochfor i in range(num_steps_per_epoch):num_updates += 1scheduler.step_update(num_updates=num_updates)scheduler.step(epoch+1)lrs.append(optimizer.param_groups[0]['lr'])
plt.plot(lrs)
plt.show()

示例策略:

scheduler = timm.scheduler.CosineLRScheduler(optimizer,t_initial=num_epoch_repeat*num_steps_per_epoch,cycle_mul=2.,cycle_limit=num_epoch_repeat+1,t_in_epochs=False)scheduler = timm.scheduler.CosineLRScheduler(optimizer,t_initial=num_epoch_repeat*num_steps_per_epoch,lr_min=1e-5,cycle_limit=1)scheduler = timm.scheduler.CosineLRScheduler(optimizer,t_initial=50,lr_min=1e-5,cycle_decay=0.8,cycle_limit=num_epoch_repeat+1)scheduler = timm.scheduler.CosineLRScheduler(optimizer,t_initial=num_epoch_repeat*num_steps_per_epoch,lr_min=1e-5,k_decay=0.5,cycle_limit=num_epoch_repeat+1)scheduler = timm.scheduler.CosineLRScheduler(optimizer,t_initial=num_epoch_repeat*num_steps_per_epoch,lr_min=1e-5,k_decay=2,cycle_limit=num_epoch_repeat+1)

5.2.3. 添加 warm up

如,设置 20 个 warm up epochs,

#args
num_epochs=300
num_epoch_repeat=num_epochs//2
num_steps_per_epoch=10def create_model_and_optimizer():model = torch.nn.Linear(2, 1)optimizer = torch.optim.SGD(model.parameters(), lr=0.05)return model, optimizer#create learning rate scheduler
scheduler = timm.scheduler.CosineLRScheduler(optimizer,t_initial=num_epoch_repeat,lr_min=1e-5,cycle_limit=num_epoch_repeat+1,warmup_lr_init=0.01,warmup_t=20)#vis
import matplotlib.pyplot as plt lrs = []
for epoch in range(num_epochs):num_updates = epoch * num_steps_per_epochfor i in range(num_steps_per_epoch):num_updates += 1scheduler.step_update(num_updates=num_updates)scheduler.step(epoch+1)lrs.append(optimizer.param_groups[0]['lr'])
plt.plot(lrs)
plt.show()

5.2.4. 添加 noise

#args
num_epochs=300
num_epoch_repeat=num_epochs//2
num_steps_per_epoch=10def create_model_and_optimizer():model = torch.nn.Linear(2, 1)optimizer = torch.optim.SGD(model.parameters(), lr=0.05)return model, optimizer#create learning rate scheduler
scheduler = timm.scheduler.CosineLRScheduler(optimizer,t_initial=num_epoch_repeat,lr_min=1e-5,cycle_limit=num_epoch_repeat+1,noise_range_t=(0, 150), #noise_range_t:噪声范围noise_pct=0.1) #noise_pct:噪声程度#vis
import matplotlib.pyplot as plt lrs = []
for epoch in range(num_epochs):num_updates = epoch * num_steps_per_epochfor i in range(num_steps_per_epoch):num_updates += 1scheduler.step_update(num_updates=num_updates)scheduler.step(epoch+1)lrs.append(optimizer.param_groups[0]['lr'])
plt.plot(lrs)
plt.show()

5.3. timm 默认设置

def create_model_and_optimizer():model = torch.nn.Linear(2, 1)optimizer = torch.optim.SGD(model.parameters(), lr=0.05)return model, optimizer#create learning rate scheduler
model, optimizer = create_model_and_optimizer()#args
training_epochs=300
cooldown_epochs=10
num_epochs=training_epochs + cooldown_epochs
num_steps_per_epoch=10scheduler = timm.scheduler.CosineLRScheduler(optimizer,t_initial=training_epochs,lr_min=1e-6,t_in_epochs=True,warmup_t=3,warmup_lr_init=1e-4,cycle_limit=1) # no restart#vis
import matplotlib.pyplot as plt lrs = []
for epoch in range(num_epochs):num_updates = epoch * num_steps_per_epochfor i in range(num_steps_per_epoch):num_updates += 1scheduler.step_update(num_updates=num_updates)scheduler.step(epoch+1)lrs.append(optimizer.param_groups[0]['lr'])
plt.plot(lrs)
plt.show()

5.4. 其他 Scheduler

#TanhLRScheduler
scheduler = timm.scheduler.TanhLRScheduler(optimizer,t_initial=num_epoch_repeat,lr_min=1e-6,cycle_limit=num_epoch_repeat+1)#PolyLRScheduler
scheduler = timm.scheduler.PolyLRScheduler(optimizer,t_initial=num_epoch_repeat,lr_min=1e-6,cycle_limit=num_epoch_repeat+1)scheduler = timm.scheduler.PolyLRScheduler(optimizer,t_initial=num_epoch_repeat,lr_min=1e-6,cycle_limit=num_epoch_repeat+1,k_decay=0.5)scheduler = timm.scheduler.PolyLRScheduler(optimizer,t_initial=num_epoch_repeat,lr_min=1e-6,cycle_limit=num_epoch_repeat+1,k_decay=2)

6. EMA 模型指数移动平均

EMA,Exponential Moving Average Model

模型训练时,一种好的方式是,将模型权重值设置为整个训练过程中所有参数的移动平均,而不是仅仅只采用最后一次增量更新的.

实际上,这往往是通过保持 EMA 来实现的,其是训练的模型副本.

不过,相比于每次更新 step 更新全量的模型参数,一般将这些参数设置为当前参数值和更新参数值的线性组合,公式如下:
updated_EMA_model_weights=decay∗EMA_model_weights+(1.−decay)∗updated_model_weightsupdated\_EMA\_model\_weights = decay * EMA\_model\_weights + (1. - decay) * updated\_model\_weights updated_EMA_model_weights=decay∗EMA_model_weights+(1.−decay)∗updated_model_weights
如,
updated_EMA_model_weights=0.99∗EMA_model_weights+0.01∗updated_model_weightsupdated\_EMA\_model\_weights = 0.99 * EMA\_model\_weights + 0.01 * updated\_model\_weights updated_EMA_model_weights=0.99∗EMA_model_weights+0.01∗updated_model_weights
timm 中 ModelEmaV2 示例,

model = create_model().to(gpu_device)
ema_model = timm.utils.ModelEmaV2(model, decay=0.9998)for epoch in num_epochs:for batch in training_dataloader:inputs, targets = batchoutputs = model(inputs)loss = loss_function(outputs, targets)loss.backward()optimizer.step()optimizer.zero_grad()ema_model.update(model)for batch in validation_dataloader:inputs, targets = batchoutputs = model(inputs)validation_loss = loss_function(outputs, targets)ema_model_outputs = ema_model.module(inputs)ema_model_validation_loss = loss_function(ema_model_outputs, targets)

参考

  • https://www.aiuai.cn/aifarm1967.html
  • https://www.cxymm.net/article/qq_39280836/120160547

视觉神经网络模型优秀开源工作:PyTorch Image Models(timm)库相关推荐

  1. 神经网络模型无法正常工作时我们应该做什么(系列)——数据标准化处理(Normalize)

    欢迎访问我的个人博客:zengzeyu.com   前言 当你进入深度学习领域,准备好深度神经网络,开始进行训练时,遇到这样一个大部分新手都会遇到的问题:你的神经网络没法正常工作,而你不知道该如何去修 ...

  2. 神经网络能用来干什么_知识普及:卷积神经网络模型是怎样工作的?可以做些什么?...

    在走进深度学习的过程中,最吸引作者的是一些用于给对象分类的模型.最新的科研结果表示,这类模型已经可以在实时视频中对多个对象进行检测.而这就要归功于计算机视觉领域最新的技术革新. 众所周知,在过去的几年 ...

  3. 优秀开源项目推荐之--文档库bookstack

    之前波哥给大家介绍了最牛X开源cmdb系统,最牛X的开源论坛系统等等一些列优秀的开源项目.当然还有咱们自己家的最牛X的一键部署系统(还在升级中暂时不可用哈!). 今天再给大家推荐一款最牛逼的知识文档库 ...

  4. 轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur

    作者丨科技猛兽 编辑丨极市平台 清华大学自动化系智能计算实验室团队开源基于 PyTorch 的视频 (图片) 去模糊框架 SimDeblur. 基于 PyTorch 的视频 (图片) 去模糊框架 Si ...

  5. Python人脸微笑识别2-----Ubuntu16.04基于Tensorflow卷积神经网络模型训练的Python3+Dlib+Opencv实现摄像头人脸微笑检测

    Python人脸微笑识别2--卷积神经网络进行模型训练目录 一.微笑数据集下载 1.微笑数据集下载 2.创建人脸微笑识别项目 3.数据集上传至Ubuntu人脸微笑识别项目文件夹 二.Python代码实 ...

  6. 17届全国大学生智能汽车竞赛 中国石油大学(华东)智能视觉组 国特开源

    17届全国大学生智能汽车竞赛 中国石油大学(华东)智能视觉组 国特开源 第一部分:art 矩阵库 透视变换 地图识别 卡尔曼滤波多目标追踪 第二部分:模型训练 环境配置 训练 量化 超模型 数据增强 ...

  7. 通过pytorch建立神经网络模型 分析遗传基因数据

    DNA双螺旋(已对齐)合并神经网络(黄色) 我最近进行了有关基因序列的研究工作.我想到的主要问题是:"哪一种最简单的神经网络能与遗传数据最匹配".经过大量文献回顾,我发现与该主题相 ...

  8. PyTorch | (4)神经网络模型搭建和参数优化

    PyTorch | (1)初识PyTorch PyTorch | (2)PyTorch 入门-张量 PyTorch | (3)Tensor及其基本操作 PyTorch | (4)神经网络模型搭建和参数 ...

  9. 谷歌重磅开源新技术:5行代码打造无限宽神经网络模型,帮助“打开ML黑匣子”...

    鱼羊 假装发自 凹非寺 量子位 报道 | 公众号 QbitAI 只要网络足够宽,深度学习动态就能大大简化,并且更易于理解. 最近的许多研究结果表明,无限宽度的DNN会收敛成一类更为简单的模型,称为高斯 ...

最新文章

  1. 9 个技巧让你的 PyTorch 模型训练变得飞快!
  2. 用神经网络二分类理论重述双原子化合物的成键过程
  3. [HAOI2015]树上染色(树形dp,树形背包)
  4. 燕山大学计算机专业研究生怎么样,求助大家!重庆邮电大学计算机专业的研究生值得一读吗?...
  5. python中if嵌套语句的作用_讲解Python中if语句的嵌套用法
  6. 形象易懂讲解算法I——小波变换
  7. mysql 5.5 双机热备_mysql 5.5双机热备份 master-master
  8. powerdesigner 连接数据库
  9. spring加载bean的流程
  10. 计算机log是代表什么,Log是什么文件?Log文件可以删除吗?
  11. AcWing 1934. 贝茜放慢脚步
  12. 数据可视化?不如用最经典的工具画最酷炫的图(EXCEL/PPT)
  13. 谷歌八年算法工程师分享几点算法学习小技巧
  14. 上海亚商投顾:沪指缩量反弹 新能源汽车产业链走强
  15. DB2 SQL错误查询 LOAD时报的日志特别好用
  16. python 局域网广播_Python实现局域网内屏幕广播的技术要点分析
  17. matlab自动生成报告,一种基于MATLAB的Word报告自动生成方法
  18. 奶爸英语学习课程要点(第8课, 彻底突破常速英语, 中级)
  19. Serif PagePlus X9使用日历
  20. 智能CAN透传记录云网关 远程监控、调试和配置,远程程序下载,远程记录及下载

热门文章

  1. ASP.NET开源框架HIPPO系统技术内幕(一)
  2. 腻害了,牛人利用Python实现“天眼系统”,一张照片就能了解个人信息
  3. eclipse蓝牙项目list结束20201225
  4. java毕业设计番剧资讯检索系统Mybatis+系统+数据库+调试部署
  5. 怎样将计算机和电视机连接网络,电脑如何连电视机连接 电脑连接电视机步骤【详解】...
  6. 吉他谱:Kiss the Rain - Sky Guitar Level 2
  7. 同样将AI装进“办公全家桶”,Microsoft 与 Google有何不同?
  8. 2021-2027全球与中国自动切割系统市场现状及未来发展趋势
  9. 【支持向量机SVM系列教程4】SVM应用实战
  10. 模式识别学习笔记——第2章 统计学习方法—2.3最小风险贝叶斯决策