目录

1. CIFAR-10数据集介绍

2. 问题说明

3. 模型训练过程

4. 结果可视化


1. CIFAR-100数据集介绍

这个数据集就像CIFAR-10,除了它有100个类,每个类包含600个图像。,每类各有500个训练图像和100个测试图像。CIFAR-100中的100个类被分成20个超类。每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类)以下是CIFAR-100中的类别列表:

Cifar10
Superclass Classes
aquatic mammals beaver, dolphin, otter, seal, whale
fish aquarium fish, flatfish, ray, shark, trout
flowers orchids, poppies, roses, sunflowers, tulips
food containers bottles, bowls, cans, cups, plates
fruit and vegetables apples, mushrooms, oranges, pears, sweet peppers
household electrical devices clock, computer keyboard, lamp, telephone, television
household furniture bed, chair, couch, table, wardrobe
insects bee, beetle, butterfly, caterpillar, cockroach
large carnivores bear, leopard, lion, tiger, wolf
large man-made outdoor things bridge, castle, house, road, skyscraper
large natural outdoor scenes cloud, forest, mountain, plain, sea
large omnivores and herbivores camel, cattle, chimpanzee, elephant, kangaroo
medium-sized mammals fox, porcupine, possum, raccoon, skunk
non-insect invertebrates crab, lobster, snail, spider, worm
people baby, boy, girl, man, woman
reptiles crocodile, dinosaur, lizard, snake, turtle
small mammals hamster, mouse, rabbit, shrew, squirrel
trees maple, oak, palm, pine, willow
vehicles 1 bicycle, bus, motorcycle, pickup truck, train
vehicles 2 lawn-mower, rocket, streetcar, tank, tractor

CIFAR-100标签

参考

2. 问题说明

  • 自由选择算法完成 cifar-100 数据集解析分类
  • 要求准确率至少大于60%

3. 模型训练过程

1.初步准备

采用ResNet模型

两种残差网络块定义

'''
不带瓶颈段的残差块
'''
class BasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.network = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride),nn.BatchNorm2d(out_channels),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),nn.MaxPool2d(3, stride=1, padding=1))self.downSample = nn.Sequential()if in_channels != out_channels:self.downSample = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),nn.BatchNorm2d(out_channels))def forward(self, x):out = self.network(x) + self.downSample(x)return out
'''
带瓶颈段的残差块
in channel = [inChannel1,inChannel2,inChannel3]
in channel = [outChannel1,outChannel2,outChannel3]
stride = 1 长宽不变
stride = 2 长宽减半
'''
class BasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride):super(BasicBlock, self).__init__()self.network = nn.Sequential(nn.Conv2d(in_channels, out_channels[0], kernel_size=1, padding=0, stride=1),nn.BatchNorm2d(out_channels[0]),nn.ReLU(),nn.Conv2d(out_channels[0], out_channels[1], kernel_size=3, padding=1, stride=stride),nn.BatchNorm2d(out_channels[1]),nn.ReLU(),nn.Conv2d(out_channels[1], out_channels[2], kernel_size=1, padding=0, stride=1),nn.BatchNorm2d(out_channels[2]),)self.downSample = nn.Sequential(nn.Conv2d(in_channels, out_channels[2], kernel_size=1,  padding=0, stride=stride),nn.BatchNorm2d(out_channels[2]))

训练,测试函数定义如下

def train():net.train()acc = 0.0sum = 0.0loss_sum = 0for batch, (data, target) in enumerate(trainLoader):data, target = data.to(device), target.to(device)net.optimizer.zero_grad()output = net(data)loss = net.lossFunc(output, target)loss.backward()net.optimizer.step()acc += torch.sum(torch.argmax(output, dim=1) == target).item()sum += len(target)loss_sum += loss.item()writer.add_scalar('Cifar100_model_log/trainAccuracy', 100 * acc / sum, epoch + 1)writer.add_scalar('Cifar100_model_log/trainLoss', loss_sum / (batch + 1), epoch + 1)print('train accuracy: %.2f%%, loss: %.4f' % (100 * acc / sum, loss_sum / (batch + 1)))def test():net.eval()acc = 0.0sum = 0.0loss_sum = 0step = 0for batch, (data, target) in enumerate(testLoader):initData = datadata = testTransform2(data)data, target = data.to(device), target.to(device)output = net(data)
'''
用于通过tensorboard可视化训练结果,使用时要注意将epoch设置为1for i in range(len(output)):writer.add_image(fineLabelList[torch.argmax(output, dim=1)[i]], initData[i], step)step = step + 1
'''loss = net.lossFunc(output, target)acc += torch.sum(torch.argmax(output, dim=1) == target).item()sum += len(target)loss_sum += loss.item()writer.add_scalar('Cifar100_model_log/testAccuracy', 100 * acc / sum, epoch + 1)writer.add_scalar('Cifar100_model_log/trainLoss', loss_sum / (batch + 1), epoch + 1)print('test accuracy: %.2f%%, loss: %.4f' % (100 * acc / sum, loss_sum / (batch + 1)))

