目录

  • 1. idea
    • 1.1 实验思路
    • 1.2 灵感来源
  • 2. 实验设置
  • 3. 实验结果
    • 3.1 结果
    • 3.2 结果分析
      • 3.2.1 一个奇怪的现象
      • 3.2.2 分析
  • 4. 代码

写在前面:本实验并未获得预期的结果,更多的是当作实验记录。

1. idea

1.1 实验思路

这个实验的思路是这样的:通过随机初始化(正态分布)的未经过训练的ResNet、ViT和SwinTransformer,来对ImangeNet-1k(2012)的验证集(val,共50000张图片,1000类别)进行预测,对比预测结果和随机猜(准确率为1‰)的区别。

1.2 灵感来源

灵感来自Deep Clustering for Unsupervised Learning of Visual Features这篇论文,论文中提到这么一段话:

这段话说的是:“当模型(如CNN)中的参数 θ \theta θ使用高斯分布随机初始化时,在未经训练时,预测出的结果很差。但是,要比随即猜测(对于ImageNet-1k来讲随机猜对的概率为千分之1)要好不少。”

作者给出的解释是,这是因为我们模型,如CNN,引入了先验知识(即CNN关于平移等变性、局部性等归纳偏置),所以随机初始化后,虽然没有经过数据训练,但得出的结果也要比随机猜要好。

所以我就在想,用未经训练的随机初始化的ResNet、ViT和SwinTransformer来预测ImageNet-1k(val),准确率越高,从某种程度上来讲说明引入的归纳偏置越强。

2. 实验设置

实验环境 Pytorch1.12、2xRTX3090(24G)
模型 ResNet、ViT、SwinTransformer
测试数据集 ImageNet-1k(2012,val,50000张,1000类别)
实验模型 参数量
ResNet50 25.56M
ResNet101 44.55M
ViT_B_16 86.57M
ViT_L_16 304.33M
swin_t 28.29M
swin_s 49.61M

其中初始化基本都是采用各种类型的正态分布进行初始化,比如

nn.init.trunc_normal_
nn.init.normal_
nn.init.kaiming_normal_

具体用法参考Pytorch官方文档

3. 实验结果

3.1 结果

实验结果是:未经训练的、随机初始化的ResNet、ViT和SwinTransformer,最后预测的结果都和随机猜的差不多(千分之一)。

模型 预测正确数量 测试集数量 准确率 训练时间 batch_size 显卡使用情况
resnet50 50 50000 1‰ 35.485s 256
resnet101 50 50000 1‰ 45.556s 256
vit_b_16 50 50000 1‰ 83.587s 256
vit_l_16 50 50000 1‰ 370.091s 64
swin_t 43 50000 ~1‰ 44.616s 256
swin_s 49 50000 ~1‰ 67.754s 256

3.2 结果分析

3.2.1 一个奇怪的现象

有个比较有意思的现象,我们可以看到对于ResNet系列和ViT系列,都是预测对了50个样本。这是因为对于未经训练的随机初始化的ResNet和ViT,不论输入什么样本,预测出的标签是一样的。

注意:不是说输出都是一样的,而是说最后经过softmax输出的1000维特征,最大值对应的索引都相同。 所以才会有正好50个预测对的。(Imagnet-1K的val有50000张图片,1000个类别,每个类别都有50张图片)。

而对于SwinTransformer,最后预测的标签确实看上去是随机分布的,但也跟瞎猜的一样(都是1‰),并没有像Deep Cluster的作者说的那样:要比瞎猜的好上不少。

3.2.2 分析

我觉得Deep Cluster的作者说的应该是没错的(毕竟是大佬写的文章,还是顶会),问题应该出在我的代码实现的细节上,比如初始化的方式?我是使用TorchVision源码中自带的初始化,感觉也没什么问题啊。

这个现象我不打算再深入了,因为本身就只是想读一下Deep Cluster,看到作者说的这个现象觉得很有趣就像试一试,但没出现预期的效果。如果想深入找一下原因的话可以看看Deep Cluster的引用文献。

4. 代码

