引言

深度学习中,模型的速度和性能具有同等重要的地位,因为这直接关系到模型是否能在实际生产应用中落地。在计算机视觉领域,FPS(模型每秒能够处理的图像帧数)是一个重要且直观地反映模型处理速度的指标,基本在所有图像处理类任务中都有用到,例如图像超分,图像修复和目标检测等等。本文从 MMDetection 中抽取了 FPS Benchmark,并做了微小的修改,以便快速测试。

代码

参数 描述
model 继承 torch.nn.Module 类实例化的 PyTorch 模型。
input_size 模型可接受的输入维度。注意第一个维度是 batch_size,必须为 1,余下的维度根据模型来设置。
device 选择在 GPU 或 CPU 上测试 FPS。默认是在 CPU 上测试,也支持 GPU,例如 cuda:0 是在机器的第一张独立显卡上测试。
warmup_num 预热次数。因为模型刚开始测试的几轮速度很慢,会影响 FPS 的测试结果,所以我们直接跳过。
log_interval 打印日志的频率,即每隔多少轮打印计算的平均 FPS 值。
iterations 单次测试的总迭代次数。程序会汇总该迭代次数内的所有 FPS 值,并取平均作为我们最终的结果。
repeat_num 重复测试的次数。为进一步缓解测试结果的偶然性,可进行多次重复的测试实验。
import torch
import timeclass FPSBenchmark():def __init__(self,model: torch.nn.Module,input_size: tuple,device: str = "cpu",warmup_num: int = 5,log_interval: int = 10,iterations: int = 100,repeat_num: int = 1,) -> None:"""FPS benchmark.Ref:MMDetection: https://mmdetection.readthedocs.io/en/stable/useful_tools.html#fps-benchmark.Args:model (torch.nn.Module): model to be tested.input_size (tuple): model acceptable input size, e.g. `BCHW`, make sure `batch_size` is 1.device (str): device for test. Default to "cpu".warmup_num (int, optional): the first several iterations may be very slow so skip them. Defaults to 5.iterations (int, optional): numer of iterations in a single test. Defaults to 100.repeat_num (int, optional): number of repeat tests. Defaults to 1."""# Parameters for `load_model`self.model = modelself.input_size = input_sizeself.device = device# Parameters for `measure_inference_speed`self.warmup_num = warmup_numself.log_interval = log_intervalself.iterations = iterations# Parameters for `repeat_measure_inference_speed`self.repeat_num = repeat_numdef load_model(self):model = self.model.to(self.device)model.eval()return modeldef measure_inference_speed(self):model = self.load_model()pure_inf_time = 0fps = 0for i in range(self.iterations):input_data = torch.randn(self.input_size, device=self.device)if "cuda" in self.device:torch.cuda.synchronize()start_time = time.perf_counter()with torch.no_grad():model(input_data)torch.cuda.synchronize()elif "cpu" in self.device:start_time = time.perf_counter()with torch.no_grad():model(input_data)else:NotImplementedError(f"{self.device} hasn't been implemented yet.")elapsed = time.perf_counter() - start_timeif i >= self.warmup_num:pure_inf_time += elapsedif (i + 1) % self.log_interval == 0:fps = (i + 1 - self.warmup_num) / pure_inf_timeprint(f'Done image [{i + 1:0>3}/{self.iterations}], 'f'FPS: {fps:.2f} img/s, 'f'Times per image: {1000 / fps:.2f} ms/img',flush=True,)else:passelse:passfps = (self.iterations - self.warmup_num) / pure_inf_timeprint(f'Overall FPS: {fps:.2f} img/s, 'f'Times per image: {1000 / fps:.2f} ms/img',flush=True,)return fpsdef repeat_measure_inference_speed(self):assert self.repeat_num >= 1fps_list = []for _ in range(self.repeat_num):fps_list.append(self.measure_inference_speed())if self.repeat_num > 1:fps_list_ = [round(fps, 2) for fps in fps_list]times_pre_image_list_ = [round(1000 / fps, 2) for fps in fps_list]mean_fps_ = sum(fps_list_) / len(fps_list_)mean_times_pre_image_ = sum(times_pre_image_list_) / len(times_pre_image_list_)print(f'Overall FPS: {fps_list_}[{mean_fps_:.2f}] img/s, 'f'Times per image: 'f'{times_pre_image_list_}[{mean_times_pre_image_:.2f}] ms/img',flush=True,)return fps_listelse:return fps_list[0]if __name__ == '__main__':FPSBenchmark(model=torch.nn.Conv2d(3, 64, 3, 1, 1),input_size=(1, 3, 224, 224),device="cuda:0",).repeat_measure_inference_speed()

