本章使用pytorch训练resnet50,使用cifar数据集。

数据集:

代码工程:

1.train.py


import torch
from torch import nn, optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from resnet50 import ResNet50#  用CIFAR-10 数据集进行实验def main():batchsz = 2cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(cifar_test, batch_size=1, shuffle=True)x, label = iter(cifar_train).next()print('x:', x.shape, 'label:', label.shape)device = torch.device('cpu')model = ResNet50().to(device)print(*list(model.children())[-3:-2])criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)# print(model)print(iter(cifar_test).next()[0].shape)for epoch in range(1):model.train()for batchidx, (x, label) in enumerate(cifar_train):if batchidx<=2:x, label = x.to(device), label.to(device)logits = model(x)loss = criteon(logits, label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print("epoch:",epoch, "index:",batchidx)else:continueprint(epoch, 'loss:', loss.item())#     # # PATH="model/test.pth"torch.save(model, "model2/test.pth")torch.save(model.state_dict(),"model2/test2.pth")model.eval()with torch.no_grad():# testtotal_correct = 0total_num = 0for idx, (x, label) in enumerate(cifar_test):if idx<=5:x, label = x.to(device), label.to(device)logits = model(x)pred = logits.argmax(dim=1)correct = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(pred)acc = total_correct / total_numprint(epoch, 'test acc:', acc)if __name__ == '__main__':main()# # 保存整个网络
# torch.save(net, PATH)
# # 保存网络中的参数, 速度快,占空间少
# torch.save(net.state_dict(),PATH)
# #--------------------------------------------------
# #针对上面一般的保存方法,加载的方法分别是:
# model_dict=torch.load(PATH)
# model_dict=model.load_state_dict(torch.load(PATH))

2.test_pth.py

from resnet50 import ResNet50
import torch
from PIL import Image
from torchvision import transforms
import cv2
import numpy as npdef prediect(img_path):device = torch.device('cpu')model=torch.load("model2/test.pth")model=model.to(device)# model = ResNet50()# weight=torch.load("model/test2.pth")# model.load_state_dict(weight)# model=model.to(device)img=cv2.imread(img_path)img=cv2.resize(img, (224, 224))img=np.reshape(img,(1,224,224,3))img=img.transpose(0,3,1,2).copy()print(img.shape)img_ = torch.Tensor(img)torch.no_grad()outputs = model(img_)_, predicted = torch.max(outputs, 1)print('pred :',outputs, predicted)if __name__ == '__main__':img_path="img/dog2.jpg"prediect(img_path)

3.resnet50.py

