目录

  • 一、数据集加载
  • 二、搭建模型
    • 1.继承torch.nn.Module
    • 2.利用容器torch.nn.Sequential
    • 3.利用现有的预训练网络
  • 三、配置模型
  • 四、训练模型
  • 参考链接

一、数据集加载

DataLoaderImageFolder函数

from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoaderdata_transform = transforms.Compose([transforms.ToTensor(), #Converts a PIL.Image or numpy.ndarray to torch.FloatTensortransforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5]),transforms.ConvertImageDtype(torch.float)
])
dataset = ImageFolder("YOUR IMAGE DIRECTORY",transform = data_transform)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

二、搭建模型

1.继承torch.nn.Module

继承Module,重写forward函数。
例:

class net(nn.Module):def __init__(self, in_size, out_size):super(unetUp, self).__init__()self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)self.up = nn.UpsamplingBilinear2d(scale_factor=2)self.relu = nn.ReLU(inplace=True)def forward(self, inputs1, inputs2):outputs = torch.cat([inputs1, self.up(inputs2)], 1)outputs = self.conv1(outputs)outputs = self.relu(outputs)outputs = self.conv2(outputs)outputs = self.relu(outputs)return outputs

2.利用容器torch.nn.Sequential

例:

model = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())

3.利用现有的预训练网络

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
efficientnet_b2 = models.efficientnet_b2(pretrained=True)
efficientnet_b3 = models.efficientnet_b3(pretrained=True)
efficientnet_b4 = models.efficientnet_b4(pretrained=True)
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
regnet_y_400mf = models.regnet_y_400mf(pretrained=True)
regnet_y_800mf = models.regnet_y_800mf(pretrained=True)
regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True)
regnet_y_3_2gf = models.regnet_y_3_2gf(pretrained=True)
regnet_y_8gf = models.regnet_y_8gf(pretrained=True)
regnet_y_16gf = models.regnet_y_16gf(pretrained=True)
regnet_y_32gf = models.regnet_y_32gf(pretrained=True)
regnet_x_400mf = models.regnet_x_400mf(pretrained=True)
regnet_x_800mf = models.regnet_x_800mf(pretrained=True)
regnet_x_1_6gf = models.regnet_x_1_6gf(pretrained=True)
regnet_x_3_2gf = models.regnet_x_3_2gf(pretrained=True)
regnet_x_8gf = models.regnet_x_8gf(pretrained=True)
regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue)
regnet_x_32gf = models.regnet_x_32gf(pretrained=True)

三、配置模型

配置:

#多块显卡并行
device_ids = [0, 1]
net = torch.nn.DataParallel(net, device_ids=device_ids)
#优化开启
cudnn.benchmark = True
#cuda
net = net.cuda()

损失函数和优化器:

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)   # 学习率为0.01
criterion = nn.CrossEntropyLoss()

内置优化器:

Adadelta:实现 Adadelta 算法。Adagrad:实现 Adagrad 算法。Adam:实现Adam算法。AdamW:实现 AdamW 算法。SparseAdam:实现适用于稀疏张量的 Adam 算法的惰性版本。Adamax:实现 Adamax 算法(基于无穷范数的 Adam 变体)。ASGD:实现平均随机梯度下降。LBFGS:实现 L-BFGS 算法,参考minFunc 。NAdam:实现 NAdam 算法。RAdam:实现 RAdam 算法。RMSprop:实现 RMSprop 算法。Rprop:实现弹性反向传播算法。SGD:实现随机梯度下降(可选动量)。

四、训练模型

训练
简单例子:

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

Unet训练的例程:

for i in epochs:total_loss = 0.0for iteration, batch in enumerate(trainloader, 0):if iteration >= epoch_size: breakimgs, pngs, labels = batchwith torch.no_grad():imgs = torch.from_numpy(imgs).type(torch.FloatTensor)pngs = torch.from_numpy(pngs).type(torch.FloatTensor).long()labels = torch.from_numpy(labels).type(torch.FloatTensor)#use cudaimgs = imgs.cuda()pngs = pngs.cuda()labels = labels.cuda()optimizer.zero_grad()outputs = net(imgs)loss    = CE_Loss(outputs, pngs, num_classes = NUM_CLASSES)loss.backward()optimizer.step()total_loss += loss.item()

保存模型:

#保存模型结构和参数
torch.save(net, 'net.pkl')
# 只保存神经网络的模型参数
torch.save(net.state_dict(), 'net_params.pkl')
#保存为ONNX
dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda') #网络输入
input_names = ["inputs"]
output_names = ["outpus"]
torch_out = torch.onnx.export(net, dummy_input, "net.onnx", export_params=True, verbose=True,input_names=input_names, output_names=output_names)

tensorboard记录
例:

from torch.utils.tensorboard import SummaryWriter
import numpy as npwriter = SummaryWriter()for n_iter in range(100):writer.add_scalar('Loss/train', np.random.random(), n_iter)writer.add_scalar('Loss/test', np.random.random(), n_iter)writer.add_scalar('Accuracy/train', np.random.random(), n_iter)writer.add_scalar('Accuracy/test', np.random.random(), n_iter)

