作者 | 王嘉宁@华师数据学院

整理 | NewBeeNLP

https://blog.csdn.net/qq_36426650/article/details/122807916

大家好,这里是NewBeeNLP。

对抗训练本质是为了提高模型的鲁棒性,一般情况下在传统训练的基础上,添加了对抗训练是可以进一步提升效果的,在比赛打榜、调参时是非常重要的一个trick。对抗训练在CV领域内非常常用,那么在NLP领域如何使用呢?本文简单总结几种常用的对抗训练方法。

对抗训练旨在对原始输入样本 上施加扰动 ,得到对抗样本后用其进行训练:

公式理解:

  • 最大化扰动:挑选一个能使得模型产生更大损失(梯度较大)的扰动量,作为攻击;

  • 最小化损失:根据最大的扰动量,添加到输入样本后,朝着最小化含有扰动的损失(梯度下降)方向更新参数;

这个被构造出来的“对抗样本”并不能具体对应到某个单词,因此,反过来在推理阶段是没有办法通过修改原始输入得到这样的对抗样本。

对抗训练有两个作用,一是 提高模型对恶意攻击的鲁棒性 ,二是 提高模型的泛化能力

在CV任务,根据经验性的结论,对抗训练往往会使得模型在非对抗样本上的表现变差,然而神奇的是,在NLP任务中,模型的泛化能力反而变强了。

常用的几种对抗训练方法有FGSM、FGM、PGD、FreeAT、YOPO、FreeLB、SMART。本文暂时只介绍博主常用的3个方法,分别是 FGMPGDFreeLB 。具体实现时,不同的对抗方法会有差异,但是 从训练速度和代码编辑难易程度的角度考虑,推荐使用FGM和迭代次数较少的PGD

一、FGM算法

  • 首先计算输入样本 (通常为word embedding)的损失函数以及在 处的梯度:;

  • 计算在输入样本的扰动量:,其中 为超参数,默认取1.0;

  • 得到对抗样本:;

  • 根据得到的对抗样本,再次喂入模型中,计算损失,并累积梯度;

  • 恢复原始的word embedding,接着下一个batch。

FGM的代码量很少,只需要自行实现简单的类即可:

import torch
class FGM():def __init__(self, model):self.model = modelself.backup = {} # 用于保存模型扰动前的参数def attack(self, epsilon=1., emb_name='word_embeddings' # emb_name表示模型中embedding的参数名):'''生成扰动和对抗样本'''for name, param in self.model.named_parameters(): # 遍历模型的所有参数 if param.requires_grad and emb_name in name: # 只取word embedding层的参数self.backup[name] = param.data.clone() # 保存参数值norm = torch.norm(param.grad) # 对参数梯度进行二范式归一化if norm != 0 and not torch.isnan(norm): # 计算扰动,并在输入参数值上添加扰动r_at = epsilon * param.grad / normparam.data.add_(r_at)def restore(self, emb_name='word_embeddings' # emb_name表示模型中embedding的参数名):'''恢复添加扰动的参数'''for name, param in self.model.named_parameters(): # 遍历模型的所有参数if param.requires_grad and emb_name in name:  # 只取word embedding层的参数assert name in self.backupparam.data = self.backup[name] # 重新加载保存的参数值self.backup = {}

在训练时,只需要额外添加5行代码:

fgm = FGM(model) # (#1)初始化
for batch_input, batch_label in data:loss = model(batch_input, batch_label) # 正常训练loss.backward() # 反向传播,得到正常的grad# 对抗训练fgm.attack() # (#2)在embedding上添加对抗扰动loss_adv = model(batch_input, batch_label) # (#3)计算含有扰动的对抗样本的lossloss_adv.backward() # (#4)反向传播,并在正常的grad基础上,累加对抗训练的梯度fgm.restore() # (#5)恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()

二、PGD算法

Project Gradient Descent(PGD)是一种迭代攻击算法,相比于普通的FGM 仅做一次迭代,PGD是做多次迭代,每次走一小步,每次迭代都会将扰动投射到规定范围内。形式化描述为:

其中 为扰动约束空间(一个半径为 的球体),原始的输入样本对应的初识点为球心,避免扰动超过球面。迭代多次后,保证扰动在一定范围内,如下图所示:

