RuntimeError: size mismatch
RuntimeError: size mismatch出现于pytorch框架下加载VGG11预训练模型时出现的错误。笔者初期认为,出现该错误的主要原因是输出和输入的维度不匹配。带着疑惑,我们通过输出网络模型结构来观察:
我们可以看到(18): Conv2d输出的维度是512,(avgpool): AdaptiveAvgPool2d输出的维度是7*7,(0): Linear输入的维度是25088。在这里,我们将25088分解,即25088=512*7*7。当时笔者设置的是model.avgpool = nn.AdaptiveAvgPool2d(1),因此512*1*1=512!=25088。随后,model.avgpool = nn.AdaptiveAvgPool2d(7)调试成功。
建议:加载预训练模型时若出现size mismatch的问题,先观察网络结构,重点观察
model.avgpool = nn.AdaptiveAvgPool2d(1)
model.fc = nn.Linear(2048,config.num_classes)
中的参数是否吻合原网络。
这里补充如果去除预训练网络某几层,自定义添加网络层方法。
import torchvision.models as models
from torchsummary import summary
import torch
import torchvision
import torch.nn.functional as F
from torch import nn
from config import configclass Net(nn.Module):def __init__(self, model):super(Net, self).__init__()# 去除原网络最后两层self.resnet_layer = nn.Sequential(*list(model.children())[:-2])# 自定义添加网络层self.transion_layer = nn.ConvTranspose2d(2048, 2048, kernel_size=14, stride=3)self.pool_layer = nn.MaxPool2d(32)self.Linear_layer = nn.Linear(2048, 8)def forward(self, x):x = self.resnet_layer(x)x = self.transion_layer(x)x = self.pool_layer(x)x = x.view(x.size(0), -1)x = self.Linear_layer(x)return x
定义好自己的网络后,加载预训练网络,传入到Net中即可。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg19bnft = models.vgg19_bn().to(device)
vgg19bnft = Net(vgg19bnft)
print(vgg19bnft)
若想单独修改最后的全连接层,可以尝试如下方法:
removed = list(vgg19bn.classifier.children())[:-1]
vgg19bn.classifier = torch.nn.Sequential(*removed)
vgg19bn.add_module('fc', torch.nn.Linear(4096, 7))
print(vgg19bn)
因为加载预训练模型的种类是1000,这里根据自己的需求,修改为了7。
最后,倘若尝试以上方法未果,可以留意输入图片尺寸。注意训练时输入图片尺寸与加载模型时使用的输入图片尺寸是否一致。一般来说输入图片的大小为224的倍数(要找一个72的指数次方,72的5次方等于7*32=224)
向IT工作者致敬,后丹之喜碧CatBrother欢迎吐槽:
后丹-喜碧CatBrother
RuntimeError: size mismatch相关推荐
- 做项目遇到问题 2 AWS NLP 剽窃RuntimeError: size mismatch, m1: [10 x 3], m2: [2 x 10]检测部署报错
报错 RuntimeError: size mismatch, m1: [10 x 3], m2: [2 x 10] 原因: train.csv 为100x4 4列 第一列 标签是否剽窃 ...
- 【error】RuntimeError: size mismatch与全连接fc层
今天跑代码的时候遇到了这个错误: RuntimeError: size mismatch, m1:[1152 x 1] ,m2:[576 x 192] ,at /opt/conda/conda-bld ...
- RuntimeError: size mismatch, m1: [80 x 4], m2: [320 x 50] at ..\aten\src\TH/generic/THTensorMath.cpp
RuntimeError: size mismatch, m1: [80 x 4], m2: [320 x 50] at -\aten\src\TH/generic/THTensorMath.cpp: ...
- PyTorch RuntimeError: size mismatch, m1:
在查看torch的FastRCNNPredictor官方实现时,想弄清楚一些细节,其中nn.Linear使用的时候需要给定(in_channels, num_classes). 随便打一点测试代码,报 ...
- pytorch搭建cnn报错:RuntimeError: size mismatch, m1: [10 x 43264], m2: [10816 x 2] at C...
具体报错信息: Traceback (most recent call last):File "E:/Program Files/PyCharm 2019.2/machinelearning ...
- RuntimeError: size mismatch, m1: [512 x 12800], m2: [2048 x 1024] at C
RuntimeError: size mismatch, m1: [512 x 12800], m2: [2048 x 1024] at C 说一下这错误,意思就是m1和m2两个数组尺寸不一样. 为啥 ...
- pytorch RuntimeError: size mismatch, m1: [64 x 784], m2: [784 x 10] at
from torch import nnclass Mnist_Logistic(nn.Module):def __init__(self):super().__init__()self.lin=nn ...
- pytorch神经网络,解决输入图像大小不匹配问题 size mismatch
问题如下:RuntimeError: size mismatch, m1: [4 x 512], m2: [64 x 128]-- RuntimeError: size mismatch, m1: [ ...
- strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur
strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur ...
最新文章
- Saltstack系列之一——安装篇
- close和shutdown的区别
- android javamail获取邮件太多太慢_结合 Spring 发送邮件的4种正确姿势,你知道几种?...
- 高并发编程-Daemon Thread的创建以及使用场景分析
- print函数python_带有结束参数的Python print()函数
- extras mibs php7,ubuntu编译安装php7遇到的问题及解决方案
- 永辉生活APP卖茅台只收款不发货,永辉超市回应...
- BAT程序员必备技能调研,你中了几招?
- CSS 实现文字两端对齐
- FX2LP与FPGA的简单批量回环
- c语言心算抢答系统,心算抢答系统2.doc
- oracle导出导入同义词,oracle同义词语句备份
- 永中文档在线预览集群部署方案
- mysql数据库用户密码_修改mysql数据库的用户名和密码
- hp 800 g4 twr linux,【拆机】HP EliteDesk 800 G4 TWR—探究塔式机箱的秘密
- 手机打车APP的机遇与挑战
- 【CSDN英雄会】黄帅:安全不是独行侠而是系统性的运维过程
- 番茄社区多门店系统介绍
- 没有什么能够毁灭一个人的灵魂
- 北洋网络口打印机设置
热门文章
- elementui组件中,树形组件的使用
- 女神异闻录5(p5)系统拆解
- selenium如何通过快捷键关闭浏览器打开的新页签
- 手机浏览器java_三款最热java手机浏览器横评(组图)
- Windows Server 2016-客户端加域端口汇总
- 《科研诚信与学术规范》参考答案最新版
- Java常见面试题含答案(第一期)
- 跟NAS斗智斗勇的个人文件整理日常(没写完)
- 短信网关通道对接及分流策略说明
- CMStudio中出现‘$错误‘ is not a vaild integer value如何解决