文章目录

  • 前言
  • 一、pytorch静态量化(手动版)
    • 踩坑:
  • 二、使用FX量化
    • 1.版本
    • 2.代码如下:
  • 总结

前言

以前面文章写到的mobilenet图像分类为例,本文主要记录一下pytorchh训练后静态量化的过程。


一、pytorch静态量化(手动版)

静态量化是最常用的量化形式,float32的模型量化成int8,模型大小大概变为原来的1/4,推理速度我在intel 8700k CPU上测试速度正好快4倍,但是在AMD的5800h CPU 上测试速度反而慢了两倍,应该是AMD不支持某些指令集加速。

踩坑:

之前手动添加量化节点的方式搞了好几天,最后模型是出来了,但是推理时候报错,大多数时候是RuntimeError: Could not run ‘--------’ with arguments from the ‘CPU’ backend,网上是说推理的时候没有安插QuantStub()和DeQuantStub(),可能是用的这个MobilenetV3网络结构复杂,某些地方没有手动添加到,这种方式肯定是可以成功的,只是比较麻烦容易出错。

# 加载模型
model = MobileNetV3_Large(2).to(device)  # 加载一个网络,我这边是二分类传了一个2
checkpoint = torch.load(weights, map_location=device)
model.load_state_dict(checkpoint)
model.to('cpu').eval()

合并层对于一些重复使用的Block和nn.Sequential要打印出来看,然后append到mix_list 里面
比如

# 打印model
for name, module in model.named_children():print(name, module)

比如这里Sequential里面存在conv+bn+relu,append进去的应该是[‘bneck.0.conv1’, ‘bneck.0.bn1’,‘nolinear1’],但是nolinear1是个变量,也就是说某些时候是relu某些时候又不是,这种时候就要一个个分析判断好然后写代码,稍微复杂点就容易出错或者遗漏。

backend = "fbgemm"  # x86平台
model.qconfig = torch.quantization.get_default_qconfig(backend)
mix_list = [['conv1','bn1'], ['conv2','bn2']] # 合并层只支持conv+bn conv+relu conv+bn+relu等操作,具体可以查一下,网络中存在的这些操作都append到mix_list里面
model = torch.quantization.fuse_modules(model,listmix) # 合并某些层
model_fp32_prepared = torch.quantization.prepare(model)
model_int8 = torch.quantization.convert(model_fp32_prepared)

有时候存在不支持的操作relu6这些要替换成relu,加法操作也要替换,最后还要输入一批图像校准模型等

self.skip_add = nn.quantized.FloatFunctional()
# forward的时候比如return a+b 改为return self.skip_add.add(a, b)

一系列注意事项操作完毕,最后推理各种报错,放弃了

二、使用FX量化

1.版本

fx量化版本也有坑,之前在torch 1.7版本操作总是报错搞不定,换成1.12.0版本就正常了,这一点非常重要。

2.代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import torchvision
from torchvision import transforms
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.quantization import get_default_qconfig
from torch import optim
import os
import time
from utils import load_data
from models.mobilenetv3copy import MobileNetV3_Largedef evaluate_model(model, test_loader, device, criterion=None):model.eval()model.to(device)running_loss = 0running_corrects = 0for inputs, labels in test_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)if criterion is not None:loss = criterion(outputs, labels).item()else:loss = 0# statisticsrunning_loss += loss * inputs.size(0)running_corrects += torch.sum(preds == labels.data)eval_loss = running_loss / len(test_loader.dataset)eval_accuracy = running_corrects / len(test_loader.dataset)return eval_loss, eval_accuracydef quant_fx(model, data_loader):model_to_quantize = copy.deepcopy(model)model_to_quantize.eval()qconfig = get_default_qconfig("fbgemm")qconfig_dict = {"": qconfig}prepared_model = prepare_fx(model_to_quantize, qconfig_dict)print("开始校准")calibrate(prepared_model, data_loader)  # 这是输入一批有代表性的数据来校准print("校准完毕")quantized_model = convert_fx(prepared_model)  # 转换return quantized_modeldef calibrate(model, data_loader):model.eval()with torch.no_grad():for image, target in train_loader:model(image)if __name__ == "__main__":cuda_device = torch.device("cuda:0")cpu_device = torch.device("cpu:0")model = MobileNetV3_Large(2)  # 加载自己的网络train_loader, test_loader = load_data(64, 8)  # 自己写一个pytorch加载数据的方法# quantizationstate_dict = torch.load('./mymodel.pth')  # 加载一个正常训练好的模型model.load_state_dict(state_dict)model.to('cpu')model.eval()quant_model = quant_fx(model, train_loader)  # 执行量化代码quant_model.eval()print("开始验证")eval_loss, eval_accuracy = evaluate_model(model=quant_model,test_loader=test_loader,device=cpu_device,criterion=nn.CrossEntropyLoss())print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1, eval_loss, eval_accuracy))torch.jit.save(torch.jit.script(quant_model), 'outQuant.pth')  # 保存量化后的模型# 加载量化模型推理loaded_quantized_model = torch.jit.load('outQuant.pth')eval_loss, eval_accuracy = evaluate_model(model=quant_model,test_loader=test_loader,device=cpu_device,criterion=nn.CrossEntropyLoss())print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1, eval_loss, eval_accuracy))