代码实现如下所示:

import torch
class PGD():def __init__(self, model):self.model = modelself.emb_backup = {}self.grad_backup = {}def attack(self, epsilon=1., alpha=0.3, emb_name='word_embeddings', is_first_attack=False):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:if is_first_attack:self.emb_backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = alpha * param.grad / normparam.data.add_(r_at)param.data = self.project(name, param.data, epsilon)def restore(self, emb_name='word_embeddings'):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name: assert name in self.emb_backupparam.data = self.emb_backup[name]self.emb_backup = {}def project(self, param_name, param_data, epsilon):r = param_data - self.emb_backup[param_name]if torch.norm(r) > epsilon:r = epsilon * r / torch.norm(r)return self.emb_backup[param_name] + rdef backup_grad(self):for name, param in self.model.named_parameters():if param.requires_grad:self.grad_backup[name] = param.grad.clone()def restore_grad(self):for name, param in self.model.named_parameters():if param.requires_grad:param.grad = self.grad_backup[name]
pgd = PGD(model)
K = 3
for batch_input, batch_label in data:# 正常训练loss = model(batch_input, batch_label)loss.backward() # 反向传播,得到正常的gradpgd.backup_grad()# 累积多次对抗训练——每次生成对抗样本后,进行一次对抗训练,并不断累积梯度for t in range(K):pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.dataif t != K-1:model.zero_grad()else:pgd.restore_grad()loss_adv = model(batch_input, batch_label)loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度pgd.restore() # 恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()

三、FreeLB算法

FreeLB针对PGD的多次迭代训练的问题进行了改进:

  • PGD是迭代 次扰动后取最后一次扰动的梯度更新参数,FreeLB是取 次迭代中的平均梯度(将 次迭代转换为类似一个虚拟的batch)。

  • 对抗训练和dropout不能同时使用;

具体的算法流程为:

很明显找到FreeLB与PGD的区别在于累积的方式:

  • FreeLB:通过对 K K K 次梯度的平均累积作为扰动更新

  • PGD:只取最后一次的梯度进行更新

实现流程如下图所示:

其他对抗训练方法,以及更为详细的理论讲解,可参考文末参考文献。

一起交流

想和你一起学习进步!『NewBeeNLP』目前已经建立了多个不同方向交流群(机器学习 / 深度学习 / 自然语言处理 / 搜索推荐 / 图网络 / 面试交流 / 等),名额有限,赶紧添加下方微信加入一起讨论交流吧!(注意一定o要备注信息才能通过)

本文参考资料

[1]

一文搞懂NLP中的对抗训练FGSM/FGM/PGD/FreeAT/YOPO/FreeLB/SMART: https://zhuanlan.zhihu.com/p/103593948

[2]

NLP --- >对抗学习:从FGM, PGD到FreeLB: https://blog.csdn.net/chencas/article/details/103551852/

[3]

【炼丹技巧】功守道:NLP中的对抗训练 + PyTorch实现: https://zhuanlan.zhihu.com/p/91269728

[4]

对抗学习总结:FGSM->FGM->PGD->FreeAT, YOPO ->FreeLb->SMART->LookAhead->VAT: https://blog.csdn.net/weixin_36378508/article/details/116131036

