借鉴:https://github.com/gwding/draw_convnet

直接上代码:

import os
import numpy as np
import matplotlib.pyplot as plt
plt.rcdefaults()
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from matplotlib.patches import CircleNumDots = 4
NumConvMax = 8
NumFcMax = 20
White = 1.
Light = 0.7
Medium = 0.5
Dark = 0.3
Darker = 0.15
Black = 0.def add_layer(patches, colors, size=(24, 24), num=5,top_left=[0, 0],loc_diff=[3, -3],):# add a rectangletop_left = np.array(top_left)loc_diff = np.array(loc_diff)loc_start = top_left - np.array([0, size[0]])for ind in range(num):patches.append(Rectangle(loc_start + ind * loc_diff, size[1], size[0]))if ind % 2:colors.append(Medium)else:colors.append(Light)def add_layer_with_omission(patches, colors, size=(24, 24),num=5, num_max=8,num_dots=4,top_left=[0, 0],loc_diff=[3, -3],):# add a rectangletop_left = np.array(top_left)loc_diff = np.array(loc_diff)loc_start = top_left - np.array([0, size[0]])this_num = min(num, num_max)start_omit = (this_num - num_dots) // 2end_omit = this_num - start_omitstart_omit -= 1for ind in range(this_num):if (num > num_max) and (start_omit < ind < end_omit):omit = Trueelse:omit = Falseif omit:patches.append(Circle(loc_start + ind * loc_diff + np.array(size) / 2, 0.5))else:patches.append(Rectangle(loc_start + ind * loc_diff,size[1], size[0]))if omit:colors.append(Black)elif ind % 2:colors.append(Medium)else:colors.append(Light)def add_mapping(patches, colors, start_ratio, end_ratio, patch_size, ind_bgn,top_left_list, loc_diff_list, num_show_list, size_list):start_loc = top_left_list[ind_bgn] \+ (num_show_list[ind_bgn] - 1) * np.array(loc_diff_list[ind_bgn]) \+ np.array([start_ratio[0] * (size_list[ind_bgn][1] - patch_size[1]),- start_ratio[1] * (size_list[ind_bgn][0] - patch_size[0])])end_loc = top_left_list[ind_bgn + 1] \+ (num_show_list[ind_bgn + 1] - 1) * np.array(loc_diff_list[ind_bgn + 1]) \+ np.array([end_ratio[0] * size_list[ind_bgn + 1][1],- end_ratio[1] * size_list[ind_bgn + 1][0]])patches.append(Rectangle(start_loc, patch_size[1], -patch_size[0]))colors.append(Dark)patches.append(Line2D([start_loc[0], end_loc[0]],[start_loc[1], end_loc[1]]))colors.append(Darker)patches.append(Line2D([start_loc[0] + patch_size[1], end_loc[0]],[start_loc[1], end_loc[1]]))colors.append(Darker)patches.append(Line2D([start_loc[0], end_loc[0]],[start_loc[1] - patch_size[0], end_loc[1]]))colors.append(Darker)patches.append(Line2D([start_loc[0] + patch_size[1], end_loc[0]],[start_loc[1] - patch_size[0], end_loc[1]]))colors.append(Darker)def label(xy, text, xy_off=[0, 4]):plt.text(xy[0] + xy_off[0], xy[1] + xy_off[1], text,family='sans-serif', size=8)if __name__ == '__main__':fc_unit_size = 2layer_width = 40flag_omit = Truepatches = []colors = []fig, ax = plt.subplots()############################# conv layerssize_list = [(28, 28),(28, 28), (28, 28), (14, 14), (14, 14),(14, 14), (7, 7)]#从输入到卷积最后的输出的图像尺寸num_list = [1, 32, 32, 32,64, 64,64]#每一层的特征图的数量x_diff_list = [0, layer_width, layer_width, layer_width,layer_width, layer_width, layer_width]#对应上面的list的个数text_list = ['Inputs'] + ['Feature\nmaps'] * (len(size_list) - 1)loc_diff_list = [[3, -3]] * len(size_list)num_show_list = list(map(min, num_list, [NumConvMax] * len(num_list)))top_left_list = np.c_[np.cumsum(x_diff_list), np.zeros(len(x_diff_list))]for ind in range(len(size_list)-1,-1,-1):if flag_omit:add_layer_with_omission(patches, colors, size=size_list[ind],num=num_list[ind],num_max=NumConvMax,num_dots=NumDots,top_left=top_left_list[ind],loc_diff=loc_diff_list[ind])else:add_layer(patches, colors, size=size_list[ind],num=num_show_list[ind],top_left=top_left_list[ind], loc_diff=loc_diff_list[ind])label(top_left_list[ind], text_list[ind] + '\n{}@\n{}x{}'.format(num_list[ind], size_list[ind][0], size_list[ind][1]))############################# in between layersstart_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#对应list的个数,这里是6end_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#对应list的个数,这里是6patch_size_list = [(3, 3), (3, 3), (2, 2), (3, 3), (3, 3),(2, 2)]#卷积或池化核的尺寸,对应list的个数,这里是6ind_bgn_list = range(len(patch_size_list))text_list = ['Conv','Conv', 'pool', 'Conv','Conv', 'pool']#结构图的说明,这里是6个for ind in range(len(patch_size_list)):add_mapping(patches, colors, start_ratio_list[ind], end_ratio_list[ind],patch_size_list[ind], ind,top_left_list, loc_diff_list, num_show_list, size_list)label(top_left_list[ind], text_list[ind] + '\n{}x{}'.format(patch_size_list[ind][0], patch_size_list[ind][1]), xy_off=[65, -65]##通过图上比较相对位置来修改坐标)############################# fully connected layerssize_list = [(fc_unit_size, fc_unit_size)] * 3num_list = [3136,256,10 ]num_show_list = list(map(min, num_list, [NumFcMax] * len(num_list)))x_diff_list = [sum(x_diff_list) + layer_width, layer_width, layer_width]top_left_list = np.c_[np.cumsum(x_diff_list), np.zeros(len(x_diff_list))]loc_diff_list = [[fc_unit_size, -fc_unit_size]] * len(top_left_list)text_list = ['Hidden\nunits'] * (len(size_list) - 1) + ['Outputs']for ind in range(len(size_list)):if flag_omit:add_layer_with_omission(patches, colors, size=size_list[ind],num=num_list[ind],num_max=NumFcMax,num_dots=NumDots,top_left=top_left_list[ind],loc_diff=loc_diff_list[ind])else:add_layer(patches, colors, size=size_list[ind],num=num_show_list[ind],top_left=top_left_list[ind],loc_diff=loc_diff_list[ind])label(top_left_list[ind], text_list[ind] + '\n{}'.format(num_list[ind]))text_list = ['Flatten\n', 'Fully\nconnected', 'Fully\nconnected']for ind in range(len(size_list)):label(top_left_list[ind], text_list[ind], xy_off=[30, -65])#通过图上比较相对位置来修改坐标############################for patch, color in zip(patches, colors):patch.set_color(color * np.ones(3))if isinstance(patch, Line2D):ax.add_line(patch)else:patch.set_edgecolor(Black * np.ones(3))ax.add_patch(patch)plt.tight_layout()plt.axis('equal')plt.axis('off')plt.show()fig.set_size_inches(8, 2.5)fig_dir = './'fig_ext = '.png'fig.savefig(os.path.join(fig_dir, 'convnet_fig' + fig_ext),bbox_inches='tight', pad_inches=0)

