VGG

VGG在2014年由牛津大学著名研究组vGG (Visual Geometry Group)提出,斩获该年lmageNet竞赛中Localization Task (定位任务)第一名和 Classification Task (分类任务)第二名。

感受野

首先介绍一下感受野的概念。在卷积神经网络中,决定某一层输出结果中一个元素所对应的输入层的区域大小,被称作感受野(receptive field)。通俗的解释是,输出feature map上的一个单元对应输入层上的区域大小。

VGG亮点

通过堆叠多个3x3的卷积核来替代大尺度卷积核(减少所需参数)。论文中提到,可以通过堆叠两个3x3的卷积核替代5x5的卷积核,堆叠三个3x3的卷积核替代7x7的卷积核,是因为他们具有相同的感受野。

我们来计算一下使用7x7卷积核所需参数和堆叠三个3x3卷积核所需参数。假设输入输出channel为C, 7x7卷积核所需参数为7x7xC2,一个3x3卷积核所需参数为3x3xC2,三个就是3x3x3xC2。差不多减少了一半的参数0.0。

VGG结构

VGGNet模型有A-E五种结构网络,深度分别为11,11,13,16,19。其中较为典型的网络结构主要有vgg16和vgg19,本篇文章主要讲VGG16,并分享VGG16的Pytorch实现。

1.两层conv3-64

输入图片大小为224x224x3,卷积核大小为3×3,stride为1,padding为1,卷积核个数为64,卷积得到输出为224×224x64。

2.maxpool

maxpool的size为2,stride为2。输入为224x224x64,池化得到输出为112x112x64。

3.两层conv3-128

输入为112x112x64,卷积核大小为3×3,stride为1,padding为1,卷积核个数为128,卷积得到输出为112×112x128。

4.maxpool

maxpool的size为2,stride为2。输入为112×112x128,池化得到输出为56×56x128。

5.三层conv3-256

输入为56x56x128,卷积核大小为3×3,stride为1,padding为1,卷积核个数为256,卷积得到输出为56×56x256。

6.maxpool

maxpool的size为2,stride为2。输入为56×56x256,池化得到输出为28×28x256。

7.三层conv3-512

输入为28×28x256,卷积核大小为3×3,stride为1,padding为1,卷积核个数为512,卷积得到输出为28×28x512。

6.maxpool

maxpool的size为2,stride为2。输入为28×28x512,池化得到输出为14×14x512。

7.三层conv3-512

输入为14×14x512,卷积核大小为3×3,stride为1,padding为1,卷积核个数为512,卷积得到输出为14×14x512。

8.maxpool

maxpool的size为2,stride为2。输入为14×14x512,池化得到输出为7×7x512。

9.三层全连接层

与两层1x1x4096,一层1x1x1000进行全连接+ReLU(共三层),通过softmax输出1000个预测结果。7x7x512的层要跟4096个神经元的层做全连接,则替换为对7x7x512的层作通道数为4096、卷积核为1x1的卷积。这是全连接转卷积的思路。

实现猫狗识别

1.建立模型
import torch
import torch.nn as nnfrom torch.hub import load_state_dict_from_urlmodel_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth','vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth','vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth','vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=True, dropout=0.5):super(VGG, self).__init__()self.features = featuresself.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(True),nn.Dropout(p=dropout),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(p=dropout),nn.Linear(4096, num_classes),)if init_weights:for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def forward(self, x):x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return xdef make_layers(cfg, batch_norm=False):layers = []in_channels = 3for v in cfg:if v == "M":layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)if batch_norm:layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]else:layers += [conv2d, nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],}def vgg16(pretrained=True, progress=True, num_classes=2):model = VGG(make_layers(cfgs['vgg16']))if pretrained:state_dict = load_state_dict_from_url(model_urls['vgg16'], model_dir='model', progress=progress)model.load_state_dict(state_dict)if num_classes != 1000:model.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, num_classes),)return model}

值得一提的是

'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

64代表64个3x3的卷积核,M指的是maxpool池化层。