参考

https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/benchmark.py

【PyTorch】模型 FPS 测试 Benchmark(参考 MMDetection 实现)相关推荐

  1. TensorRT和PyTorch模型的故事

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨伯恩legacy 来源丨https://zhuanlan.zh ...

  2. 基于C++的PyTorch模型部署

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 引言 PyTorch作为一款端到端的深度学习框架,在1.0版本之后 ...

  3. 如何使用TensorRT对训练好的PyTorch模型进行加速?

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨伯恩legacy@知乎 来源丨https://zhuanlan.zhihu.com/p/8831 ...

  4. pytorch模型转onnx-量化rknn(bisenet)

    1.pytorch模型转化onnx 先把pytorch的.pth模型转成onnx,例如我这个是用Bisenet转的,执行export_onnx.py import argparse import os ...

  5. pytorch | 深度学习分割网络U-net的pytorch模型实现

    原文:https://blog.csdn.net/u014722627/article/details/60883185 pytorch | 深度学习分割网络U-net的pytorch模型实现 这个是 ...

  6. PyTorch模型部署:pth转onnx跨框架部署详解+代码

    文章目录 引言 基础概念 onnx:跨框架的模型表达标准 onnxruntime:部署模型的推理引擎 示例代码 0)安装onnx和onnxruntime 1)pytorch模型转onnx模型 2)on ...

  7. pytorch 模型同一轮两次预测结果不一样_2020年的最新深度学习模型可解释性综述[附带代码]...

    最近low-level vision的炼丹经常出现各种主观评测上的效果问题,无法定位出其对于输入数据的对应关系,出现了问题之后很难进行针对性解决. 这个时候一个很自然的问题就是,都2020年了,深度学 ...

  8. python吃显卡还是内存不足_解决Pytorch 训练与测试时爆显存(out of memory)的问题

    Pytorch 训练时有时候会因为加载的东西过多而爆显存,有些时候这种情况还可以使用cuda的清理技术进行修整,当然如果模型实在太大,那也没办法. 使用torch.cuda.empty_cache() ...

  9. 《Pytorch 模型推理及多任务通用范式》第三节作业

    1 课程学习 本节课主要对于大白AI课程:https://mp.weixin.qq.com/s/STbdSoI7xLeHrNyLlw9GOg <Pytorch 模型推理及多任务通用范式>课 ...

最新文章

  1. js 定时任务,定时器
  2. 146. LRU Cache--java,python解法
  3. Messages 贪心,期望,概率,模拟(2000)
  4. mysql两条记录合成一条数据_踩坑记录之csv数据导入MySQL
  5. Javascript:原型模式类继承
  6. php基础教程 第八步循环补充
  7. [导入]C#实现WEB浏览器
  8. Oracle意外赢官司,程序员或过苦日子
  9. android优雅的一个侧滑
  10. 新浪第一时间视频直播全球火炬接力
  11. 记一次VS Code崩溃的解决(Win10扫描自动回复系统文件)
  12. ENVI入门系列教程---一、数据预处理---4.3自定义RPC文件图像正射校正
  13. Ubuntu 无法mount解决办法
  14. 如何破解无法炸开的CAD加密图纸
  15. 三菱plc程序三菱FX3U画圆程序,只要弄明白这个程序,就可以非常了解整个项目的程序如何去编写
  16. Oracle Data Guard官方说明
  17. 大学c语言程序设计听不懂,C语言听不懂?那你还不点进来看看?
  18. Spring问题研究之bean的属性xml注入List类型不匹配
  19. MySQL模糊查询 结果按匹配度 排序
  20. 高性能web平台【OpenResty入门与实战】

热门文章

  1. STM32F105双CAN双FIFO通讯心得体会
  2. 性能测试,你需要了解这款工具
  3. 区块链亲民应用场景大猜想 第一次或将献给超大文件传输
  4. mysql查询IN索引无效的问题【已解决】
  5. 搭建一个属于自己的博客平台
  6. 使用Calibre转换任意格式为支持KF8的mobi文件
  7. Spark日志,及设置日志输出级别
  8. 内行人看鸿蒙系统,如何看待华为终端2020年全线搭载鸿蒙系统?内行人“一语道破”...
  9. Week8 作业 C - 班长竞选 SCC Kosaraju HDU - 3639
  10. Cesium自定义编辑多边形