原创: Lebhoryi@rt-thread.com
时间: 2020/06/18

  • 项目: ML-KWS-for-MCU
  • 参考源代码:DSCNN bn folded
  • pytorch 实现
  • bn折叠好处:推理时速度提升
  • 以下pdf手稿文件上传至csdn,无需积分即可下载



魔改的cnn的推理时,将bn折叠,即在训练的变量上乘以一个系数从而将bn层在推理时舍去,

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport argparse
import sys
import os
import numpy as np
import tensorflow as tfpath = os.path.dirname(__file__)
sys.path.append(os.path.join(path, '../'))
import models
import input_dataFLAGS = Nonedef fold_batch_norm(wanted_words, sample_rate, clip_duration_ms,window_size_ms, window_stride_ms,dct_coefficient_count, model_architecture, model_size_info):"""Creates an audio model with the nodes needed for inference.Uses the supplied arguments to create a model, and inserts the input andoutput nodes that are needed to use the graph for inference.Args:wanted_words: Comma-separated list of the words we're trying to recognize.sample_rate: How many samples per second are in the input audio files.clip_duration_ms: How many samples to analyze for the audio pattern.window_size_ms: Time slice duration to estimate frequencies from.window_stride_ms: How far apart time slices should be.dct_coefficient_count: Number of frequency bands to analyze.model_architecture: Name of the kind of model to generate."""tf.logging.set_verbosity(tf.logging.INFO)sess = tf.InteractiveSession()words_list = input_data.prepare_words_list(wanted_words.split(','))model_settings = models.prepare_model_settings(len(words_list), sample_rate, clip_duration_ms, window_size_ms,window_stride_ms, dct_coefficient_count)fingerprint_input = tf.placeholder(tf.float32, [None, model_settings['fingerprint_size']], name='fingerprint_input')logits = models.create_model(fingerprint_input,model_settings,FLAGS.model_architecture,FLAGS.model_size_info,is_training=False)ground_truth_input = tf.placeholder(tf.float32, [None, model_settings['label_count']], name='groundtruth_input')predicted_indices = tf.argmax(logits, 1)expected_indices = tf.argmax(ground_truth_input, 1)correct_prediction = tf.equal(predicted_indices, expected_indices)confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices)evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))models.load_variables_from_checkpoint(sess, FLAGS.checkpoint)saver = tf.train.Saver(tf.global_variables())tf.logging.info('Folding batch normalization layer parameters to preceding layer weights/biases')#epsilon added to variance to avoid division by zeroepsilon  = 1e-3 #default epsilon for tf.slim.batch_norm all_variables = [v.name for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)]weight_list = ['Variable:0' if i == 0 else 'Variable_'+str(i*2)+':0' for i in range(3)]biase_list = ['Variable_'+str(2*i+1)+':0' for i in range(3)]#get batch_norm meanmean_variables = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)if 'moving_mean' in v.name]for i, mean_var in enumerate(mean_variables):mean_name = mean_var.namemean_values = sess.run(mean_var)variance_name = mean_name.replace('moving_mean','moving_variance')variance_var = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == variance_name][0]variance_values = sess.run(variance_var)beta_name = mean_name.replace('moving_mean','beta')beta_var = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == beta_name][0]beta_values = sess.run(beta_var)bias_name = biase_list[i]bias_var = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == bias_name][0]bias_values = sess.run(bias_var)wt_name = weight_list[i]wt_var = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == wt_name][0]wt_values = sess.run(wt_var)#Update weightstf.logging.info('Updating '+wt_name)# 获取带 BN 的每一个维度wt_dim = wt_values.shape[-1]# 在每一个维度上进行计算if i != 2:for l in range(wt_values.shape[3]):for k in range(wt_values.shape[2]):for j in range(wt_values.shape[1]):for x in range(wt_values.shape[0]):# gamma (scale factor) is 1.0wt_values[x][j][k][l] *= 1.0/np.sqrt(variance_values[l]+epsilon)else:for l in range(wt_values.shape[1]):for k in range(wt_values.shape[0]):wt_values[k][l] *= 1.0/np.sqrt(variance_values[l]+epsilon)wt_values = sess.run(tf.assign(wt_var,wt_values))# Update biasestf.logging.info('Updating '+bias_name)biase_dim = wt_values.shape[-1]for l in range(biase_dim):bias_values[l] = (1.0*(bias_values[l]-mean_values[l])/np.sqrt(variance_values[l]+epsilon)) \+ beta_values[l]bias_values = sess.run(tf.assign(bias_var,bias_values))#Write fused weights to ckpt filetf.logging.info('Saving new checkpoint at '+FLAGS.checkpoint+'_bnfused')saver.save(sess, FLAGS.checkpoint+'_bnfused')def main(_):# Create the model and load its weights.fold_batch_norm(FLAGS.wanted_words, FLAGS.sample_rate,FLAGS.clip_duration_ms, FLAGS.window_size_ms,FLAGS.window_stride_ms, FLAGS.dct_coefficient_count,FLAGS.model_architecture, FLAGS.model_size_info)if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--data_url',type=str,# pylint: disable=line-too-longdefault='',# default='http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz',# pylint: enable=line-too-longhelp='Location of speech training data archive on the web.')parser.add_argument('--data_dir',type=str,# default='/tmp/speech_dataset/',default='../../data',help="""Where to download the speech training data to.""")parser.add_argument('--silence_percentage',type=float,default=10.0,help="""\How much of the training data should be silence.""")parser.add_argument('--unknown_percentage',type=float,default=10.0,help="""\How much of the training data should be unknown words.""")parser.add_argument('--testing_percentage',type=int,default=10,help='What percentage of wavs to use as a test set.')parser.add_argument('--validation_percentage',type=int,default=10,help='What percentage of wavs to use as a validation set.')parser.add_argument('--sample_rate',type=int,default=16000,help='Expected sample rate of the wavs',)parser.add_argument('--clip_duration_ms',type=int,default=1000,help='Expected duration in milliseconds of the wavs',)parser.add_argument('--window_size_ms',type=float,default=40.0,help='How long each spectrogram timeslice is',)parser.add_argument('--window_stride_ms',type=float,default=40.0,help='How long each spectrogram timeslice is',)parser.add_argument('--dct_coefficient_count',type=int,default=10,help='How many bins to use for the MFCC fingerprint',)parser.add_argument('--batch_size',type=int,default=100,help='How many items to train with at once',)parser.add_argument('--wanted_words',type=str,default='yes,no,up,down,left,right,on,off,stop,go,nihaoxr,xrxr',help='Words to use (others will be added to an unknown label)',)parser.add_argument('--checkpoint',type=str,default='../train_model/526_cnn/best/cnn_8884.ckpt-13200',help='Checkpoint to load the weights from.')parser.add_argument('--model_architecture',type=str,default='cnn2',help='What model architecture to use')parser.add_argument('--model_size_info',type=int,nargs="+",default=[128,128,128],help='Model dimensions - different for various models')FLAGS, unparsed = parser.parse_known_args()tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