import torch
import torch.nn as nn
from torch.nn import functional as Fclass ResNet50BasicBlock(nn.Module):def __init__(self, in_channel, outs, kernerl_size, stride, padding):super(ResNet50BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channel, outs[0], kernel_size=kernerl_size[0], stride=stride[0], padding=padding[0])self.bn1 = nn.BatchNorm2d(outs[0])self.conv2 = nn.Conv2d(outs[0], outs[1], kernel_size=kernerl_size[1], stride=stride[0], padding=padding[1])self.bn2 = nn.BatchNorm2d(outs[1])self.conv3 = nn.Conv2d(outs[1], outs[2], kernel_size=kernerl_size[2], stride=stride[0], padding=padding[2])self.bn3 = nn.BatchNorm2d(outs[2])def forward(self, x):out = self.conv1(x)out = F.relu(self.bn1(out))out = self.conv2(out)out = F.relu(self.bn2(out))out = self.conv3(out)out = self.bn3(out)return F.relu(out + x)class ResNet50DownBlock(nn.Module):def __init__(self, in_channel, outs, kernel_size, stride, padding):super(ResNet50DownBlock, self).__init__()# out1, out2, out3 = outs# print(outs)self.conv1 = nn.Conv2d(in_channel, outs[0], kernel_size=kernel_size[0], stride=stride[0], padding=padding[0])self.bn1 = nn.BatchNorm2d(outs[0])self.conv2 = nn.Conv2d(outs[0], outs[1], kernel_size=kernel_size[1], stride=stride[1], padding=padding[1])self.bn2 = nn.BatchNorm2d(outs[1])self.conv3 = nn.Conv2d(outs[1], outs[2], kernel_size=kernel_size[2], stride=stride[2], padding=padding[2])self.bn3 = nn.BatchNorm2d(outs[2])self.extra = nn.Sequential(nn.Conv2d(in_channel, outs[2], kernel_size=1, stride=stride[3], padding=0),nn.BatchNorm2d(outs[2]))def forward(self, x):x_shortcut = self.extra(x)out = self.conv1(x)out = self.bn1(out)out = F.relu(out)out = self.conv2(out)out = self.bn2(out)out = F.relu(out)out = self.conv3(out)out = self.bn3(out)return F.relu(x_shortcut + out)class ResNet50(nn.Module):def __init__(self):super(ResNet50, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = nn.Sequential(ResNet50DownBlock(64, outs=[64, 64, 256], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),ResNet50BasicBlock(256, outs=[64, 64, 256], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),ResNet50BasicBlock(256, outs=[64, 64, 256], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),)self.layer2 = nn.Sequential(ResNet50DownBlock(256, outs=[128, 128, 512], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),ResNet50BasicBlock(512, outs=[128, 128, 512], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),ResNet50BasicBlock(512, outs=[128, 128, 512], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),ResNet50DownBlock(512, outs=[128, 128, 512], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]))self.layer3 = nn.Sequential(ResNet50DownBlock(512, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),ResNet50BasicBlock(1024, outs=[256, 256, 1024], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1],padding=[0, 1, 0]),ResNet50BasicBlock(1024, outs=[256, 256, 1024], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1],padding=[0, 1, 0]),ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],padding=[0, 1, 0]),ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],padding=[0, 1, 0]),ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],padding=[0, 1, 0]))self.layer4 = nn.Sequential(ResNet50DownBlock(1024, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2],padding=[0, 1, 0]),ResNet50DownBlock(2048, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],padding=[0, 1, 0]),ResNet50DownBlock(2048, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],padding=[0, 1, 0]))self.avgpool = nn.AvgPool2d(kernel_size = 7,stride=1,ceil_mode=False)# self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))self.fc = nn.Linear(2048, 10)# 使用卷积代替全连接self.conv11=nn.Conv2d(2048, 10, kernel_size=1, stride=1, padding=0)def forward(self, x):out = self.conv1(x)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out=self.conv11(out)out = out.reshape(x.shape[0], -1)# out = self.fc(out)return outif __name__ == '__main__':x = torch.randn(1, 3, 224, 224)net = ResNet50()out = net(x)print('out.shape: ', out.shape)print(out)

4.pth2onnx.py

import torch
from torchsummary import summarydevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = torch.load("model2/test.pth") # pytorch模型加载
model.eval()
for name in model.state_dict():print(name)
summary(model, (3, 224, 224))input_shape=list(map(int, "1,3,224,224".split(",")))
x = torch.randn(input_shape)   # 生成张量
x = x.to(device)export_onnx_file = "model2/test.onnx"       # 目的ONNX文件名
torch.onnx.export(model, x, export_onnx_file, verbose=True)
# torch.onnx.export(model, x, export_onnx_file, verbose=True, export_params=True, do_constant_folding=True, opset_version=11)# input_names=['boxes']
# output_names=['layer1.1.conv1.bias']
# torch.onnx.export(model, x, export_onnx_file,
#                   export_params=True,
#                   do_constant_folding=True,
#                   input_names=input_names,
#                   output_names=output_names
#                   )

5.test_onnx_v1.py

import cv2
import numpy as np
import onnxruntime as rtdef image_process(image_path):mean = np.array([[[0.485, 0.456, 0.406]]])      # 训练的时候用来mean和stdstd = np.array([[[0.229, 0.224, 0.225]]])img = cv2.imread(image_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = cv2.resize(img, (224, 224))                 # (96, 96, 3)image = img.astype(np.float32)/255.0image = (image - mean)/ stdimage = image.transpose((2, 0, 1))              # (3, 96, 96)image = image[np.newaxis,:,:,:]                 # (1, 3, 96, 96)image = np.array(image, dtype=np.float32)return imagedef onnx_runtime():imgdata = image_process('img/test.jpg')sess = rt.InferenceSession("model2/test.onnx")input_name = sess.get_inputs()[0].name  output_name = sess.get_outputs()[0].namepred_onnx = sess.run([output_name], {input_name: imgdata})print("outputs:",np.array(pred_onnx)[0].shape)onnx_runtime()

