Tensorflow2.0 + Transformers 实现Bert FGM对抗训练惩罚梯度损失函数

  • 前言
  • 变种实现
  • Transformers中的word_embeddings
  • 代码修改
  • 实验效果
  • 总结

前言

之前看了很多关于NLP中应用对抗训练的文章,测试结果都很香,所以想在自己在用的模型上试一试看看能不能提升效果,参考了一些代码找到了pytroch和keras实现,但发现对于tensorflows来说更改训练过程非常繁琐,而且容易出错,如果要配合transformers实现则更难
——transformers的bert类会被打包成一个layers,在调用tf.layers时拿不到里面的embedding_layers,如果有大神知道怎么拿还望留言告知~


变种实现

但最近看了一篇文章对抗训练浅谈:意义、方法和思考(附Keras实现)
文中讲到了FGM的效果等价转化式:

***这里有个小技巧使用了一阶泰勒展开:

并且文章也给出了keras的实现代码:

def sparse_categorical_crossentropy(y_true, y_pred):"""自定义稀疏交叉熵这主要是因为keras自带的sparse_categorical_crossentropy不支持求二阶梯度。"""y_true = K.reshape(y_true, K.shape(y_pred)[:-1])y_true = K.cast(y_true, 'int32')y_true = K.one_hot(y_true, K.shape(y_pred)[-1])return K.categorical_crossentropy(y_true, y_pred)def loss_with_gradient_penalty(y_true, y_pred, epsilon=1):"""带梯度惩罚的loss"""loss = K.mean(sparse_categorical_crossentropy(y_true, y_pred))embeddings = search_layer(y_pred, 'Embedding-Token').embeddingsgp = K.sum(K.gradients(loss, [embeddings])[0].values**2)return loss + 0.5 * epsilon * gpmodel.compile(loss=loss_with_gradient_penalty,optimizer=Adam(2e-5),metrics=['sparse_categorical_accuracy'],
)

但可惜的是在loss_with_gradient_penalty部分仍然需要调用kerasbert的方法search_layer,但transformers中的get_input_embeddings在调用时一直报错(可能我用法不对)


Transformers中的word_embeddings

但是发现tf除了通过tf.layers追踪到相应的参数,还能通过model.variables这个方法,
调用 model.variables (model为你搭建好的模型名称)会返回一个list,里面是按顺序排列好的tensor参数,在这里通过 model.variables[0] 即可找到模型的第一层参数:word_embeddings


可以看到前三组参数分别为word_embeddings, position_embeddings, token_type_embeddings, 这里我们只要取word_embeddings(NLP中对抗训练的扰动对象)即可。


代码修改

因此我们可以简单的修改原来的代码:

def sparse_categorical_crossentropy(y_true, y_pred):y_true = tf.reshape(y_true, tf.shape(y_pred)[:-1])y_true = tf.cast(y_true, tf.int32)y_true = tf.one_hot(y_true, K.shape(y_pred)[-1])return tf.keras.losses.categorical_crossentropy(y_true, y_pred)def loss_with_gradient_penalty(model,epsilon=1):def loss_with_gradient_penalty_2(y_true, y_pred):loss = tf.math.reduce_mean(sparse_categorical_crossentropy(y_true, y_pred))embeddings = model.variables[0]gp = tf.math.reduce_sum(tf.gradients(loss, [embeddings])[0].values**2)return loss + 0.5 * epsilon * gpreturn loss_with_gradient_penalty_2

调用方法:

bert_ner_model.compile(optimizer=optimizer, loss=[loss_with_gradient_penalty(bert_ner_model,1.0)],
metrics=['sparse_categorical_accuracy'])

实验效果

原sparse_cross_entropy结果:

加入惩罚项(epsilon = 1)结果:

加入惩罚项(epsilon = 0.5)结果:


总结

可以看到使用该惩罚梯度损失函数,以为要计算两次梯度,训练时间增加了2倍之多,但模型效果有了一个点左右的提升,而且不容易过拟合。所以看出,尽管无法复原FGM的方法,该用效果差不多的惩罚梯度损失函数,还是可以获得一定的提升(前提是epsilon这个超参要调好)


参考链接:https://spaces.ac.cn/archives/7234