import os
import timeimport torch
import torch.nn as nn
from torchvision.models import resnet50, resnet101, vit_b_16, vit_l_16, swin_t, swin_s
from torchvision import datasets, transforms
from tqdm import tqdmdef get_dataloader(data_dir=None, batch_size=64):''':param data_dir: val dataset direction:return: imageNet'''assert data_dir is not Nonetransform = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.2, 1.)),transforms.ToTensor(),# 这里我在考虑是否要进行标准化,可以做个对比实验transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])val_dataset = datasets.ImageFolder(data_dir,transform,)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,num_workers=8,shuffle=True,)return val_loaderdef get_models():''':return: res50, res101, ViT, SwinTransformer'''# load model with random initializationres50 = resnet50()res101 = resnet101()vit_B_16 = vit_b_16()vit_L_16 = vit_l_16()swin_T = swin_t()  # 参数量和res50相近swin_B = swin_s()  # 参数量和res101相近model_list = [res50, res101, vit_B_16, vit_L_16, swin_T, swin_B]model_names = ['res50', 'res101', 'vit_B_16', 'vit_L_16', 'swin_T', 'swin_B']for name, model in zip(model_names, model_list):print(f'{name:10}parametersize is {compute_params(model): .2f}M')return model_list, model_namesdef compute_params(model):''':param model: nn.Module, model:return: float, model parameter size'''total = sum(p.numel() for p in model.parameters())size = total / 1e6# print("Total params: %.2fM" % (total / 1e6))return sizedef model_evaluate(model, data_loader):''':param model: test model:param data_loader: val_loader:return: list[float] acc list'''device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')device_ids = [1, 2]  # use 2 GPUsmodel = nn.DataParallel(model, device_ids=device_ids)model.to(device)model.eval()total = 0correct = 0loop = tqdm((data_loader), total=len(data_loader))for imgs, labels in loop:imgs.to(device)outputs = model(imgs)outputs = outputs.argmax(dim=1)labels = labels.to(device)# print(outputs.shape, '\n', outputs, outputs.argmax(dim=1))# print(labels.shape, '\n', labels)total += len(labels)res = outputs==labelscorrect += res.sum().item()loop.set_description(f'inference test:')loop.set_postfix(total=total, correct=correct, acc=f'{correct/total:.2f}')if __name__ == '__main__':seed = 2022torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)data_dir = os.path.join('..', '..', '..', 'data', 'ImageNet2012', 'imagenet', 'val')val_loader = get_dataloader(data_dir, batch_size=256)get_models()  # 输出模型的大小,本来想写一个循环自动训练所有的model,但测试会爆显存,所以就单独测试每个model了net = swin_s()  # 这里换成我们测试的模型,可以用resnet50, resnet101, vit_b_16, vit_l_16, swin_t, swin_st1 = time.time()model_evaluate(net, val_loader)print(f'total time: {time.time() - t1:.3f}s')# nets, net_names = get_models()# for net, name in zip(nets, net_names):#     if name == 'vit_L_16':#         val_loader = get_dataloader(data_dir, batch_size=16)#     val_loader = get_dataloader(data_dir, batch_size=128)##     t1 = time.time()#     model_evaluate(net, val_loader)#     print(f'{name:.10} total time: {time.time()-t1:.3f}s')

参考:
1. Deep Clustering for Unsupervised Learning of Visual Features

2. Pytorch官方文档