2.训练前准备工作

获取图片路径并分类:

import os
from os import getcwdclasses = ['cat', 'dog']
sets = ['train']if __name__ == '__main__':wd = getcwd()for se in sets:list_file = open('cls_' + se + '.txt', 'w')datasets_path = setypes_name = os.listdir(datasets_path)  # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表for type_name in types_name:if type_name not in classes:continuecls_id = classes.index(type_name)  # 输出0-1photos_path = os.path.join(datasets_path, type_name)photos_name = os.listdir(photos_path)for photo_name in photos_name:_, postfix = os.path.splitext(photo_name)  # 该函数用于分离文件名与拓展名if postfix not in ['.jpg', '.png', '.jpeg']:continuelist_file.write(str(cls_id) + ';' + '%s/%s' % (wd, os.path.join(photos_path, photo_name)))list_file.write('\n')list_file.close()

图像处理类:

import cv2
import numpy as np
import torch.utils.data as data
from PIL import Imagedef preprocess_input(x):x /= 127.5x -= 1.return xdef cvtColor(image):if len(np.shape(image)) == 3 and np.shape(image)[-2] == 3:return imageelse:image = image.convert('RGB')return imageclass DataGenerator(data.Dataset):def __init__(self, annotation_lines, inpt_shape, random=True):self.annotation_lines = annotation_linesself.input_shape = inpt_shapeself.random = randomdef __len__(self):return len(self.annotation_lines)def __getitem__(self, index):annotation_path = self.annotation_lines[index].split(';')[1].split()[0]image = Image.open(annotation_path)image = self.get_random_data(image, self.input_shape, random=self.random)image = np.transpose(preprocess_input(np.array(image).astype(np.float32)), [2, 0, 1])y = int(self.annotation_lines[index].split(';')[0])return image, ydef rand(self, a=0, b=1):return np.random.rand() * (b - a) + adef get_random_data(self, image, inpt_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):image = cvtColor(image)iw, ih = image.sizeh, w = inpt_shapeif not random:scale = min(w / iw, h / ih)nw = int(iw * scale)nh = int(ih * scale)dx = (w - nw) // 2dy = (h - nh) // 2image = image.resize((nw, nh), Image.BICUBIC)new_image = Image.new('RGB', (w, h), (128, 128, 128))new_image.paste(image, (dx, dy))image_data = np.array(new_image, np.float32)return image_datanew_ar = w / h * self.rand(1 - jitter, 1 + jitter) / self.rand(1 - jitter, 1 + jitter)scale = self.rand(.75, 1.25)if new_ar < 1:nh = int(scale * h)nw = int(nh * new_ar)else:nw = int(scale * w)nh = int(nw / new_ar)image = image.resize((nw, nh), Image.BICUBIC)# 将图像多余的部分加上灰条dx = int(self.rand(0, w - nw))dy = int(self.rand(0, h - nh))new_image = Image.new('RGB', (w, h), (128, 128, 128))new_image.paste(image, (dx, dy))image = new_image# 翻转图像flip = self.rand() < .5if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)rotate = self.rand() < .5if rotate:angle = np.random.randint(-15, 15)a, b = w / 2, h / 2M = cv2.getRotationMatrix2D((a, b), angle, 1)image = cv2.warpAffine(np.array(image), M, (w, h), borderValue=[128, 128, 128])# 色域扭曲hue = self.rand(-hue, hue)sat = self.rand(1, sat) if self.rand() < .5 else 1 / self.rand(1, sat)val = self.rand(1, val) if self.rand() < .5 else 1 / self.rand(1, val)x = cv2.cvtColor(np.array(image, np.float32) / 255, cv2.COLOR_RGB2HSV)  # 颜色空间转换x[..., 1] *= satx[..., 2] *= valx[x[:, :, 0] > 360, 0] = 360x[:, :, 1:][x[:, :, 1:] > 1] = 1x[x < 0] = 0image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB) * 255return image_data
3.训练模型

数据集我上传到百度网盘里,可自行下载解压到根目录下。

