Original url:

https://zhuanlan.zhihu.com/p/33992733

从网上各种资料加上自己实践的可用工具。

主要包括:

模型层数:print_layers_num

模型参数总量:print_model_parm_nums

模型的计算图:def print_autograd_graph():或者参见tensorboad

模型滤波器可视化:show_save_tensor

模型在具体的输入下的尺寸信息summary以及参数量:show_summary

模型计算量:print_model_parm_flops

格式较混乱,但上述代码均可用,后续会继续整理。

#coding:utf8
import torch
import torchvisionimport torch.nn as nn
from torch.autograd import Variable
import torchvision.models as modelsimport numpy as npdef test():model = models.resnet18()print model.layer1[0].conv1.weight.dataprint model.layer1[0].conv1.__class__#<class 'torch.nn.modules.conv.Conv2d'>print model.layer1[0].conv1.kernel_sizeinput = torch.autograd.Variable(torch.randn(20, 16, 50, 100))print input.size()print np.prod(input.size())def print_model_parm_nums():model = models.alexnet()total = sum([param.nelement() for param in model.parameters()])print('  + Number of params: %.2fM' % (total / 1e6))def print_model_parm_flops():# prods = {}# def save_prods(self, input, output):# print 'flops:{}'.format(self.__class__.__name__)# print 'input:{}'.format(input)# print '_dim:{}'.format(input[0].dim())# print 'input_shape:{}'.format(np.prod(input[0].shape))# grads.append(np.prod(input[0].shape))prods = {}def save_hook(name):def hook_per(self, input, output):# print 'flops:{}'.format(self.__class__.__name__)# print 'input:{}'.format(input)# print '_dim:{}'.format(input[0].dim())# print 'input_shape:{}'.format(np.prod(input[0].shape))# prods.append(np.prod(input[0].shape))prods[name] = np.prod(input[0].shape)# prods.append(np.prod(input[0].shape))return hook_perlist_1=[]def simple_hook(self, input, output):list_1.append(np.prod(input[0].shape))list_2={}def simple_hook2(self, input, output):list_2['names'] = np.prod(input[0].shape)multiply_adds = Falselist_conv=[]def conv_hook(self, input, output):batch_size, input_channels, input_height, input_width = input[0].size()output_channels, output_height, output_width = output[0].size()kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)bias_ops = 1 if self.bias is not None else 0params = output_channels * (kernel_ops + bias_ops)flops = batch_size * params * output_height * output_widthlist_conv.append(flops)list_linear=[] def linear_hook(self, input, output):batch_size = input[0].size(0) if input[0].dim() == 2 else 1weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)bias_ops = self.bias.nelement()flops = batch_size * (weight_ops + bias_ops)list_linear.append(flops)list_bn=[] def bn_hook(self, input, output):list_bn.append(input[0].nelement())list_relu=[] def relu_hook(self, input, output):list_relu.append(input[0].nelement())list_pooling=[]def pooling_hook(self, input, output):batch_size, input_channels, input_height, input_width = input[0].size()output_channels, output_height, output_width = output[0].size()kernel_ops = self.kernel_size * self.kernel_sizebias_ops = 0params = output_channels * (kernel_ops + bias_ops)flops = batch_size * params * output_height * output_widthlist_pooling.append(flops)def foo(net):childrens = list(net.children())if not childrens:if isinstance(net, torch.nn.Conv2d):# net.register_forward_hook(save_hook(net.__class__.__name__))# net.register_forward_hook(simple_hook)# net.register_forward_hook(simple_hook2)net.register_forward_hook(conv_hook)if isinstance(net, torch.nn.Linear):net.register_forward_hook(linear_hook)if isinstance(net, torch.nn.BatchNorm2d):net.register_forward_hook(bn_hook)if isinstance(net, torch.nn.ReLU):net.register_forward_hook(relu_hook)if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):net.register_forward_hook(pooling_hook)returnfor c in childrens:foo(c)resnet = models.alexnet()foo(resnet)input = Variable(torch.rand(3,224,224).unsqueeze(0), requires_grad = True)out = resnet(input)total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling))print('  + Number of FLOPs: %.2fG' % (total_flops / 1e9))# print list_bn# print 'prods:{}'.format(prods)# print 'list_1:{}'.format(list_1)# print 'list_2:{}'.format(list_2)# print 'list_final:{}'.format(list_final)def print_forward():model = torchvision.models.resnet18()select_layer = model.layer1[0].conv1grads={}def save_grad(name):def hook(self, input, output):grads[name] = inputreturn hookselect_layer.register_forward_hook(save_grad('select_layer'))input = Variable(torch.rand(3,224,224).unsqueeze(0), requires_grad = True)out = model(input)# print grads['select_layer']print gradsdef print_value():grads = {}def save_grad(name):def hook(grad):grads[name] = gradreturn hookx = Variable(torch.randn(1,1), requires_grad=True)y = 3*xz = y**2# In here, save_grad('y') returns a hook (a function) that keeps 'y' as namey.register_hook(save_grad('y'))z.register_hook(save_grad('z'))z.backward()print 'HW'print("grads['y']: {}".format(grads['y']))print(grads['z'])def print_layers_num():# resnet = models.resnet18()resnet = models.resnet18()def foo(net):childrens = list(net.children())if not childrens:if isinstance(net, torch.nn.Conv2d):print ' '#可以用来统计不同层的个数# net.register_backward_hook(print)return 1count = 0for c in childrens:count += foo(c)return countprint(foo(resnet))def check_summary():def torch_summarize(model, show_weights=True, show_parameters=True):"""Summarizes torch model by showing trainable parameters and weights."""from torch.nn.modules.module import _addindenttmpstr = model.__class__.__name__ + ' (\n'for key, module in model._modules.items():# if it contains layers let call it recursively to get params and weightsif type(module) in [torch.nn.modules.container.Container,torch.nn.modules.container.Sequential]:modstr = torch_summarize(module)else:modstr = module.__repr__()modstr = _addindent(modstr, 2)params = sum([np.prod(p.size()) for p in module.parameters()])weights = tuple([tuple(p.size()) for p in module.parameters()])tmpstr += '  (' + key + '): ' + modstr if show_weights:tmpstr += ', weights={}'.format(weights)if show_parameters:tmpstr +=  ', parameters={}'.format(params)tmpstr += '\n'   tmpstr = tmpstr + ')'return tmpstr# Testimport torchvision.models as modelsmodel = models.alexnet()print(torch_summarize(model))#https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7
#summarize a torch model like in keras, showing parameters and output shape
def show_summary():from collections import OrderedDictimport pandas as pdimport numpy as npimport torchfrom torch.autograd import Variableimport torch.nn.functional as Ffrom torch import nndef get_names_dict(model):"""Recursive walk to get names including path"""names = {}def _get_names(module, parent_name=''):for key, module in module.named_children():name = parent_name + '.' + key if parent_name else keynames[name]=moduleif isinstance(module, torch.nn.Module):_get_names(module, parent_name=name)_get_names(model)return namesdef torch_summarize_df(input_size, model, weights=False, input_shape=True, nb_trainable=False):"""Summarizes torch model by showing trainable parameters and weights.author: wassnameurl: https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7license: MITModified from:- https://github.com/pytorch/pytorch/issues/2001#issuecomment-313735757- https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7/Usage:import torchvision.models as modelsmodel = models.alexnet()df = torch_summarize_df(input_size=(3, 224,224), model=model)print(df)#              name class_name        input_shape       output_shape  nb_params# 1     features=>0     Conv2d  (-1, 3, 224, 224)   (-1, 64, 55, 55)      23296#(3*11*11+1)*64# 2     features=>1       ReLU   (-1, 64, 55, 55)   (-1, 64, 55, 55)          0# ..."""def register_hook(module):def hook(module, input, output):name = ''for key, item in names.items():if item == module:name = key#<class 'torch.nn.modules.conv.Conv2d'>class_name = str(module.__class__).split('.')[-1].split("'")[0]module_idx = len(summary)m_key = module_idx + 1summary[m_key] = OrderedDict()summary[m_key]['name'] = namesummary[m_key]['class_name'] = class_nameif input_shape:summary[m_key]['input_shape'] = (-1, ) + tuple(input[0].size())[1:]summary[m_key]['output_shape'] = (-1, ) + tuple(output.size())[1:]if weights:summary[m_key]['weights'] = list([tuple(p.size()) for p in module.parameters()])#             summary[m_key]['trainable'] = any([p.requires_grad for p in module.parameters()])if nb_trainable:params_trainable = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters() if p.requires_grad])summary[m_key]['nb_trainable'] = params_trainableparams = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters()])summary[m_key]['nb_params'] = paramsif  not isinstance(module, nn.Sequential) and \not isinstance(module, nn.ModuleList) and \not (module == model):hooks.append(module.register_forward_hook(hook))# Names are stored in parent and path+name is unique not the namenames = get_names_dict(model)# check if there are multiple inputs to the networkif isinstance(input_size[0], (list, tuple)):x = [Variable(torch.rand(1, *in_size)) for in_size in input_size]else:x = Variable(torch.rand(1, *input_size))if next(model.parameters()).is_cuda:x = x.cuda()# create propertiessummary = OrderedDict()hooks = []# register hookmodel.apply(register_hook)# make a forward passmodel(x)# remove these hooksfor h in hooks:h.remove()# make dataframedf_summary = pd.DataFrame.from_dict(summary, orient='index')return df_summary# Test on alexnetimport torchvision.models as modelsmodel = models.alexnet()df = torch_summarize_df(input_size=(3, 224, 224), model=model)print(df)# # Output#              name class_name        input_shape       output_shape  nb_params# 1     features=>0     Conv2d  (-1, 3, 224, 224)   (-1, 64, 55, 55)      23296#nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),# 2     features=>1       ReLU   (-1, 64, 55, 55)   (-1, 64, 55, 55)          0# 3     features=>2  MaxPool2d   (-1, 64, 55, 55)   (-1, 64, 27, 27)          0# 4     features=>3     Conv2d   (-1, 64, 27, 27)  (-1, 192, 27, 27)     307392# 5     features=>4       ReLU  (-1, 192, 27, 27)  (-1, 192, 27, 27)          0# 6     features=>5  MaxPool2d  (-1, 192, 27, 27)  (-1, 192, 13, 13)          0# 7     features=>6     Conv2d  (-1, 192, 13, 13)  (-1, 384, 13, 13)     663936# 8     features=>7       ReLU  (-1, 384, 13, 13)  (-1, 384, 13, 13)          0# 9     features=>8     Conv2d  (-1, 384, 13, 13)  (-1, 256, 13, 13)     884992# 10    features=>9       ReLU  (-1, 256, 13, 13)  (-1, 256, 13, 13)          0# 11   features=>10     Conv2d  (-1, 256, 13, 13)  (-1, 256, 13, 13)     590080# 12   features=>11       ReLU  (-1, 256, 13, 13)  (-1, 256, 13, 13)          0# 13   features=>12  MaxPool2d  (-1, 256, 13, 13)    (-1, 256, 6, 6)          0# 14  classifier=>0    Dropout         (-1, 9216)         (-1, 9216)          0# 15  classifier=>1     Linear         (-1, 9216)         (-1, 4096)   37752832# 16  classifier=>2       ReLU         (-1, 4096)         (-1, 4096)          0# 17  classifier=>3    Dropout         (-1, 4096)         (-1, 4096)          0# 18  classifier=>4     Linear         (-1, 4096)         (-1, 4096)   16781312# 19  classifier=>5       ReLU         (-1, 4096)         (-1, 4096)          0# 20  classifier=>6     Linear         (-1, 4096)         (-1, 1000)    4097000def show_save_tensor():import torchfrom torchvision import utilsimport torchvision.models as modelsfrom matplotlib import pyplot as pltdef vis_tensor(tensor, ch = 0, all_kernels=False, nrow=8, padding = 2):'''ch: channel for visualizationallkernels: all kernels for visualization'''n,c,h,w = tensor.shapeif all_kernels:tensor = tensor.view(n*c ,-1, w, h)elif c != 3:tensor = tensor[:, ch,:,:].unsqueeze(dim=1)rows = np.min((tensor.shape[0]//nrow + 1, 64 ))  grid = utils.make_grid(tensor, nrow=nrow, normalize=True, padding=padding)# plt.figure(figsize=(nrow,rows))plt.imshow(grid.numpy().transpose((1, 2, 0)))#CHW HWCdef save_tensor(tensor, filename, ch=0, all_kernels=False, nrow=8, padding=2):n,c,h,w = tensor.shapeif all_kernels:tensor = tensor.view(n*c ,-1, w, h)elif c != 3:tensor = tensor[:, ch,:,:].unsqueeze(dim=1)utils.save_image(tensor, filename, nrow = nrow,normalize=True, padding=padding)vgg = models.resnet18(pretrained=True)mm = vgg.double()filters = mm.modulesbody_model = [i for i in mm.children()][0]# layer1 = body_model[0]layer1 = body_modeltensor = layer1.weight.data.clone()vis_tensor(tensor)save_tensor(tensor,'test.png')plt.axis('off')plt.ioff()plt.show()def print_autograd_graph():from graphviz import Digraphimport torchfrom torch.autograd import Variabledef make_dot(var, params=None):""" Produces Graphviz representation of PyTorch autograd graphBlue nodes are the Variables that require grad, orange are Tensorssaved for backward in torch.autograd.FunctionArgs:var: output Variableparams: dict of (name, Variable) to add names to node thatrequire grad (TODO: make optional)"""if params is not None:#assert all(isinstance(p, Variable) for p in params.values())        param_map = {id(v): k for k, v in params.items()}node_attr = dict(style='filled',shape='box',align='left',fontsize='12',ranksep='0.1',height='0.2')dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))seen = set()def size_to_str(size):return '('+(', ').join(['%d' % v for v in size])+')'def add_nodes(var):if var not in seen:if torch.is_tensor(var):dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')elif hasattr(var, 'variable'):u = var.variable#name = param_map[id(u)] if params is not None else ''#node_name = '%s\n %s' % (name, size_to_str(u.size()))node_name = '%s\n %s' % (param_map.get(id(u.data)), size_to_str(u.size()))dot.node(str(id(var)), node_name, fillcolor='lightblue')else:dot.node(str(id(var)), str(type(var).__name__))seen.add(var)if hasattr(var, 'next_functions'):for u in var.next_functions:if u[0] is not None:dot.edge(str(id(u[0])), str(id(var)))add_nodes(u[0])if hasattr(var, 'saved_tensors'):for t in var.saved_tensors:dot.edge(str(id(t)), str(id(var)))add_nodes(t)add_nodes(var.grad_fn)return dotfrom torchvision import modelstorch.manual_seed(1)inputs = torch.randn(1,3,224,224)model = models.resnet18(pretrained=False)y = model(Variable(inputs))#print(y)g = make_dot(y, params=model.state_dict())g.view()#gif __name__=='__main__':import firefire. Fire()