执行:

python3 ./utils/fold_batchnorm_cnn.py \
--data_dir ../data \
--dct_coefficient_count 10 \
--window_size_ms 32 \
--window_stride_ms 20 \
--checkpoint ./train_model/615_cnn_with_32_frame/best/cnn2_9127.ckpt-14000 \
--model_architecture cnn2 \
--model_size_info 28 10 4 1 1 30 10 4 2 1 16 128

推理时 cnn bn 折叠;基于KWS项目相关推荐

  1. 在模型推理时合并BN和Conv层

    我们在这里简单讲解一下,在模型推理时合并BN和Conv层,能够简化网络架构,起到加速模型推理的作用.在模型中,BN层一般置于Conv层之后. Conv: 卷积层的计算简单,公式为: BN: 再来回忆一 ...

  2. 深度学习推理时融合BN,轻松获得约5%的提速

    批归一化(Batch Normalization)因其可以加速神经网络训练.使网络训练更稳定,而且还有一定的正则化效果,所以得到了非常广泛的应用.但是,在推理阶段,BN层一般是可以完全融合到前面的卷积 ...

  3. TF之TFOD-API:基于tensorflow框架利用TFOD-API脚本文件将YoloV3训练好的.ckpt模型文件转换为推理时采用的.pb文件

    TF之TFOD-API:基于tensorflow框架利用TFOD-API脚本文件将YoloV3训练好的.ckpt模型文件转换为推理时采用的frozen_inference_graph.pb文件 目录 ...

  4. 使用基于Roslyn的编译时AOP框架来解决.NET项目的代码复用问题

    理想的代码优化方式 团队日常协作中,自然而然的会出现很多重复代码,根据这些代码的种类,之前可能会以以下方式处理 方式 描述 应用时可能产生的问题 硬编码 多数新手,或逐渐腐坏的项目会这么干,会直接复制 ...

  5. CNN推理时opencv图像Mat数组从HWC转换到CHW方法

    在嵌入式端进行CNN推理时,opencv中Mat数组数据组织格式为HWC,输入到推理框架中,需要转换为CHW格式,可以使用opencv中dnn模块的cv::dnn::blobFromImages或cv ...

  6. CNN应用之基于Overfeat的物体检测-2014 ICLR-未完待续

    转载自:深度学习(二十)CNN应用之基于Overfeat的物体检测-2014 ICLR-未完待续 - hjimce的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/ ...

  7. CNN应用之基于R-CNN的物体检测-CVPR 2014-未完待续

    转载自: 深度学习(十八)CNN应用之基于R-CNN的物体检测-CVPR 2014-未完待续 - hjimce的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/hj ...

  8. 训练和推理阶段的BN和Dropout

    文章目录 前言 Batch Normalization 前向传播 后向传播 推理阶段 Dropout 前向传播 后向传播 推理阶段 扩展 Batch Renormalization Cross-Ite ...

  9. 基于springboot项目中使用docker-compose+es+kibana+logstash+mysql 提高数据查询效率

    基于springboot项目中使用docker-compose+es+kibana+logstash+mysql 提高数据查询效率 1.拉取logstash,kibana,es,mysql镜像 #命令 ...