参考链接

https://pytorch.org/docs/stable/index.html

pytorch 网络搭建简要步骤相关推荐

  1. 大型综合网络搭建详细步骤教程

    1.1 问题 现有网络问题分析: 接入层交换机只与同一个三层交换机相连,存在单点故障而影响网络通信. 互联网连接单一服务商 现有网络需求: 随着企业发展,为了保证网络的高可用性,需要使用很多的冗余技术 ...

  2. 怎样检查python环境是否安装好_如何搭建pytorch环境的方法步骤

    1.conda创建虚拟环境pytorch_gpu conda create -n pytorch_gpu python=3.6 创建虚拟环境还是相对较快的,它会自动为本环境安装一些基本的库,等待时间无 ...

  3. <计算机视觉四> pytorch版yolov3网络搭建

    鼠标点击下载     项目源代码免费下载地址 <计算机视觉一> 使用标定工具标定自己的目标检测 <计算机视觉二> labelme标定的数据转换成yolo训练格式 <计算机 ...

  4. fabric2.3.2 test-network测试网络搭建 超详细步骤

    搭建好fabric网络后的第一步一定是练习一下测试网络.如果需要ubuntu下安装fabric环境的可以看下面两篇文章: Ubuntu16.04+fabric1.4.3 (15条消息) fabric1 ...

  5. 搭建nexus3私库简要步骤

    搭建nexus私库 简要步骤: 安装nexus 登录nexus页面端 默认地址http://loaclhost:8081 登录nexus账号 默认admin/admin123 maven-centra ...

  6. 虚拟机屏幕显示不全(界面大小更改 )虚拟机Ubuntu18.04 的超详细环境搭建教程/步骤 SDN软件定义网络实验

    打开虚拟机后,我们可能发现,桌面周围有大量黑边,且有些界面无法完整显示,影响我们的感受和操作!!!  解决方法: (1)点击箭头所指,进入目录 (2)点击箭头所指的齿轮,进入"setting ...

  7. 2022年网络搭建与应用——国赛FTP搭建 (解题步骤答案)

    2022年网络搭建与应用 FTP搭建 需要其他部分 全部解析私聊. [任务描述]为了提高文件的共享性,对用户进行透明和可靠高效地 [任务描述]为了提高文件的共享性,对用户进行透明和可靠高效地]数据传送 ...

  8. 运用PyTorch动手搭建一个共享单车预测器

    本文摘自 <深度学习原理与PyTorch实战> 我们将从预测某地的共享单车数量这个实际问题出发,带领读者走进神经网络的殿堂,运用PyTorch动手搭建一个共享单车预测器,在实战过程中掌握神 ...

  9. [转]vmware 域网络搭建

    最近给一个客户做网络搭建项目,要是实现网络内部办公安全,实现文件服务器,域控.用户监控.邮件服务器等(真的是狮子大开口啊).但是,却只提供给我一台普通的服务器.在我一番摆事实,将道理的说服下,老板最终 ...

  10. 第十二章_网络搭建及训练

    文章目录 第十二章 网络搭建及训练 CNN训练注意事项 第十二章 TensorFlow.pytorch和caffe介绍 12.1 TensorFlow 12.1.1 TensorFlow是什么? 12 ...

最新文章

  1. 女神推荐, 卡片,广告图 ,点击查看更多
  2. mxnet制作人脸识别训练集
  3. 计算机桌面运行慢,电脑越来越慢原因 电脑运行慢解决方法【详解】
  4. MAVEN 傻瓜式快速教程
  5. google drive的压缩包直接解压到google drive
  6. P1032-字串变换【bfs】
  7. qbytearry有数据上限吗_金仕达大数据开发岗位面试题
  8. Linux设备驱动模型-Bus
  9. atitit.查看预编译sql问号 本质and原理and查看原生sql语句
  10. linux 脚本 日志文件,在linux下用脚本输出日志
  11. 电磁场理论基础 01-17
  12. 雷达原理(第五版)常见公式
  13. 博帝 boost和威刚S102哪个好详细原创评测
  14. 云计算机技术的运用,三分钟为你详细解析云计算技术与应用
  15. AAAI-2021-RE-Progressive Multitask Learning with Controlled Information Flow for Joint Entity and Re
  16. bootstrap 滚动 进度条_Bootstrap中的进度条
  17. 10个SaaS的常见问题解答告诉你SaaS是什么
  18. javascript高级程序设计(python编程代码大全)
  19. 常用的企业管理软件有哪些?
  20. 串行 RapidIO接口介绍

热门文章

  1. python怎么在gui中显示图片_Python 3-如何从Web检索图像并使用TKINTER在GUI中显示?...
  2. jquery 元素第二个_jQuery知识总结
  3. vue 配置sass、scss全局变量
  4. 《计算机系统:系统架构与操作系统的高度集成》——2.5 高级数据抽象
  5. Message启动菜单个性化制作工具V1.0.3.1最终版
  6. nginx 配置php
  7. Linux目录结构、bash的基础命令学习
  8. 移动互联网赌博的大礼包触发
  9. 领域驱动设计系列 (六):CQRS
  10. UserDefault使用