6.test_onnx_v2.py

import numpy as np
import torch
import onnx
import onnxruntime
import pickle# 测试数据
x = torch.randn(1,3,224,224, requires_grad=False)
print(type(x),x.shape)
# 使用 ONNX 的 API 检查 ONNX 模型
onnx_model = onnx.load("model2/test.onnx")
onnx.checker.check_model(onnx_model)# onnx模型测试
ort_session = onnxruntime.InferenceSession("model2/test.onnx")
def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()#结果输出
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
ort_out = ort_outs[0]
print(x.shape, ort_out.shape)# torch模型测试
# model=torch.load("test/person_reid.pth",map_location='cpu')
# model.eval()
# torch_out = model(x)# 比较ONNX 和 PyTorch 的结果
# np.testing.assert_allclose(to_numpy(torch_out), ort_out, rtol=1e-03, atol=1e-05)
# print("模型没有太大差异!")

7.onnx2pb.py

import onnx
from onnx_tf.backend import preparedef onnx2pb(onnx_input_path, pb_output_path):onnx_model = onnx.load(onnx_input_path)  # load onnx modeltf_exp = prepare(onnx_model)  # prepare tf representationtf_exp.export_graph(pb_output_path)  # export the modelif __name__ == "__main__":# onnx_input_path = 'test/person_reid.onnx'# pb_output_path = 'test/person_reid2.pb'onnx_input_path = 'model2/test.onnx'pb_output_path = 'model2/test.pb'onnx2pb(onnx_input_path, pb_output_path)

8.test_pb.py  (onnx+pb)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow
import cv2
import numpy as np
import torch
import onnx
import onnxruntime
import pickledef recognize(img, pb_file_path):with tf.Graph().as_default():output_graph_def = tf.GraphDef()with open(pb_file_path, "rb") as f:#主要步骤即为以下标出的几步,1、2步即为读取图output_graph_def.ParseFromString(f.read())# 1.将模型文件解析为二进制放进graph_def对象_ = tf.import_graph_def(output_graph_def, name="")# 2.import到当前图with tf.Session() as sess:init = tf.global_variables_initializer()sess.run(init)graph = tf.get_default_graph()# 3.获得当前图# # 4.get_tensor_by_name获取需要的节点# x = graph.get_tensor_by_name("IteratorGetNext_1:0")# y_out = graph.get_tensor_by_name("resnet_v1_50_1/predictions/Softmax:0")x = graph.get_tensor_by_name("data:0")y_out = graph.get_tensor_by_name("reid_embedding:0")# img=np.random.normal(size=(1, 224, 224, 3))# img=cv2.imread(jpg_path)# img=cv2.resize(img, (128, 256))# img=np.reshape(img,(1,128,256,3))# img=img.transpose(0,3,1,2).copy()# print(img.shape)#执行output = sess.run(y_out, feed_dict={x:img})pred=np.argmax(output, axis=1)return output# print("预测结果:", output.shape, output, "预测label:", pred)jpg_path="img/test.jpg"
img=cv2.imread(jpg_path)
img=cv2.resize(img, (128, 256))
img=np.reshape(img,(1,128,256,3))
img=img.transpose(0,3,1,2).copy()
print(img.shape)
x = torch.randn(1,3,256,128, requires_grad=False)
img=x# 测试pb
a=recognize(img, "test/gg.pb")
print(a.shape)
# b=recognize(img, "test/person_reid2.pb")
# np.testing.assert_allclose(a, b, rtol=1e-03, atol=1e-05)
# print(a.shape,a[0][4],b[0][4])# # # 测试数据
# # x = torch.randn(1,3,256,128, requires_grad=False)
# # # x=torch.from_numpy(img)
# # # x.requires_grad=False# # 使用 ONNX 的 API 检查 ONNX 模型
# onnx_model = onnx.load("test/person_reid.onnx")
# onnx.checker.check_model(onnx_model)# # onnx模型测试
# ort_session = onnxruntime.InferenceSession("test/person_reid.onnx")
# def to_numpy(tensor):
#     return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()#  #结果输出
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
# ort_outs = ort_session.run(None, ort_inputs)
# ort_out = ort_outs[0]
# print(ort_out.shape, ort_out[0][4])
# np.testing.assert_allclose(a, ort_out, rtol=1e-03, atol=1e-05)