【炼丹之道】NLP中的对抗训练相关推荐

  1. pytorch 对抗样本_【炼丹技巧】功守道:NLP中的对抗训练 + PyTorch实现

    本文分享一个"万物皆可盘"的NLP对抗训练实现,只需要四行代码即可调用.盘他. 最近,微软的FreeLB-Roberta [1] 靠着对抗训练 (Adversarial Train ...

  2. 训练技巧 | 功守道:NLP中的对抗训练 + PyTorch实现

    本文分享一个"万物皆可盘"的 NLP 对抗训练实现,只需要四行代码即可调用.盘他. 作者丨Nicolas 单位丨追一科技AI Lab研究员 研究方向丨信息抽取.机器阅读理解 最近, ...

  3. 【NLP】一文搞懂NLP中的对抗训练

    本文主要串烧了FGSM, FGM, PGD, FreeAT, YOPO, FreeLB, SMART这几种对抗训练方法,希望能使各位大佬炼出的丹药更加圆润有光泽,一颗永流传 简介 对抗训练是一种引入噪 ...

  4. 【NLP】NLP中的对抗训练

    作者 | 王嘉宁@华师数据学院 整理 | NewBeeNLP https://blog.csdn.net/qq_36426650/article/details/122807916 对抗训练本质是为了 ...

  5. 浅谈NLP中的对抗训练方式

    ©作者 | 林远平 单位 | QTrade AI研发中心 研究方向 | 自然语言处理 前言 什么是对抗训练呢?说起"对抗",我们就想起了计算机视觉领域的对抗生成网络(GAN).在计 ...

  6. NLP中的对抗训练(附PyTorch实现)

    对抗样本的基本概念 要认识对抗训练,首先要了解"对抗样本",它首先出现在论文Intriguing properties of neural networks之中.简单来说,它是指对 ...

  7. 一文读懂文本处理中的对抗训练

    作者丨WenZe.Leo 单位丨追一科技AI Lab研究员 背景与研究意义 深度学习技术的快速发展,大幅提升了众多自然语言处理任务(比如文本分类,机器翻译等)的效果,越来越多的深度学习模型被用于现实生 ...

  8. 【综述】NLP 对抗训练(FGM、PGD、FreeAT、YOPO、FreeLB、SMART)

    在对抗训练中关键的是需要找到对抗样本,通常是对原始的输入添加一定的扰动来构造,然后放给模型训练,这样模型就有了识别对抗样本的能力.其中的关键技术在于如果构造扰动,使得模型在不同的攻击样本中均能够具备较 ...

  9. 对抗训练浅谈:意义、方法和思考(附Keras实现)

    ©PaperWeekly 原创 · 作者|苏剑林 单位|追一科技 研究方向|NLP.神经网络 当前,说到深度学习中的对抗,一般会有两个含义:一个是生成对抗网络(Generative Adversari ...

最新文章

  1. Forrester:全球供应商在中国处于领导地位 但本土供应商却在私有云市场蒸蒸日上...
  2. python培训好学吗-python难学吗?为什么上了python培训还是学不会?
  3. [二分查找] 一:子区间界限应当如何确定
  4. 计算机办公应用适合什么工作,有什么软件堪称办公神器,让你每天的工作轻松不累?...
  5. hdu4549 M斐波那契数列
  6. Java工作笔记-使用IDEA开始我的第一个Spring项目
  7. 关于SWT中的Combo类和List类
  8. CSDN博客不能正常发布的问题
  9. linux lamp实验报告,我的LAMP过程
  10. 软件模拟PWM——呼吸灯小程序的理解
  11. 第二章:在HTML中使用JavaScript
  12. [转载] json.dumps()和json.dump()的用法和区别
  13. (转)python3之模块io使用流的核心工具
  14. Ubuntu Software Center has closed unexpectly解决方案
  15. 湖北工业大学计算机导论考试试题,湖北工业大学计算机二级考试时间
  16. 计算机一级c类题库及答案解析,全国计算机一级考试试题题库及答案
  17. Chpater 5 大规模MIMO信道估计与导频设计
  18. matlab算kdj指标,通达信带注释的KDJ指标公式
  19. 白鹭php源码,看源码系列之从运行流程开始-Egret社区-教程文档-白鹭引擎-Egret Engine-免费开源HTML5游戏引擎 - Powered by Discuz!...
  20. 大数据可视化常用图表--简单说

热门文章

  1. 什么活动需要媒体邀约
  2. 利用jquery动态添加和删除表格的一行,并且保存单行数据
  3. LR11 无法弹出ie浏览器,或者ie已停止工作问题的解决方法汇总
  4. python 安装Hugging Face
  5. [经验教程]2022淘宝天猫618定金可以退吗及2022年淘宝天猫618超级红包活动时间是什么时候开始到几月几号结束活动优惠力度大吗?
  6. 王同学的科技周刊(第一期):七夕疯狂搞钱的年轻人,一周赚14万
  7. Android开发_备份短信
  8. MySql条件查询及连接
  9. JAVA数组编程教程,Java入门超经典内部教程-数组
  10. perl注释快捷键_Jupyter Notebook的秘诀,技巧和快捷键