准备数据

'''
采用GPU加速训练,没有则用cpu
'''
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainTransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])'''
数据集扩充采用
trainTransform = transforms.Compose([transforms.ToTensor(),transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=(0,1), contrast=(0,1), saturation=(0,1), hue=0),transforms.RandomVerticalFlip(p=0.5),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
'''testTransform1 = transforms.Compose([transforms.ToTensor(),
])
testTransform2 = transforms.Compose([transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])trainDataset = datasets.cifar.CIFAR100(root='cifar100', train=True, transform=trainTransform, download=False)
testDataset = datasets.cifar.CIFAR100(root='cifar100', train=False, transform=testTransform1, download=False)trainLoader = DataLoader(trainDataset, batch_size=200, shuffle=True)
testLoader = DataLoader(testDataset, batch_size=200, shuffle=False)

2.模型训练及改进思路

训练过程中发现网络的过拟合问题情况严重,训练集正确率可达90%,测试集正确率卡在45%上下。

解决方案

(1)使用L1、L2正则化

(2)增加Dropout

(3)扩充数据集

(4)删除部分残差块,缩减网络

(5)迁移学习

经过测试,发现采用了缩减网络,扩充数据集,增加Dropout的网络可以使正确率从45%提升到了61.79%(迁移学习作用也比较明显,实验中的一种迁移学习方式在原训练基础上将正确率提升至63.13%)而相较于Dropout,L2正则化可以更好地防止过拟合,但是正确率上升缓慢且出现停滞。

最后网络定义如下,剩余残差块和卷积层可进行迁移训练

进行的部分迁移学习思路过程如下:

第一次迁移学习:
冻结block1~4
增加
block5:BasicBlock(128, 128)
block6:BasicBlock(128, 512, 2)
重设全连接层
正确率浮动变化不大第二次迁移学习:
第一次的基础上冻结所有层
补充cov1单层卷积2*2*512
(如果增加至三层则会下降严重)
正确率提升至58%第三次迁移学习
冻结block1~4
原版基础上重设全连接层单层
目的为了训练卷积层第四次迁移学习(较为成功)
冻结block1~4
在原版基础上增加
block5:BasicBlock(128, 1024)
block6:BasicBlock(1024, 128)
进行一次放大将正确率提升至63.13%
经测试若采用缩小,正确率提高幅度不大  

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.block1 = BasicBlock(3, 16)self.block2 = BasicBlock(16, 64, 2)self.block3 = BasicBlock(64, 64)self.block4 = BasicBlock(64, 128, 2)self.block5 = nn.Sequential()self.block6 = nn.Sequential()self.block7 = nn.Sequential()self.block8 = nn.Sequential()self.block9 = nn.Sequential()self.cov1 = nn.Sequential()self.linear = nn.Sequential(nn.Flatten(),nn.Linear(8 * 8 * 128, 2048),nn.Dropout(0.1),nn.BatchNorm1d(2048),nn.Linear(2048, 1024),nn.Dropout(0.1),nn.ReLU(),nn.Linear(1024, 100))self.optimizer = torch.optim.SGD(self.parameters(), lr=0.08)self.lossFunc = torch.nn.CrossEntropyLoss()

训练结果(可以发现还是过拟合严重)

4. 结果可视化

受版面限制,部分展示如下

通过拖动发现,在bed标签下出现一只错误识别的兔子