最新文章

  1. 1036 Boys vs Girls (25 分)_27行代码AC
  2. 凯撒(Caesar)密码加密解密c语言
  3. 5-1 逻辑回归代码(含warning解释)
  4. 用例设计:思维导图orExcel
  5. 如何让 PPT 像 PyeCharts 一样炫酷?
  6. 近十年数据库流行趋势纵览!存储计算分离、ACID 全面回归......
  7. java list打乱排序_JAVA Collections.shuffle打乱列表
  8. springboot整合ssm(mybatis)
  9. 解决办法:对BZ2_bzDecompressInit/BZ2_bzDecompress/BZ2_bzDecompressEnd未定义的引用
  10. win10计算机管理字体糊,完美解决:Win10系统字体模糊解决教程
  11. 计算机辅助设计 Photoshop 教案,计算机辅助设计①Photoshop学习领域课程标准.doc
  12. PHP 生成PDF文件并向PDF添加图片
  13. 《华为你学不会》读书笔记
  14. 深入了解Unity剔除(草稿)
  15. Excel高级功能 数据工具
  16. 运行时错误91问题汇总
  17. 关闭git命令窗快捷键_zsh中git的快捷命令
  18. 大数据在金融领域主要面临哪些风险,应该怎么解决?
  19. Python自动化开发【1】:Python简介和入门
  20. LittleFS移植实践

热门文章

  1. 设计师网页导航 php,设计师必须收藏的7个网址导航
  2. 微软的sdk以及azure_.NET的Azure SDK:关于困难错误搜索的故事
  3. Jquery颜色选择插件使用
  4. 计算机B类会议排名,计算机学科会议排名
  5. Transformer架构:位置编码
  6. 动手学习深度学习——2.7 文档(Pytorch)
  7. activiti6.0(二)节点处理人
  8. 不借助第三方插件,用js画日历
  9. 关于京东畅读卡的盈利模式猜想
  10. 足球联赛赛程表思路(转)