这里实际上就是对该代码的说明书吧,知道怎么去修改用来绘画自己的CNN。

size_list = [(28, 28),(28, 28), (28, 28), (14, 14), (14, 14),(14, 14), (7, 7)]#从输入到卷积最后的输出的图像尺寸
num_list = [1, 32, 32, 32,64, 64,64]#每一层的特征图的数量
x_diff_list = [0, layer_width, layer_width, layer_width,layer_width, layer_width, layer_width]#对应上面的list的个数
start_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#对应list的个数,这里是6
end_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#对应list的个数,这里是6
patch_size_list = [(3, 3), (3, 3), (2, 2), (3, 3), (3, 3),(2, 2)]#卷积或池化核的尺寸,对应list的个数,这里是6
text_list = ['Conv','Conv', 'pool', 'Conv','Conv', 'pool']#结构图的说明,这里是6个
label(top_left_list[ind], text_list[ind], xy_off=[30, -65])#通过图上比较相对位置来修改坐标

运行如下:

用于说明卷积神经网络(ConvNet)的Python脚本相关推荐

  1. keras构建卷积神经网络_在python中使用tensorflow s keras api构建卷积神经网络的初学者指南...

    keras构建卷积神经网络 初学者的深度学习 (DEEP LEARNING FOR BEGINNERS) Welcome to Part 2 of the Neural Network series! ...

  2. python 卷积神经网络 应用_卷积神经网络概述及python实现

    摘要:本文概括地介绍CNN的基本原理 ,并通过阿拉伯字母分类例子具体介绍其实现过程,理论与实践的结合体. 对于卷积神经网络(CNN)而言,相信很多读者并不陌生,该网络近年来在大多数领域都表现优异,尤其 ...

  3. 卷积神经网络算法python实现_卷积神经网络概述及python实现-阿里云开发者社区...

    对于卷积神经网络(CNN)而言,相信很多读者并不陌生,该网络近年来在大多数领域都表现优异,尤其是在计算机视觉领域中.但是很多工作人员可能直接调用相关的深度学习工具箱搭建卷积神经网络模型,并不清楚其中具 ...

  4. 卷积神经网络pytorch_使用PyTorch和卷积神经网络进行动物分类

    卷积神经网络pytorch 介绍 (Introduction) PyTorch is a deep learning framework developed by Facebook's AI Rese ...

  5. 干货 | 如何入手卷积神经网络

    点击上方"视学算法",选择"星标"公众号 重磅干货,第一时间送达 来自 | medium    作者丨Tirmidzi Faizal Aflahi 来源丨机器之 ...

  6. 如何入手卷积神经网络

    选自medium 作者:Tirmidzi Faizal Aflahi 参与:韩放.王淑婷 卷积神经网络可以算是深度神经网络中很流行的网络了.本文从基础入手,介绍了卷积网络的基本原理以及相关的其它技术, ...

  7. 文本分类(下) | 卷积神经网络(CNN)在文本分类上的应用

    正文共3758张图,4张图,预计阅读时间18分钟. 1.简介 原先写过两篇文章,分别介绍了传统机器学习方法在文本分类上的应用以及CNN原理,然后本篇文章结合两篇论文展开,主要讲述下CNN在文本分类上的 ...

  8. 卷积神经网络看见了什么

    NVIDIA DLI 深度学习入门培训 | 特设三场!! 4月28日/5月19日/5月26日一天密集式学习  快速带你入门阅读全文> 正文共1859个字,2张图,预计阅读时间5分钟. 这是众多卷 ...

  9. 透析 | 卷积神经网络CNN究竟是怎样一步一步工作的?

    北京 | 深度学习与人工智能研修 12月23-24日 再设经典课程 重温深度学习阅读全文> 正文共6018个字109张图,预计阅读时间16分钟. 视频地址:https://www.youtube ...