【小实验1】比较ResNet、ViT、SwinTransformer的归纳偏置(然而并没有达到预期结果)相关推荐

  1. CMS垃圾收集器小实验之CMSInitiatingOccupancyFraction参数

    点击上方"方志朋",选择"设为星标" 回复"666"获取新整理的面试文章 背景 测试CMSInitiatingOccupancyFracti ...

  2. 秒懂JVM的三大参数类型,就靠这十个小实验了

    来源 | 悟空聊架构(ID:PassJava666) 本实验的目的是讲解 JVM 的三大参数类型.在JVM调优中用到的最多的 XX 参数,而如何去查看和设置 JVM 的 XX 参数也是调优的基本功,本 ...

  3. 小实验:用创建进程()打开计算器,然后关闭进程句柄。再用打开进程(进程ID),使用两次,得到两个进程句柄。实验目的:这两个进程句柄都能控制这个进程吗?通过该试验加深对句柄的理解!!...

    小实验:用创建进程()打开计算器,然后关闭进程句柄.再用打开进程(进程ID),使用两次,得到两个进程句柄.实验目的:这两个进程句柄都能控制这个进程吗? .版本 2 .程序集 窗口程序集1 .子程序 _ ...

  4. 【 FPGA 】超声波测距小实验(一)

    超声波测距原理: 超声波测距原理是在超声波发射装置发出超声波,它的根据是接收器接到超声波时的时间差,与雷达测距原理相似. 超声波发射器向某一方向发射超声波,在发射时刻的同时开始计时,超声波在空气中传播 ...

  5. 【 FPGA 】按键消抖与LED灯流动小实验

    记录一个小实验吧,实验的目的是仅仅是塞塞牙缝而已,没其他意思,很简单. 功能:拨码开关控制led灯工作与否,拨码开关为on,led灯工作,否则不工作:导航按键up和down,也就是独立按键而已,控制l ...

  6. [na]出口选路pbr小实验视频

    什么是策略路由? 一般都是部署在出口路由器,用于路径强制分发的, 优先级高于路由表. 策略路由小实验视频 这个是读书时候录的一个策略路由小实验 转载于:https://www.cnblogs.com/ ...

  7. 用计算机做科学实验评课,科学小实验课程听课心得

    010在线为您甄选多篇描写科学小实验课程听课心得,科学小实验课程听课心得精选,科学小实验课程听课心得大全,有议论,叙事 ,想象等形式.文章字数有400字.600字.800字....缓存时间: 2021 ...

  8. cisco 路由器监控路由连通性_Cisco-路由器配置DHCP小实验

    ​本文介绍在思科路由器上配置DHCP服务端的小实验,旨在让大家掌握其配置命令和配置思路.实际上,DHCP作为一种非常常用的网络服务,对于小型网络.家用网络都是部署在网关设备上. 1 拓扑: 配置DHC ...

  9. 菜鸟学习JavaScript小实验之函数引用

      function tt()         {             alert(11);         }         var b = tt;         var b1 = tt() ...

最新文章

  1. discoGAN 论文解读
  2. 第四周课程总结实验报告(二)
  3. 看懂通信协议:自定义通信协议设计之TLV编码应用
  4. 【CF603E】Pastoral Oddities cdq分治+并查集
  5. php 判断是否文件,php 判断是否一个文件的函数is_file()应用举例
  6. (2021) 22 [持久化] 1-Bit的存储
  7. oppo 手机侧滑快捷菜单_oppo手机如何截图 oppo手机快捷键截屏方法【教程】
  8. 数据工作者的福音:Google 发布正式版数据搜索工具啦!
  9. 响应式下的雪碧图解决方案 - 活用background-size / background-position
  10. 字符串连接符(Java)
  11. cvs导入oracle缺失逗号,pandas教程:使用read_csv()导入数据
  12. 汇编程序:成绩分段统计
  13. 公司邮箱域名注册申请,域名邮箱如何解析?邮箱域名是什么?
  14. 东信杯题解详细版本附带代码(还有日常琐碎bb)
  15. Latex:大于等于号和小于等于号
  16. 2010.4 计算机二级等级考试 vb上机试题 第一套 的答案,2012年计算机二级VB上机试题及解题思路第44套...
  17. 浅谈混迹力扣和codeforces上的几个月
  18. 2023 XL软件库App后端源码 可自定义易支付 完整版
  19. 数据库两表联查、多表联查,多重联查
  20. 缩略图方式下, 资源管理器,不能显示文件名

热门文章

  1. 来自阿里前端的一些中肯建议
  2. 「得到」的竞品是谁?「王者荣耀」啊!
  3. 《道德经》与“低熵”思想炫酷实现(.html)
  4. reading 摘录一
  5. 【英语考研词汇训练营】Day 17 —— espresso,ultimate,gradually,detect,dimension
  6. 脚手架的logo字符图片生成
  7. PAT 自学题解 B1033【测试点4超时】
  8. 在线客服系统IM即时通讯聊天源码
  9. 数据存储的声音 - 第9集:与Stephen Foskett的对话
  10. jquery实现按钮倒数7秒后才可以点击