fx量化也不用管里面什么算子不支持之类的,开箱即用,以上代码参考pytorch官网https://pytorch.org/docs/stable/fx.html
最后验证模型精度下降0.02%可以忽略不计,pytorch量化的模型是不支持gpu推理的,只能在arm或者x86平台实现压缩提速。要用cuda的话要上tensorrt+onnx,以后完成了再讲。完整的训练模型量化模型的代码后面会放到github上面。
完整代码:https://github.com/Ysnower/pytorch-static-quant


总结

刚开始搞量化坑比较多,一个是某些操作不支持,合并层麻烦,另外有版本问题导致的报错可能搞很久,觉得有用的各位吴彦祖麻烦送个免费三连

pytorch FX模型静态量化相关推荐

  1. onnxruntime 模型静态量化

    文章目录 前言 代码 测试结果 前言 安装onnxruntime pip install -i https://mirror.baidu.com/pypi/simple onnxruntime==1. ...

  2. PyTorch模型训练完毕后静态量化、保存、加载int8量化模型

    1. PyTorch模型量化方法 Pytorch模型量化方法介绍有很多可以参考的,这里推荐两篇文章写的很详细可以给大家一个大致的参考Pytorch的量化,官方量化文档 Pytorch的量化大致分为三种 ...

  3. torch.fx 简介与量化

    pytroch发布的torch.fx工具包可以说是很好的消除一些动态图和静态图的Gap,可以使得我们对于nn.Module的各种变化操作变得非常简单. 动态图和静态图: 动态意味着程序将按照我们编写命 ...

  4. 适配PyTorch FX,OneFlow让量化感知训练更简单

    作者 | 刘耀辉 审稿 | BBuf.许啸宇 1 背景 近年来,量化感知训练是一个较为热点的问题,可以大大优化量化后训练造成精度损失的问题,使得训练过程更加高效. Torch.fx在这一问题上走在了前 ...

  5. 模型量化(3):ONNX 模型的静态量化和动态量化

    转自AI Studio,原文链接:模型量化(3):ONNX 模型的静态量化和动态量化 - 飞桨AI Studio 1. 引入 前面介绍了模型量化的基本原理 也介绍了如何使用 PaddleSlim 对 ...

  6. ONNX 模型的静态量化和动态量化

    文章目录 1 量化介绍 1.1 量化概述 1.2 量化方式 1.3 量化类型 1.4 量化格式 2. 量化实践 2.1 安装依赖 2.2 模型准备 2.3 动态量化 2.4 静态量化 3. 对比测试 ...

  7. 用沐神的方法阅读PyTorch FX论文

    [GiantPandaCV导语]torch.fx对于PyTorch来说确实是一个比较好的工作,因为它消除了一些动态图和静态图的Gap.比如在图改写方面,torch.fx让PyTorch想做一些其它静态 ...

  8. 神经网络(模型)量化介绍 - PTQ 和 QAT

    神经网络(模型)量化介绍 - PTQ 和 QAT 1. 需求目的 2. 量化简介 3. 三种量化模式 3.1 Dynamic Quantization - 动态量化 3.2 Post-Training ...

  9. c++list遍历_小白学PyTorch | 6 模型的构建访问遍历存储(附代码)

    关注一下不迷路哦~喜欢的点个星标吧~<> 小白学PyTorch | 5 torchvision预训练模型与数据集全览 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyT ...

最新文章

  1. SpringBoot-web开发(四): SpringMVC的拓展、接管(源码分析)
  2. 如何快速融入一个团队?
  3. 我的世界java版幻翼_我的世界:熬夜3天能见到“幻翼”?你错了,还要满足这7个条件!...
  4. NYOJ 248 BUYING FEED (贪心)
  5. 软件测试面试选择判断提,软件测试面试常考判断题
  6. mybatis 一对一 一对多 级联查询
  7. 北京计算机一级2020,2020北京市一级计算机基础及MS Office应用考试在线自测试题库(不限设备,登陆即可做题)...
  8. vue 微信开发工具 Maximum call stack size exceeded
  9. 搞笑日常:有位程序员的老爸是个什么感觉?过程你绝对意想不到!
  10. 关于SQLite.org网站给黑...
  11. 记录表类型 oracle,[转]关于oracle的记录类型
  12. c语言char float混合类型运算,求int long char double float 在混合运算中的自动转换规则 win 32位系统。...
  13. 修复桌面快捷方式箭头图标
  14. h5 a标签下载链接下载文件
  15. 暴力破解rar和zip加密压缩包
  16. 使用cerebro可视化ElasticSearch集群信息
  17. java格式化时间间隔_用Java本地化格式化时间间隔
  18. html中diy的背景怎么透明,自制复古几何无缝纹案背景_html/css_WEB-ITnose
  19. css3遮罩——新功能引导层
  20. viewpager 与 pageradapter

热门文章

  1. Word转Chm 在线转换应用
  2. SEO效果评估(7大综合指标)
  3. 无人机激光雷达系统在森林资源调查中的应用
  4. 服务正在启动或停止中
  5. 约瑟夫(Joseph)问题
  6. 一个大二学生送给大一学弟学妹的建议
  7. 1200 -- 无聊又简单的游戏
  8. 被解放的姜戈07 马不停蹄
  9. 焦作护理学校计算机应用专业,焦作护理学校
  10. excel常用函数大全(一)