pytorch 网络搭建简要步骤
目录
- 一、数据集加载
- 二、搭建模型
- 1.继承torch.nn.Module
- 2.利用容器torch.nn.Sequential
- 3.利用现有的预训练网络
- 三、配置模型
- 四、训练模型
- 参考链接
一、数据集加载
用DataLoader
和ImageFolder
函数
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoaderdata_transform = transforms.Compose([transforms.ToTensor(), #Converts a PIL.Image or numpy.ndarray to torch.FloatTensortransforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5]),transforms.ConvertImageDtype(torch.float)
])
dataset = ImageFolder("YOUR IMAGE DIRECTORY",transform = data_transform)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
二、搭建模型
1.继承torch.nn.Module
继承Module,重写forward函数。
例:
class net(nn.Module):def __init__(self, in_size, out_size):super(unetUp, self).__init__()self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)self.up = nn.UpsamplingBilinear2d(scale_factor=2)self.relu = nn.ReLU(inplace=True)def forward(self, inputs1, inputs2):outputs = torch.cat([inputs1, self.up(inputs2)], 1)outputs = self.conv1(outputs)outputs = self.relu(outputs)outputs = self.conv2(outputs)outputs = self.relu(outputs)return outputs
2.利用容器torch.nn.Sequential
例:
model = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())
3.利用现有的预训练网络
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
efficientnet_b2 = models.efficientnet_b2(pretrained=True)
efficientnet_b3 = models.efficientnet_b3(pretrained=True)
efficientnet_b4 = models.efficientnet_b4(pretrained=True)
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
regnet_y_400mf = models.regnet_y_400mf(pretrained=True)
regnet_y_800mf = models.regnet_y_800mf(pretrained=True)
regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True)
regnet_y_3_2gf = models.regnet_y_3_2gf(pretrained=True)
regnet_y_8gf = models.regnet_y_8gf(pretrained=True)
regnet_y_16gf = models.regnet_y_16gf(pretrained=True)
regnet_y_32gf = models.regnet_y_32gf(pretrained=True)
regnet_x_400mf = models.regnet_x_400mf(pretrained=True)
regnet_x_800mf = models.regnet_x_800mf(pretrained=True)
regnet_x_1_6gf = models.regnet_x_1_6gf(pretrained=True)
regnet_x_3_2gf = models.regnet_x_3_2gf(pretrained=True)
regnet_x_8gf = models.regnet_x_8gf(pretrained=True)
regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue)
regnet_x_32gf = models.regnet_x_32gf(pretrained=True)
三、配置模型
配置:
#多块显卡并行
device_ids = [0, 1]
net = torch.nn.DataParallel(net, device_ids=device_ids)
#优化开启
cudnn.benchmark = True
#cuda
net = net.cuda()
损失函数和优化器:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 学习率为0.01
criterion = nn.CrossEntropyLoss()
内置优化器:
Adadelta:实现 Adadelta 算法。Adagrad:实现 Adagrad 算法。Adam:实现Adam算法。AdamW:实现 AdamW 算法。SparseAdam:实现适用于稀疏张量的 Adam 算法的惰性版本。Adamax:实现 Adamax 算法(基于无穷范数的 Adam 变体)。ASGD:实现平均随机梯度下降。LBFGS:实现 L-BFGS 算法,参考minFunc 。NAdam:实现 NAdam 算法。RAdam:实现 RAdam 算法。RMSprop:实现 RMSprop 算法。Rprop:实现弹性反向传播算法。SGD:实现随机梯度下降(可选动量)。
四、训练模型
训练
简单例子:
for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()
Unet训练的例程:
for i in epochs:total_loss = 0.0for iteration, batch in enumerate(trainloader, 0):if iteration >= epoch_size: breakimgs, pngs, labels = batchwith torch.no_grad():imgs = torch.from_numpy(imgs).type(torch.FloatTensor)pngs = torch.from_numpy(pngs).type(torch.FloatTensor).long()labels = torch.from_numpy(labels).type(torch.FloatTensor)#use cudaimgs = imgs.cuda()pngs = pngs.cuda()labels = labels.cuda()optimizer.zero_grad()outputs = net(imgs)loss = CE_Loss(outputs, pngs, num_classes = NUM_CLASSES)loss.backward()optimizer.step()total_loss += loss.item()
保存模型:
#保存模型结构和参数
torch.save(net, 'net.pkl')
# 只保存神经网络的模型参数
torch.save(net.state_dict(), 'net_params.pkl')
#保存为ONNX
dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda') #网络输入
input_names = ["inputs"]
output_names = ["outpus"]
torch_out = torch.onnx.export(net, dummy_input, "net.onnx", export_params=True, verbose=True,input_names=input_names, output_names=output_names)
tensorboard记录
例:
from torch.utils.tensorboard import SummaryWriter
import numpy as npwriter = SummaryWriter()for n_iter in range(100):writer.add_scalar('Loss/train', np.random.random(), n_iter)writer.add_scalar('Loss/test', np.random.random(), n_iter)writer.add_scalar('Accuracy/train', np.random.random(), n_iter)writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
参考链接
https://pytorch.org/docs/stable/index.html
pytorch 网络搭建简要步骤相关推荐
- 大型综合网络搭建详细步骤教程
1.1 问题 现有网络问题分析: 接入层交换机只与同一个三层交换机相连,存在单点故障而影响网络通信. 互联网连接单一服务商 现有网络需求: 随着企业发展,为了保证网络的高可用性,需要使用很多的冗余技术 ...
- 怎样检查python环境是否安装好_如何搭建pytorch环境的方法步骤
1.conda创建虚拟环境pytorch_gpu conda create -n pytorch_gpu python=3.6 创建虚拟环境还是相对较快的,它会自动为本环境安装一些基本的库,等待时间无 ...
- <计算机视觉四> pytorch版yolov3网络搭建
鼠标点击下载 项目源代码免费下载地址 <计算机视觉一> 使用标定工具标定自己的目标检测 <计算机视觉二> labelme标定的数据转换成yolo训练格式 <计算机 ...
- fabric2.3.2 test-network测试网络搭建 超详细步骤
搭建好fabric网络后的第一步一定是练习一下测试网络.如果需要ubuntu下安装fabric环境的可以看下面两篇文章: Ubuntu16.04+fabric1.4.3 (15条消息) fabric1 ...
- 搭建nexus3私库简要步骤
搭建nexus私库 简要步骤: 安装nexus 登录nexus页面端 默认地址http://loaclhost:8081 登录nexus账号 默认admin/admin123 maven-centra ...
- 虚拟机屏幕显示不全(界面大小更改 )虚拟机Ubuntu18.04 的超详细环境搭建教程/步骤 SDN软件定义网络实验
打开虚拟机后,我们可能发现,桌面周围有大量黑边,且有些界面无法完整显示,影响我们的感受和操作!!! 解决方法: (1)点击箭头所指,进入目录 (2)点击箭头所指的齿轮,进入"setting ...
- 2022年网络搭建与应用——国赛FTP搭建 (解题步骤答案)
2022年网络搭建与应用 FTP搭建 需要其他部分 全部解析私聊. [任务描述]为了提高文件的共享性,对用户进行透明和可靠高效地 [任务描述]为了提高文件的共享性,对用户进行透明和可靠高效地]数据传送 ...
- 运用PyTorch动手搭建一个共享单车预测器
本文摘自 <深度学习原理与PyTorch实战> 我们将从预测某地的共享单车数量这个实际问题出发,带领读者走进神经网络的殿堂,运用PyTorch动手搭建一个共享单车预测器,在实战过程中掌握神 ...
- [转]vmware 域网络搭建
最近给一个客户做网络搭建项目,要是实现网络内部办公安全,实现文件服务器,域控.用户监控.邮件服务器等(真的是狮子大开口啊).但是,却只提供给我一台普通的服务器.在我一番摆事实,将道理的说服下,老板最终 ...
- 第十二章_网络搭建及训练
文章目录 第十二章 网络搭建及训练 CNN训练注意事项 第十二章 TensorFlow.pytorch和caffe介绍 12.1 TensorFlow 12.1.1 TensorFlow是什么? 12 ...
最新文章
- 女神推荐, 卡片,广告图 ,点击查看更多
- mxnet制作人脸识别训练集
- 计算机桌面运行慢,电脑越来越慢原因 电脑运行慢解决方法【详解】
- MAVEN 傻瓜式快速教程
- google drive的压缩包直接解压到google drive
- P1032-字串变换【bfs】
- qbytearry有数据上限吗_金仕达大数据开发岗位面试题
- Linux设备驱动模型-Bus
- atitit.查看预编译sql问号 本质and原理and查看原生sql语句
- linux 脚本 日志文件,在linux下用脚本输出日志
- 电磁场理论基础 01-17
- 雷达原理(第五版)常见公式
- 博帝 boost和威刚S102哪个好详细原创评测
- 云计算机技术的运用,三分钟为你详细解析云计算技术与应用
- AAAI-2021-RE-Progressive Multitask Learning with Controlled Information Flow for Joint Entity and Re
- bootstrap 滚动 进度条_Bootstrap中的进度条
- 10个SaaS的常见问题解答告诉你SaaS是什么
- javascript高级程序设计(python编程代码大全)
- 常用的企业管理软件有哪些?
- 串行 RapidIO接口介绍
热门文章
- python怎么在gui中显示图片_Python 3-如何从Web检索图像并使用TKINTER在GUI中显示?...
- jquery 元素第二个_jQuery知识总结
- vue 配置sass、scss全局变量
- 《计算机系统:系统架构与操作系统的高度集成》——2.5 高级数据抽象
- Message启动菜单个性化制作工具V1.0.3.1最终版
- nginx 配置php
- Linux目录结构、bash的基础命令学习
- 移动互联网赌博的大礼包触发
- 领域驱动设计系列 (六):CQRS
- UserDefault使用