Tensorflow2.0 + Transformers 实现Bert FGM对抗训练惩罚梯度损失函数相关推荐

  1. 再战FGM!Tensorflow2.0 自定义模型训练实现NLP中的FGM对抗训练 代码实现

    TF版本2.2及以上 def creat_FGM(epsilon=1.0):@tf.function def train_step(self, data):'''计算在embedding上的gradi ...

  2. 对抗训练的理解,以及FGM、PGD和FreeLB的详细介绍

    对抗训练基本思想--Min-Max公式 如图所示. 中括号里的含义为,我们要找到一组在样本空间内.使Loss最大的的对抗样本(该对抗样本由原样本x和经过某种手段得到的扰动项r_adv共同组合得到).这 ...

  3. 【综述】NLP 对抗训练(FGM、PGD、FreeAT、YOPO、FreeLB、SMART)

    在对抗训练中关键的是需要找到对抗样本,通常是对原始的输入添加一定的扰动来构造,然后放给模型训练,这样模型就有了识别对抗样本的能力.其中的关键技术在于如果构造扰动,使得模型在不同的攻击样本中均能够具备较 ...

  4. 对抗训练:FGM、FGSM、PGD

    当前,在各大NLP竞赛中,对抗训练已然成为上分神器,尤其是fgm和pgd使用较多,下面来说说吧.对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力. 一.什么是对抗训练? ...

  5. 对抗训练浅谈:意义、方法和思考(附Keras实现)

    ©PaperWeekly 原创 · 作者|苏剑林 单位|追一科技 研究方向|NLP.神经网络 当前,说到深度学习中的对抗,一般会有两个含义:一个是生成对抗网络(Generative Adversari ...

  6. 训练技巧 | 功守道:NLP中的对抗训练 + PyTorch实现

    本文分享一个"万物皆可盘"的 NLP 对抗训练实现,只需要四行代码即可调用.盘他. 作者丨Nicolas 单位丨追一科技AI Lab研究员 研究方向丨信息抽取.机器阅读理解 最近, ...

  7. NLP中的对抗训练(附PyTorch实现)

    对抗样本的基本概念 要认识对抗训练,首先要了解"对抗样本",它首先出现在论文Intriguing properties of neural networks之中.简单来说,它是指对 ...

  8. 浅谈NLP中的对抗训练方式

    ©作者 | 林远平 单位 | QTrade AI研发中心 研究方向 | 自然语言处理 前言 什么是对抗训练呢?说起"对抗",我们就想起了计算机视觉领域的对抗生成网络(GAN).在计 ...

  9. pytorch 对抗样本_【炼丹技巧】功守道:NLP中的对抗训练 + PyTorch实现

    本文分享一个"万物皆可盘"的NLP对抗训练实现,只需要四行代码即可调用.盘他. 最近,微软的FreeLB-Roberta [1] 靠着对抗训练 (Adversarial Train ...

最新文章

  1. mybatis mysql crud_Mybatis的CRUD操作
  2. eigrp 重分布默认路由
  3. datalist,Repeater和Gridview的区别分析
  4. 移动端banner css3(@keyframes )实现
  5. 微信小程序 与后台服务器交互,微信小程序 与后台交互----传递和回传时间
  6. php获取当前设备,Linux_在Linux系统中使用lsblk和blkid显示设备信息的方法,今天我们将会向你展示如何使 - phpStudy...
  7. html-----020----事件
  8. 12bit灰度图像映射到8bit显示及python 实现
  9. uwsgi和nginx的故事
  10. 【翻译】针对多种设备定制Ext JS 5应用程序
  11. 大熊君大话NodeJS之------Net模块
  12. wps-奇数偶数页眉不同设定方法
  13. 游戏因为音效而变得触动人心
  14. QFile读取移动硬盘文件卡死问题
  15. 抽样中误差的相关概念和种类
  16. 墙面有几种装修方法_装修时墙面处理都有哪几种方式?
  17. CDA1级习题复习(3)
  18. 带你一文通透CAN总线相关知识
  19. easyui Grid 的列合计
  20. vector初始化必须设置大小么_如何将西门子S120变频器设置成为可调电压源

热门文章

  1. 聊聊架构设计做些什么来谈如何成为架构师
  2. Normalization,Regularization 和 standardization
  3. POJ 1449 amp; ZOJ 1036 Enigma(简单枚举)
  4. 部分网站公开数据的汇总(2)
  5. 获取apk的package name 和 Activity
  6. win32/mfc/qt 异常处理与总结
  7. ASP.NET AJAX文档-ASP.NET AJAX 概述[翻译](1)
  8. c语言6大设计原则 控制反转,fun6868备用网址-fun6868备用网址
  9. java 广播地址,根据ip地址跟子网掩码获取广播地址的java实现
  10. conda pip安装在哪里_TensorFlow 2.0 安装指南