最新文章

  1. 如何判断Java中两个Class对象是否相同
  2. 内卷时代,互联网人相亲有多难?|漫画
  3. asp.net 中chartlet 统计图表的的使用
  4. AngularJs学习笔记--unit-testing
  5. ALS爱立思脚本调用参考
  6. 计算天数java_Java,计算两个日期之间的天数
  7. java实现人脸识别源码【含测试效果图】——Service层(IUserService)
  8. 解决vue里iscroll(better-scroll)点击触发两次和初始化无法滚动问题!
  9. AMD5470显卡Ubuntu下的U盘的使用
  10. 判断sem信号量为零_kernel.sem信号量调优
  11. TypeError: softmax() got an unexpected keyword argument 'axis'
  12. 你能相信这些逼真的油画是前端小姐姐只用HTML+CSS画出来的吗?精细到毛发,让美术设计也惊叹丨GitHub热榜...
  13. Ensemble Learning方法总结
  14. 常用传感器讲解九--雨滴传感器
  15. 蒙特卡罗(Monte Carlo) 模拟
  16. 搜狗拼音输入法居然输入不了半角的人民币符号¥¥¥¥¥¥¥¥!!!
  17. MySQL分库分表后聚合查询_MySQL订单分库分表多维度查询
  18. 大数据(7f)比较Python和Scala面向对象
  19. pytohn 单下划线与双下划线的区别
  20. ODPS上下文参数的使用

热门文章

  1. 从拼产品到拼营销,头条是不是走偏了?
  2. 智禾教育:现在入局电商行业算晚吗,行业前景将会如何发展
  3. python 通达信公式函数_通达信zig函数的python实现
  4. 数学建模-层次分析法
  5. 宽带换了新的账号怎么连接服务器地址,宽带换了路由器设置步骤图解
  6. 多目标跟踪(MOT,Multiple Object Tracking)评价指标
  7. Pocket英语语法---六、感官动词接不同的动词表示什么意思
  8. C# 将打印机临时缓存文件SPL转为图片文件EMF
  9. 你是否在Microsoft Edge上测试你的网站?
  10. PPT内常用的五个插件