链接:https://pan.baidu.com/s/1v14gSYa5S0CH0GYDnKjb7Q?pwd=xhd0
提取码:xhd0

为了方便,我们把25000张图片全部放在train文件夹下,从中取10分之1也就是2500张做测试集,剩余22500做训练集

import torch
import torch.nn as nn
from net import vgg16
from torch.utils.data import DataLoader
from data import *'''数据集'''
annotation_path = 'cls_train.txt'
with open(annotation_path, 'r') as f:lines = f.readlines()
np.random.seed(10101)
np.random.shuffle(lines)  # 打乱数据
np.random.seed(None)
num_val = int(len(lines) * 0.1)
num_train = len(lines) - num_val
# 输入图像大小
input_shape = [224, 224]
train_data = DataGenerator(lines[:num_train], input_shape, True)
val_data = DataGenerator(lines[num_train:], input_shape, False)
val_len = len(val_data)
"""加载数据"""
gen_train = DataLoader(train_data, batch_size=4)
gen_test = DataLoader(val_data, batch_size=4)
'''构建网络'''
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
net = vgg16(pretrained=True, progress=True, num_classes=2)
net.to(device)
'''选择优化器和学习率的调整方法'''
lr = 0.0001
optim = torch.optim.Adam(net.parameters(), lr=lr)
sculer = torch.optim.lr_scheduler.StepLR(optim, step_size=1)
'''训练'''
epochs = 20
for epoch in range(epochs):print("===========", epoch, "==============")total_train = 0for data in gen_train:img, label = datawith torch.no_grad():img = img.to(device)label = label.to(device)optim.zero_grad()output = net(img)train_loss = nn.CrossEntropyLoss()(output, label).to(device)train_loss.backward()optim.step()total_train += train_lossprint("训练集上的损失:{}".format(train_loss))total_test = 0total_accuracy = 0for data in gen_test:img, label = datawith torch.no_grad():img = img.to(device)label = label.to(device)optim.zero_grad()out = net(img)test_loss = nn.CrossEntropyLoss()(out, label).to(device)total_test += test_lossaccuracy = (out.argmax(1) == label).sum()total_accuracy += accuracyprint("测试集上的精度:{:.1%}".format(total_accuracy / val_len))print("===============================================")print("训练集上的损失:{}".format(total_train))print("测试集上的损失:{}".format(total_test))print("测试集上的精度:{:.1%}".format(total_accuracy / val_len))# torch.save(net,"dogandcat.{}.pt".format(epoch+1))torch.save(net.state_dict(), "Adogandcat.{}.pth".format(epoch + 1))print("模型已保存")

放到老师的深度学习工作台上训练:


可以看到取得了不错的准确率。每一个epoch保存一下模型,一共20个epoch,最后拿Adogandcat.20.pth来测试一下模型泛化能力。

4.测试泛化能力

去网上找了几张图片,简直恐怖。这就是深度学习的魅力吧。



import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from net import vgg16img_pth = './test/test4.jpg'
img = Image.open(img_pth)
'''处理图片'''
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
image = transform(img)
'''加载网络'''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = vgg16()
model = torch.load("./Adogandcat.20.pth", map_location=device)
net.load_state_dict(model)
net.eval()
image = torch.reshape(image, (1, 3, 224, 224))
with torch.no_grad():out = net(image)
out = F.softmax(out, dim=1)
out = out.data.cpu().numpy()
print(out)
a = int(out.argmax(1))
plt.figure()
list = ["cat", 'dog']
plt.suptitle("Classes:{}:{:.1%}".format(list[a], out[0, a]))
plt.imshow(img)
plt.show()

