由于bert模型参数很大,在用到生产环境中推理效率难以满足要求,因此经常需要将模型进行压缩。常用的模型压缩的方法有剪枝、蒸馏和量化等方法。比较容易实现的方法为知识蒸馏,下面便介绍如何将bert模型进行蒸馏。

一、知识蒸馏原理

模型蒸馏的目的是用一个小模型去学习大模型的知识,让小模型的效果接近大模型的效果,小模型被称为student,大模型被称为teacher。

知识蒸馏的实现可以根据teacher和student的网络结构的不同设计不同的蒸馏步骤,基本结构如下所示:

损失函数需要计算两个部分,cross entropy loss和mse loss,计算的时候需要注意有soft target和hard target。有两个参数需要定义,通过这两个参数对student和teacher进行拟合。其中一个是温度(T),对logits进行缩放。另一个是权重\alpha,用来计算加权损失。hard target就是原始的标注标签。soft target计算公式如下:

加权损失计算如下:

二、将simBert模型蒸馏到simase孪生网络上

蒸馏的步骤示意图可以参考下图:

核心代码如下:

class Distill_model(tf.keras.Model):'''使用dssm进行知识蒸馏'''def __init__(self,config,teacher_network,vocab_size,word_vectors,**kwargs):self.config = configself.vocab_size = vocab_sizeself.word_vectors = word_vectors#冻结teacher network的参数for layer in teacher_network.layers:layer.trainable = False#定义学生模型输入query = tf.keras.layers.Input(shape=(None,), dtype=tf.int64, name='input_x_ids')sim_query = tf.keras.layers.Input(shape=(None,), dtype=tf.int64, name='input_y_ids')#定义老师模型输入word_ids_a = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_word_ids_a')mask_a = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_mask_a')type_ids_a = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_type_ids_a')word_ids_b = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_word_ids_b')mask_b = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_mask_b')type_ids_b = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_type_ids_b')input_a = [word_ids_a, mask_a, type_ids_a]input_b = [word_ids_b, mask_b, type_ids_b]teacher_input = [input_a, input_b]#teacher_softlabelteacher_output = teacher_network(teacher_input)teacher_soft_label = softmax_t(self.config['t'], teacher_output['logits'])# embedding层# 利用词嵌入矩阵将输入数据转成词向量,shape=[batch_size, seq_len, embedding_size]class GatherLayer(tf.keras.layers.Layer):def __init__(self, config, vocab_size, word_vectors):super(GatherLayer, self).__init__()self.config = configself.vocab_size = vocab_sizeself.word_vectors = word_vectorsdef build(self, input_shape):with tf.name_scope('embedding'):if not self.config['use_word2vec']:self.embedding_w = tf.Variable(tf.keras.initializers.glorot_normal()(shape=[self.vocab_size, self.config['embedding_size']],dtype=tf.float32), trainable=True, name='embedding_w')else:self.embedding_w = tf.Variable(tf.cast(self.word_vectors, tf.float32), trainable=True,name='embedding_w')self.build = Truedef call(self, inputs, **kwargs):return tf.gather(self.embedding_w, inputs, name='embedded_words')def get_config(self):config = super(GatherLayer, self).get_config()return configshared_net = tf.keras.Sequential([GatherLayer(config, vocab_size, word_vectors),shared_lstm_layer(config)])query_embedding_output = shared_net.predict_step(query)sim_query_embedding_output = shared_net.predict_step(sim_query)# 余弦函数计算相似度# cos_similarity余弦相似度[batch_size, similarity]query_norm = tf.sqrt(tf.reduce_sum(tf.square(query_embedding_output), axis=-1), name='query_norm')sim_query_norm = tf.sqrt(tf.reduce_sum(tf.square(sim_query_embedding_output), axis=-1), name='sim_query_norm')dot = tf.reduce_sum(tf.multiply(query_embedding_output, sim_query_embedding_output), axis=-1)cos_similarity = tf.divide(dot, (query_norm * sim_query_norm), name='cos_similarity')self.similarity = cos_similarity# 预测为正例的概率cond = (self.similarity > self.config["neg_threshold"])pos = tf.where(cond, tf.square(self.similarity), 1 - tf.square(self.similarity))neg = tf.where(cond, 1 - tf.square(self.similarity), tf.square(self.similarity))predictions = [[neg[i], pos[i]] for i in range(self.config['batch_size'])]self.logits = self.similaritystudent_soft_label = softmax_t(self.config['t'], self.logits)student_hard_label = self.logitsif self.config['is_training']:#训练时候蒸馏outputs = dict(student_soft_label=student_soft_label, student_hard_label=student_hard_label, teacher_soft_label=teacher_soft_label, predictions=predictions)super(Distill_model, self).__init__(inputs=[query, sim_query, teacher_input], outputs=outputs, **kwargs)else:#预测时候只加载学生模型outputs = dict(predictions=predictions)super(Distill_model, self).__init__(inputs=[query, sim_query], outputs=outputs, **kwargs)

其中比较重要的步骤就是先冻结teacher模型的参数使其不参与训练:

#冻结teacher network的参数
for layer in teacher_network.layers:layer.trainable = False

然后在预测阶段只加载student模型:

#预测时候只加载学生模型
outputs = dict(predictions=predictions)
super(Distill_model, self).__init__(inputs=[query, sim_query], outputs=outputs, **kwargs)

然后是loss的计算:

# mse损失计算y = tf.reshape(labels, (-1,))student_soft_label = model_outputs['student_soft_label']teacher_soft_label = model_outputs['teacher_soft_label']mse_loss = tf.keras.losses.mean_squared_error(teacher_soft_label, student_soft_label)#ce损失计算similarity = model_outputs['student_hard_label']cond = (similarity < self.config["neg_threshold"])zeros = tf.zeros_like(similarity, dtype=tf.float32)ones = tf.ones_like(similarity, dtype=tf.float32)squre_similarity = tf.square(similarity)neg_similarity = tf.where(cond, squre_similarity, zeros)pos_loss = y * (tf.square(ones - similarity) / 4)neg_loss = (ones - y) * neg_similarityce_loss = pos_loss+neg_losslosses = self.config['alpha']*mse_loss + (1-self.config['alpha'])*ce_lossloss = tf.reduce_mean(losses)

三、总结

知识蒸馏作为一个模型压缩的方法,优点还是很多的,实现起来方便,也可以在样本数量少的情况下使用。

参考文章:

模型蒸馏原理和bert模型蒸馏以及theseus压缩实战_colourmind的博客-CSDN博客_模型蒸馏

bert模型蒸馏实战相关推荐

  1. BERT模型蒸馏有哪些方法?

    ©PaperWeekly 原创 · 作者|蔡杰 学校|北京大学硕士生 研究方向|问答系统 我们都知道预训练模型的标准范式: pretrain-利用大量的未标记数据通过一些自监督的学习方式学习丰富的语义 ...

  2. BERT知识蒸馏TinyBERT

    1. 概述 诸如BERT等预训练模型的提出显著的提升了自然语言处理任务的效果,但是随着模型的越来越复杂,同样带来了很多的问题,如参数过多,模型过大,推理事件过长,计算资源需求大等.近年来,通过模型压缩 ...

  3. 手写数字识别中多元分类原理_广告行业中那些趣事系列:从理论到实战BERT知识蒸馏...

    导读:本文将介绍在广告行业中自然语言处理和推荐系统实践.本文主要分享从理论到实战知识蒸馏,对知识蒸馏感兴趣的小伙伴可以一起沟通交流. 摘要:本篇主要分享从理论到实战知识蒸馏.首先讲了下为什么要学习知识 ...

  4. 【周末送新书】基于BERT模型的自然语言处理实战

    如果你是一名自然语言处理从业者,那你一定听说过大名鼎鼎的 BERT 模型. BERT(Bidirectional Encoder Representations From Transformers)模 ...

  5. 详解谷歌最强NLP模型BERT(理论+实战)

    作者:李理,环信人工智能研发中心vp,十多年自然语言处理和人工智能研发经验.主持研发过多款智能硬件的问答和对话系统,负责环信中文语义分析开放平台和环信智能机器人的设计与研发. 本文是作者正在编写的&l ...

  6. 从源码到实战:BERT模型训练营

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 开课吧教育 方向:NLP 之 BERT实战 都说BERT模型开启了NLP的新时代,更有" ...

  7. 使用DistilBERT 蒸馏类 BERT 模型的代码实现

    来源:DeepHub IMBA 本文约2700字,建议阅读9分钟 本文带你进入Distil细节,并给出完整的代码实现.本文为你详细介绍DistilBERT,并给出完整的代码实现. 机器学习模型已经变得 ...

  8. 独家 | 用spaCy蒸馏BERT模型

    作者:YVES PEIRSMAN 翻译:詹荣辉 校对:闫晓雨 本文约2800字,建议阅读7分钟. 本文为大家介绍了用spaCy对BERT进行模型蒸馏,其性能也能接近BERT. Photo on Blo ...

  9. 独家 | 基于知识蒸馏的BERT模型压缩

    作者:孙思琦.成宇.甘哲.刘晶晶 本文约1800字,建议阅读5分钟. 本文为你介绍"耐心的知识蒸馏"模型. 数据派THU后台回复"191010",获取论文地址. ...

最新文章

  1. suse linux显示乱码,open suse11.4中文乱码问题
  2. python监控数据库_【Python】NavicatPre查询日志监控并转存数据库
  3. Android之使用HTTP协议的Get/Post方式向服务器提交数据
  4. 使用MSBuild实现完整daily build流程 .
  5. 汉字编码表(五笔编码表)
  6. vue3.0项目服务器部署
  7. html语言标示,HTML语言剖析(二) HTML标记一览
  8. CROSS APPLY 和OUTER APPLY 的区别
  9. [渝粤教育] 西南科技大学 大学物理 在线考试复习资料
  10. EDA软件_Protel99se导出坐标教程
  11. 6183. 字符串的前缀分数和(每日一难phase2--day18)
  12. 思岚S2激光雷达1—初次连接
  13. 计算机应用安装不了软件总被隔离,电脑安装软件时显示此程序被组策略阻止的解决方法...
  14. mysql基于PHP的校园竞赛信息网站 毕业设计源码221230
  15. 联想拯救者笔记本电脑亮度无法调节解决办法
  16. NLP初学-简易聊天机器人
  17. 电商运营到底做什么?说出来你也不信。
  18. 「PAT乙级真题解析」Basic Level 1053 住房空置率 (问题分析+完整步骤+伪代码描述+提交通过代码)
  19. 程序启动,遇到Process finished with exit code 1 解决方法
  20. CSS 文本超出溢出显示省略号...

热门文章

  1. 密码找回安全总结-业务安全测试实操(29)
  2. Springboot项目配置oracle数据库
  3. mysql误删数据后 快速恢复的办法
  4. 外企面试的常见英语表达6
  5. MySQL与Java数据类型对应关系
  6. python 中文转拼音原理_【Python】 汉字转化汉语拼音pinyin
  7. 广州外援斯贝茨被CBA公司停赛4场 罚款10万元
  8. java如何让数字占四位,int占4字节,一数占一字符,为什么int能表示5位以上的数字?...
  9. [转载] STM32的Vcap的问题及解决
  10. webpack中对js进行转译