pytorch实用工具总结(GFLOPs如何计算)相关推荐

  1. 产品价值公式,一个重要的实用工具

    这是一个神奇的公式!它不仅可以帮助我们理解互联网的发展大势,同时也可以是日常产品工作中一个重要的实用工具.信与不信,请看下文分解. 1. 产品价值 1.1 产品价值的概念 在谈论互联网产品时,你是否遇 ...

  2. 如何使用 ASP.NET 实用工具加密凭据和会话状态连接字符串

    文章编号 : 329290 最后修改 : 2006年4月10日 修订 : 8.0 重要说明:本文包含有关如何修改注册表的信息.修改注册表之前,一定要先进行备份,并且一定要知道在发生问题时如何还原注册表 ...

  3. windows installer清理实用工具

    今天编译了一个程序,一个Windows Service卸载命令写错了,导致这个程序不能正常卸载. 只好手动卸载这个Windows Service,之后再卸载程序,依然不能卸载.最后只好使用MSICUU ...

  4. Unix实用工具教程:《sed与awk》修订第三版清晰版

    为什么80%的码农都做不了架构师?>>>    Unix实用工具教程:<sed与awk>修订第三版清晰版 本书介绍了一组名字奇特的Unix实用工具sed和awk,这组实用 ...

  5. ServiceModel 元数据实用工具 (Svcutil.exe)

    ServiceModel 元数据实用工具用于依据元数据文档生成服务模型代码,以及依据服务模型代码生成元数据文档. 在win7系统中的路径为C:\Program Files\Microsoft SDKs ...

  6. Scott Hanselman's 推荐的的实用工具集合(2011版)

    Scott Hanselman活跃于.NET社区,这篇文章来自于它的工具列表,地址是http://www.hanselman.com/tools .NET开发人员应该收藏的工具 LINQPad 快速理 ...

  7. wps右键新建里面没有word和excel_WPS竟然出过这么多实用工具?每个都免费无广告,简直相见恨晚...

    没想到,总被各种吐槽的WPS,竟然还出过这么多实用工具!向你分享4款WPS出品的软件和网站,其中1款让我又爱又恨,而另外几款免费无广告,和三顿一起来看看吧!WPS图片电脑上到底有没有好用的看图软件?这 ...

  8. 十款让 Web 前端开发人员更轻松的实用工具

    这篇文章介绍十款让 Web 前端开发人员生活更轻松的实用工具.每个 Web 开发人员都有自己的工具箱,这样工作中碰到的每个问题都有一个好的解决方案供选择. 对于每一项工作,开发人员需要特定的辅助工具, ...

  9. 实用工具类库java.util

    本章介绍Java的实用工具类库java.util包.在这个包中,Java提供了一些实用的方法和数据结构.例如,Java提供日期(Data)类.日历(Calendar)类来产生和获取日期及时间,提供随机 ...

  10. 工作中常用,实用工具推荐!

    原文:工作中常用,实用工具推荐! Red Gate 家族 大名鼎鼎的RedGate,相信大家都不会陌生,Reflector就是它家做的.这里就不介绍了.我本地安装的是09年下的一个套装,我介绍下常用的 ...

