【PyTorch】模型 FPS 测试 Benchmark(参考 MMDetection 实现)
引言
深度学习中,模型的速度和性能具有同等重要的地位,因为这直接关系到模型是否能在实际生产应用中落地。在计算机视觉领域,FPS(模型每秒能够处理的图像帧数)是一个重要且直观地反映模型处理速度的指标,基本在所有图像处理类任务中都有用到,例如图像超分,图像修复和目标检测等等。本文从 MMDetection 中抽取了 FPS Benchmark,并做了微小的修改,以便快速测试。
代码
参数 | 描述 |
---|---|
model |
继承 torch.nn.Module 类实例化的 PyTorch 模型。
|
input_size |
模型可接受的输入维度。注意第一个维度是 batch_size ,必须为 1,余下的维度根据模型来设置。
|
device |
选择在 GPU 或 CPU 上测试 FPS。默认是在 CPU 上测试,也支持 GPU,例如 cuda:0 是在机器的第一张独立显卡上测试。
|
warmup_num | 预热次数。因为模型刚开始测试的几轮速度很慢,会影响 FPS 的测试结果,所以我们直接跳过。 |
log_interval | 打印日志的频率,即每隔多少轮打印计算的平均 FPS 值。 |
iterations | 单次测试的总迭代次数。程序会汇总该迭代次数内的所有 FPS 值,并取平均作为我们最终的结果。 |
repeat_num | 重复测试的次数。为进一步缓解测试结果的偶然性,可进行多次重复的测试实验。 |
import torch
import timeclass FPSBenchmark():def __init__(self,model: torch.nn.Module,input_size: tuple,device: str = "cpu",warmup_num: int = 5,log_interval: int = 10,iterations: int = 100,repeat_num: int = 1,) -> None:"""FPS benchmark.Ref:MMDetection: https://mmdetection.readthedocs.io/en/stable/useful_tools.html#fps-benchmark.Args:model (torch.nn.Module): model to be tested.input_size (tuple): model acceptable input size, e.g. `BCHW`, make sure `batch_size` is 1.device (str): device for test. Default to "cpu".warmup_num (int, optional): the first several iterations may be very slow so skip them. Defaults to 5.iterations (int, optional): numer of iterations in a single test. Defaults to 100.repeat_num (int, optional): number of repeat tests. Defaults to 1."""# Parameters for `load_model`self.model = modelself.input_size = input_sizeself.device = device# Parameters for `measure_inference_speed`self.warmup_num = warmup_numself.log_interval = log_intervalself.iterations = iterations# Parameters for `repeat_measure_inference_speed`self.repeat_num = repeat_numdef load_model(self):model = self.model.to(self.device)model.eval()return modeldef measure_inference_speed(self):model = self.load_model()pure_inf_time = 0fps = 0for i in range(self.iterations):input_data = torch.randn(self.input_size, device=self.device)if "cuda" in self.device:torch.cuda.synchronize()start_time = time.perf_counter()with torch.no_grad():model(input_data)torch.cuda.synchronize()elif "cpu" in self.device:start_time = time.perf_counter()with torch.no_grad():model(input_data)else:NotImplementedError(f"{self.device} hasn't been implemented yet.")elapsed = time.perf_counter() - start_timeif i >= self.warmup_num:pure_inf_time += elapsedif (i + 1) % self.log_interval == 0:fps = (i + 1 - self.warmup_num) / pure_inf_timeprint(f'Done image [{i + 1:0>3}/{self.iterations}], 'f'FPS: {fps:.2f} img/s, 'f'Times per image: {1000 / fps:.2f} ms/img',flush=True,)else:passelse:passfps = (self.iterations - self.warmup_num) / pure_inf_timeprint(f'Overall FPS: {fps:.2f} img/s, 'f'Times per image: {1000 / fps:.2f} ms/img',flush=True,)return fpsdef repeat_measure_inference_speed(self):assert self.repeat_num >= 1fps_list = []for _ in range(self.repeat_num):fps_list.append(self.measure_inference_speed())if self.repeat_num > 1:fps_list_ = [round(fps, 2) for fps in fps_list]times_pre_image_list_ = [round(1000 / fps, 2) for fps in fps_list]mean_fps_ = sum(fps_list_) / len(fps_list_)mean_times_pre_image_ = sum(times_pre_image_list_) / len(times_pre_image_list_)print(f'Overall FPS: {fps_list_}[{mean_fps_:.2f}] img/s, 'f'Times per image: 'f'{times_pre_image_list_}[{mean_times_pre_image_:.2f}] ms/img',flush=True,)return fps_listelse:return fps_list[0]if __name__ == '__main__':FPSBenchmark(model=torch.nn.Conv2d(3, 64, 3, 1, 1),input_size=(1, 3, 224, 224),device="cuda:0",).repeat_measure_inference_speed()
参考
https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/benchmark.py
【PyTorch】模型 FPS 测试 Benchmark(参考 MMDetection 实现)相关推荐
- TensorRT和PyTorch模型的故事
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨伯恩legacy 来源丨https://zhuanlan.zh ...
- 基于C++的PyTorch模型部署
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 引言 PyTorch作为一款端到端的深度学习框架,在1.0版本之后 ...
- 如何使用TensorRT对训练好的PyTorch模型进行加速?
点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨伯恩legacy@知乎 来源丨https://zhuanlan.zhihu.com/p/8831 ...
- pytorch模型转onnx-量化rknn(bisenet)
1.pytorch模型转化onnx 先把pytorch的.pth模型转成onnx,例如我这个是用Bisenet转的,执行export_onnx.py import argparse import os ...
- pytorch | 深度学习分割网络U-net的pytorch模型实现
原文:https://blog.csdn.net/u014722627/article/details/60883185 pytorch | 深度学习分割网络U-net的pytorch模型实现 这个是 ...
- PyTorch模型部署:pth转onnx跨框架部署详解+代码
文章目录 引言 基础概念 onnx:跨框架的模型表达标准 onnxruntime:部署模型的推理引擎 示例代码 0)安装onnx和onnxruntime 1)pytorch模型转onnx模型 2)on ...
- pytorch 模型同一轮两次预测结果不一样_2020年的最新深度学习模型可解释性综述[附带代码]...
最近low-level vision的炼丹经常出现各种主观评测上的效果问题,无法定位出其对于输入数据的对应关系,出现了问题之后很难进行针对性解决. 这个时候一个很自然的问题就是,都2020年了,深度学 ...
- python吃显卡还是内存不足_解决Pytorch 训练与测试时爆显存(out of memory)的问题
Pytorch 训练时有时候会因为加载的东西过多而爆显存,有些时候这种情况还可以使用cuda的清理技术进行修整,当然如果模型实在太大,那也没办法. 使用torch.cuda.empty_cache() ...
- 《Pytorch 模型推理及多任务通用范式》第三节作业
1 课程学习 本节课主要对于大白AI课程:https://mp.weixin.qq.com/s/STbdSoI7xLeHrNyLlw9GOg <Pytorch 模型推理及多任务通用范式>课 ...
最新文章
- js 定时任务,定时器
- 146. LRU Cache--java,python解法
- Messages 贪心,期望,概率,模拟(2000)
- mysql两条记录合成一条数据_踩坑记录之csv数据导入MySQL
- Javascript:原型模式类继承
- php基础教程 第八步循环补充
- [导入]C#实现WEB浏览器
- Oracle意外赢官司,程序员或过苦日子
- android优雅的一个侧滑
- 新浪第一时间视频直播全球火炬接力
- 记一次VS Code崩溃的解决(Win10扫描自动回复系统文件)
- ENVI入门系列教程---一、数据预处理---4.3自定义RPC文件图像正射校正
- Ubuntu 无法mount解决办法
- 如何破解无法炸开的CAD加密图纸
- 三菱plc程序三菱FX3U画圆程序,只要弄明白这个程序,就可以非常了解整个项目的程序如何去编写
- Oracle Data Guard官方说明
- 大学c语言程序设计听不懂,C语言听不懂?那你还不点进来看看?
- Spring问题研究之bean的属性xml注入List类型不匹配
- MySQL模糊查询 结果按匹配度 排序
- 高性能web平台【OpenResty入门与实战】