VGG网络详解(实现猫猫和狗狗识别)相关推荐

  1. 基于CIFAR100的VGG网络结构详解

    基于CIFAR100的VGG网络详解 码字不易,点赞收藏 1 数据集概况 1.1 CIFAR100 cifar100包含20个大类,共100类,train集50000张图片,test集10000张图片 ...

  2. 第十六章 ConvNeXt网络详解

    系列文章目录 第一章 AlexNet网络详解 第二章 VGG网络详解 第三章 GoogLeNet网络详解 第四章 ResNet网络详解 第五章 ResNeXt网络详解 第六章 MobileNetv1网 ...

  3. ResNet网络详解并使用pytorch搭建模型、并基于迁移学习训练

    1.ResNet网络详解 网络中的创新点: (1)超深的网络结构(突破1000层) (2)提出residual模块 (3)使用Batch Normalization加速训练(丢弃dropout) (1 ...

  4. MobileNetv1、v2网络详解、使用pytorch搭建模型MobileNetv2并基于迁移学习训练

    1.MobileNetv1网络详解 传统卷积神经网络专注于移动端或者嵌入式设备中的轻量级CNN网络,相比于传统卷积神经网络,在准确率小幅降低的前提下大大减少模型参数与运算量.(相比VGG16准确率减少 ...

  5. GoogLeNet网络详解并使用pytorch搭建模型

    1.GoogLeNet网络详解 网络中的创新点: (1)引入了Inception结构(融合不同尺度的特征信息) (2)使用1x1的卷积核进行降维以及映射处理 (虽然VGG网络中也有,但该论文介绍的更详 ...

  6. ResNet网络详解与keras实现

    ResNet网络详解与keras实现 ResNet网络详解与keras实现 Resnet网络的概览 Pascal_VOC数据集 第一层目录 第二层目录 第三层目录 梯度退化 Residual Lear ...

  7. GoogleNet网络详解与keras实现

    GoogleNet网络详解与keras实现 GoogleNet网络详解与keras实现 GoogleNet系列网络的概览 Pascal_VOC数据集 第一层目录 第二层目录 第三层目录 Incepti ...

  8. Linux系统下ifconfig和route配置网络详解

    Linux系统下ifconfig和route配置网络详解 ifconfig和route合用于配置网络(ip命令综合二者功能,此处不讲),通常在前者设置好ip地址等信息后,采用route命令配置路由.( ...

  9. EfficientNetV2网络详解

    原论文名称:EfficientNetV2: Smaller Models and Faster Training 论文下载地址:https://arxiv.org/abs/2104.00298 原论文 ...

最新文章

  1. Mysql隐藏命令_mysql常用命令整理
  2. php启用日志记录,PHP SDK启用日志功能报错
  3. 用C#实现支持gmail邮件发送
  4. OpenLayers事件处理Event.js(七)
  5. MySQL utf8mb4与emoji表情
  6. 未来教育python视频百度云-2019年计算机二级Python语言程序设计考试大纲
  7. 数据结构源码笔记(C语言):二路归并排序
  8. TFS 2008 中文版下载及安装完整图解
  9. [跨平台系列三Docker篇]:ASP.NET Core应用
  10. vue组件化通信之父向子传值
  11. 二叉树的深度_十七:二叉树的最小深度
  12. .net如何引用该命名空间
  13. windows平台下subversion服务器端配置
  14. 一个超级简单的HTML模板框架源代码以及使用示例
  15. C# 异步定时器,可以重载; System.Timers.Timer
  16. 计算机基础与程序设计
  17. oracle对时间范围比较的语句
  18. NAFSM中值滤波器讲解与实现
  19. PymongoDB报错MongoError: The dotted field .. is not valid for storage
  20. Python数据处理DataFrame小记

热门文章

  1. Uniapp商城项目【详细笔记文档】
  2. 【51nod_1381】硬币游戏
  3. 黑猴子的家:stop-cluster.sh
  4. GO语言基础入门(二)
  5. 《淘宝店铺设计装修一册通》一2.1 Photoshop界面
  6. okhttp3测试框架_Okhttp3的使用详解
  7. 2.G3-PLC PHY
  8. 轻松让你的nginx服务器支持HTTP2协议
  9. Windows2012R2 远程桌面服务(RDP)3389存在SSL漏洞的解决办法
  10. 在腾讯工作是一种怎样的体验?