最新文章

  1. MFC获取系统当前时间
  2. 网络技术工程师专业核心 | 网络技术工程师就业方向
  3. HDU - 3068 最长回文(manacher)
  4. C# 通过拼音检索中文名称
  5. GitLab CI/CD 因git凭据踩坑
  6. Python接口自动化之requests请求封装
  7. 使用socks5代理实现SSH安全登录
  8. ubuntu无法打开系统设置的解决办法
  9. 能查阅国外文献的8个论文网站(最新整理)
  10. 图解图库JanusGraph系列-janusgraph图数据库的本地源码编译教程(janusgraph source code compile)
  11. Variable used in lambda expression should be final or effectively final报错解决方案
  12. Auto.js Pro安卓免ROOT引流脚本开发系列教程23网易公开课(1)-前言
  13. 【面试总结】JNI层MediaScanner的分析,挥泪整理面经
  14. JvisualVM使用教程
  15. orge_src版编译与安装
  16. ==和equals的基本注意事项
  17. 95后阿里P7晒出工资单:狠补了这个,真香...
  18. Python安利一个会化学方程式的消灭泡泡小游戏~
  19. 基础博弈论题和一些题解
  20. Illustrator CS中字体丢失或缺失问题的解决方法

热门文章

  1. mysql计算同比和环比的区别_【面试真题】Mysql实现计算同比、环比
  2. 如何用计算机算分数乘法,分数乘法怎么算
  3. 【PTA】中M22春C、Java入门练习7-138 质因子分解
  4. 怎样复制秀米html码,来,今天学习秀米的“复制粘贴”快捷键~
  5. 在xml添加红色的星号android,在文本输入框中输入编辑文本(红色星号)的必填符号...
  6. 10X Genomics单细胞转录组测序
  7. [电机控制话题] 精辟!伺服电机、舵机、步进电机的区别
  8. SpringCound-Alibaba
  9. gitbook生成目录toc
  10. 计算机工作月度个人总结怎么写,计算机*学生个人实习工作总结范文