pytorch实现resnet50(训练+测试+模型转换)相关推荐

  1. tensorflow实现resnet50(训练+测试+模型转换)

    本章使用tensorflow训练resnet50,使用手写数字图片作为数据集. 数据集: 代码工程: 1.train.py import argparse import cv2 import tens ...

  2. YOLOP 训练+测试+模型评估

    文章目录 前言 一.环境搭建 二.测试 三.训练 3.1 下载数据集 3.2 在./lib/config/default.py中修改相关参数 四.模型评估 五.可能遇到的报错 5.1 测试视频报错 I ...

  3. pt->onnx->ncnn(pytorch部署自己训练的模型)

    pt->onnx->ncnn(pytorch部署自己训练的模型) yolov6似乎有部分操作ncnn不支持,需要改一下网络结构,所以这里使用 yolov7-tiny 首先,找一个目标检测的 ...

  4. linux caffe生成的模型,深度学习之pytorch转caffe转ncnn模型转换(三)

    搭建caffe平台: 先在Linux系统下搭建caffe环境,安装依赖包: sudo apt-get install libprotobuf-dev libleveldb-dev libsnappy- ...

  5. 关于使用Pytorch时,训练集模型表现很好但测试集模型表现极差的原因

    出现这一现象的原因主要有三个: 训练数据集过小,导致系统泛化能力不足 训练和测试时的model.train(), model.eval()没有设置好 训练时数据集一定要打乱,Dataloader的sh ...

  6. Pytorch 保存中途训练的模型

    一般看到的是这个: torch.save(net.state_dict(), 'model.pth') 但这个只保存了模型本身的参数,能保证可以load保存的模型进行测试,但是并不方便恢复训练. 实际 ...

  7. 轻松入门模型转换和可视化

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 本文给大家介绍一个模型转换格式ONNX和可视化工具Netron.ONNX是微软设计的一种多平台的通用文 ...

  8. 训练TFlite模型

    Tensorflow训练的模型转换成TFlite可使用的模型 Windows篇 1. Anaconda 2.下载Tensorflow model. SSD预训练模型和文件库 3. 设置conda虚拟环 ...

  9. Openvino 模型转换命令mo.py

    openvino中用于将训练的模型转换成IR文件,可以使用mo.py脚本 该脚本位于: /opt/intel/openvino/deployment_tools/model_optimizer/mo. ...

最新文章

  1. 当AI学会高数:解题、出题、评分样样都行
  2. 【双11背后的技术】集团AliDocker化双11总结
  3. NR 5G 网络切片
  4. ubuntu修改pip的官方源为豆瓣源
  5. python numpy和pandas库的区别_python – 来自熊猫和numpy的意思不同
  6. Spring 从入门到入土——AOP 就这么简单!| 原力计划
  7. php有多少魔术方法,PHP常用的几个魔术方法
  8. linq的简单查询 和 组合查询
  9. PGPDesktopWin32-10.2.0 加密与签名-软件实验8
  10. Matlab聚类分析/判别分析
  11. C++实现人机对战围棋(使用Leela Zero权重)-自动提子
  12. 一、用于数据分析的Excel技巧
  13. php怎么把字符转成大写,php将字符串全部转换成大写或者小写的方法
  14. LoadRunner 录制IE 8卡死
  15. 单片机 MSP430 串口 计算 波特率
  16. java 计算间隔天数,java 计算间隔的天数
  17. 快手之家(aardio.net) - 开头难
  18. 传统与现代可视化 PK:再生水厂二维工艺组态系统
  19. Python相对导入:ValueError: attempted relative import beyond top-level package
  20. 2015自然基金一审结果:项目申请的共性问题。

热门文章

  1. Highcharts 笔记
  2. 用C语言easyx库制作简易Flappy Bird小游戏
  3. 欧几里得算法时间复杂度简单分析
  4. 卜若代码笔记系列的bug集合-3999
  5. [Realtek sdk-3.4.14b] rtl8197fh+8812F wifi安全漏洞patch
  6. 大数据spark开发入门教程
  7. python-循环递归斐波那契数列
  8. 常用的mvc框架 java_常用框架SpringMVC
  9. centos7,openfile修改
  10. C语言 获取Float小数位数