CIFAR-100数据集 卷积神经网络训练相关推荐

  1. 【神经网络与深度学习】CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——[附完整训练代码]

    [神经网络与深度学习]CIFAR-10数据集介绍,并使用卷积神经网络训练模型--[附完整代码] 一.CIFAR-10数据集介绍 1.1 CIFAR-10数据集的内容 1.2 CIFAR-10数据集的结 ...

  2. 记录|深度学习100例-卷积神经网络(CNN)彩色图片分类 | 第2天

    记录|深度学习100例-卷积神经网络(CNN)彩色图片分类 | 第2天 1. 彩色图片分类效果图 数据集如下: 测试图1如下 训练/验证精确图如下: 优化后:测试图--打印预测标签: 优化后:测试图- ...

  3. 神经网络 卷积神经网络,卷积神经网络训练太慢

    深度学习为什么加入卷积神经网络之后程序运行速度反而变慢了 谷歌人工智能写作项目:神经网络伪原创 卷积神经网络训练精度高,测试精度很低的原因 过拟合了,原因很多,解决方案也有很多写作猫.百度/谷歌搜索过 ...

  4. 卷积神经网络学习——第二部分:卷积神经网络训练的基本流程

    卷积神经网络学习--第二部分:卷积神经网络训练的基本流程 一.序言 二.训练流程 1.数据集引入 2.构建网络 (1)四层卷积神经网络 (2)两层全连接层网络 3.模型训练 4.模型评估 三.总结 一 ...

  5. 卷积云神经网络_用于卷积神经网络训练的地基云图数据库构建方法与流程

    本发明涉及人工智能模式识别领域,具体涉及一种用于卷积神经网络训练的地基云图数据库构建方法. 背景技术: 云是地球上水文循环的一个重要环节,它与地面辐射相互作用共同影响着局地和全球尺度的能量平衡.云分类 ...

  6. Pytorch 实现全连接神经网络/卷积神经网络训练MNIST数据集,并将训练好的模型在制作自己的手写图片数据集上测试

    使用教程 代码下载地址:点我下载 模型在训练过程中会自动显示训练进度,如果您的pytorch是CPU版本的,代码会自动选择CPU训练,如果有cuda,则会选择GPU训练. 项目目录说明: CNN文件夹 ...

  7. 卷积神经网络训练准确率突然下降_详解卷积神经网络:手把手教你训练一个新项目...

    作者:Tirmidzi Aflahi 原文链接:https://thedatamage.com/convolutional-neural-network-explained/Tirmidzi Afla ...

  8. 卷积神经网络的Python实现,python卷积神经网络训练

    如何才能自学好python? 对于想要自学Python的小伙伴,这里整理了一份系统全面的学习路线,按照这份大纲来安排学习可以少走弯路,事半功倍. 第一阶段:专业核心基础阶段目标:1.熟练掌握Pytho ...

  9. 深度学习100例-卷积神经网络(LeNet-5)深度学习里的“Hello Word” | 第22天

    大家好,我是「K同学啊 」! 前几天翻译了一篇讲十大CNN结构的文章(「多图」图解10大CNN架构),原作者思路十分清晰,从时间线上,将近年来CNN发展过程中一些比较重要的网络模型做了一一介绍.我发现 ...

最新文章

  1. 霍尔传感器与直流无刷电机换相
  2. 软件测试中性能瓶颈是什么,性能测试中如何定位性能瓶颈
  3. PyTorch基础-猫狗分类实战-10
  4. 085 Maximal Rectangle 最大矩形
  5. Star Schema完全参考手册学习笔记六
  6. mysql sysdate 格式化_MySQL函数汇总
  7. 限制只允许某个进程调用库
  8. 小白能读懂的 《手把手教你学DSP(TMS320X281X)》第七章 CPU定时器
  9. 域名转出与转入,以新网到万网之间的转移为例
  10. K3 ERP 系统财务管理 - 账结法、表结法
  11. 2019 Multi-University Training Contest 2:Beauty Of Unimodal Sequence(DP + 贪心构造)
  12. 计算机二级需要报班,计算机二级需要报班培训吗
  13. 巨杉数据库兼容mysql_SequoiaDB 巨杉数据库
  14. PDF编辑器首选工具Acrobat Pro DC
  15. Swagger3.0官方starter诞生,可以扔掉那些野生starter了
  16. Introducing a forensics data type taxonomy of acquirable artefacts from PLCs
  17. 子组件向父组件传递数据_如何将元素引用向下传递到角度的组件树中
  18. 将pdf文件压缩到指定大小
  19. 开学季学生党买什么蓝牙耳机好?高性价比无线蓝牙耳机推荐
  20. 树莓派使用FlashFxp SSH 连接

热门文章

  1. 相约未名湖畔,百度商业AI技术创新大赛携手北大学子共探AI发展
  2. 第一本书《瓦尔登湖》
  3. [剑指offer]JT19---顺时针打印矩阵(正方形打野)
  4. netfx_使用netfx平台增强视觉效果社区的能力
  5. linux命令 scp怎么用,linux scp命令怎么用
  6. 时间间隔感测试器2.0
  7. php代码 扫描,PHP代码安全扫描工具(AutoPHPCheck)
  8. 文档翻译软件哪个好?安利三个好用的文档翻译软件给你
  9. 日记侠:如何优化你的微信头像?
  10. 字典学习算法K-SVD详解