TF版本2.2及以上

def creat_FGM(epsilon=1.0):@tf.function def train_step(self, data):'''计算在embedding上的gradient计算扰动 在embedding上加上扰动重新计算loss和gradient删除embedding上的扰动,并更新参数'''data = data_adapter.expand_1d(data)x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)with tf.GradientTape() as tape:y_pred = model(x,training=True)loss = loss_func(y,y_pred)embedding = model.trainable_variables[0]embedding_gradients = tape.gradient(loss,[model.trainable_variables[0]])[0]embedding_gradients = tf.zeros_like(embedding) + embedding_gradientsdelta = 0.2 * embedding_gradients / (tf.math.sqrt(tf.reduce_sum(embedding_gradients**2)) + 1e-8)  # 计算扰动model.trainable_variables[0].assign_add(delta)with tf.GradientTape() as tape2:y_pred = model(x,training=True)new_loss = loss_func(y,y_pred)gradients = tape2.gradient(new_loss,model.trainable_variables)model.trainable_variables[0].assign_sub(delta)optimizer.apply_gradients(zip(gradients,model.trainable_variables))train_loss.update_state(loss)return {m.name: m.result() for m in self.metrics}return train_step

使用方法

TF2.2 及以上的方法比较简单

model.compile(loss='sparse_categorical_crossentropy',optimizer=tf.keras.optimizers.Adam(0.001),metrics=['acc'],)#替换model.train_step 方法即可,并且删除原有的 train_function方法
train_step = creat_FGM()
model.train_step = functools.partial(train_step, model)
model.train_function = Nonehistory = model.fit(X_train,y_train,epochs=5,validation_data=(X_test,y_test),verbose=1,batch_size=32)

TF版本2.2以下,适用于2.0GPU版本

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_func = tf.losses.SparseCategoricalCrossentropy()
train_loss = tf.metrics.Mean(name='train_loss')ds_train = tf.data.Dataset.from_tensor_slices((X_train,y_train)) \.shuffle(buffer_size = 1000).batch(32) \.prefetch(tf.data.experimental.AUTOTUNE).cache()@tf.function
def train_step(model,x,y,loss_func,optimizer,train_loss):with tf.GradientTape() as tape:y_pred = model(x,training=True)loss = loss_func(y,y_pred)embedding = model.trainable_variables[0]embedding_gradients = tape.gradient(loss,[model.trainable_variables[0]])[0]embedding_gradients = tf.zeros_like(embedding) + embedding_gradientsdelta = 0.2 * embedding_gradients / (tf.math.sqrt(tf.reduce_sum(embedding_gradients**2)) + 1e-8)  # 计算扰动model.trainable_variables[0].assign_add(delta)with tf.GradientTape() as tape2:y_pred = model(x,training=True)new_loss = loss_func(y,y_pred)gradients = tape2.gradient(new_loss,model.trainable_variables)model.trainable_variables[0].assign_sub(delta)optimizer.apply_gradients(zip(gradients,model.trainable_variables))train_loss.update_state(loss)@tf.function
def printbar():ts = tf.timestamp()today_ts = ts%(24*60*60)hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)minite = tf.cast((today_ts%3600)//60,tf.int32)second = tf.cast(tf.floor(today_ts%60),tf.int32)def timeformat(m):if tf.strings.length(tf.strings.format("{}",m))==1:return(tf.strings.format("0{}",m))else:return(tf.strings.format("{}",m))timestring = tf.strings.join([timeformat(hour),timeformat(minite),timeformat(second)],separator = ":")tf.print("=========="*8,end = "")tf.print(timestring)

训练代码

def train_model(model,ds_train,epochs):for epoch in tf.range(1,epochs+1):for x, y in ds_train:train_step(model,x,y,loss_func,optimizer,train_loss)logs = 'Epoch={},Loss:{}'if epoch%1 ==0:printbar()tf.print(tf.strings.format(logs,(epoch,train_loss.result())))tf.print("")train_loss.reset_states()train_model(model,ds_train,10)

以上方法均在小模型上测试完成,由于本人的GPU显存不足,导致无法给出一个BERTbase模型的效果分析,各位可以自己搬运后尝试一下。

对于FGM的介绍可以参考苏神文章:
苏剑林. (2020, Mar 01). 《对抗训练浅谈:意义、方法和思考(附Keras实现) 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/7234

再战FGM!Tensorflow2.0 自定义模型训练实现NLP中的FGM对抗训练 代码实现相关推荐

  1. TensorFlow 2.0 - 自定义模型、训练过程

    文章目录 1. 自定义模型 2. 学习流程 学习于:简单粗暴 TensorFlow 2 1. 自定义模型 重载 call() 方法,pytorch 是重载 forward() 方法 import te ...

  2. Tensorflow2.0 自定义网络

    自定义网络 keras.Sequential 容器 keras.layers.Layer keras.Model Keras.Sequential 容器 网络层的搭建 model = keras.Se ...

  3. Google 最强开源模型 BERT 在 NLP 中的应用 | 技术头条

    作者 | 董文涛 责编 | 唐小引 出品 | CSDN(ID:CSDNnews) [CSDN 编者按]Google 的 BERT 模型一经发布便点燃了 NLP 各界的欢腾,Google Brain 的 ...

  4. NLP 中的语言模型预训练微调

    1 引言 语言模型(Language Model),语言模型简单来说就是一串词序列的概率分布.具体来说,语言模型的作用是为一个长度为m的文本确定一个概率分布P,表示这段文本存在的可能性.在实践中,如果 ...

  5. 从0到1,了解NLP中的文本相似度

    本文由云+社区发表 作者:netkiddy 导语 AI在2018年应该是互联网界最火的名词,没有之一.时间来到了9102年,也是项目相关,涉及到了一些AI写作相关的功能,为客户生成一些素材文章.但是, ...

  6. 从0到1,了解NLP中的文本相似度 1

    导语 AI在2018年应该是互联网界最火的名词,没有之一.时间来到了9102年,也是项目相关,涉及到了一些AI写作相关的功能,为客户生成一些素材文章.但是,AI并不一定最懂你,客户对于AI写出来的文章 ...

  7. sqlite 0转换为bit_Cisco Talos在SQLite中发现了一个远程代码执行漏洞

    思科Talos的研究人员在SQLite中发现了一个use-after-free() 的漏洞,攻击者可利用该漏洞在受影响设备上远程执行代码. 攻击者可以通过向受影响的SQLite安装发送恶意SQL命令来 ...

  8. TensorFlow2.0保存模型

    介绍 模型保存有5种:1.整体保存:2.网络架构保存:3.权重保存:4.回调保存:5.自定义训练模型的保存 1.整体保存:权重值,模型配置(架构),优化器配置 整个模型可以保存到一个文件中,其中包含权 ...

  9. TensorFlow2.0:自定义层与自定义网络

    自定义层函数需要继承layers.Layer,自定义网络需要继承keras.Model. 其内部需要定义两个函数: 1.__init__初始化函数,内部需要定义构造形式: 2.call函数,内部需要定 ...

最新文章

  1. 算法复习——bitset(bzoj3687简单题)
  2. mysql100个优化技巧_MySQL 调优/优化的 100 个建议
  3. cisco 系列时间修改
  4. 给Source Insight做个外挂系列之二--将本地代码注入到Source Insight进程
  5. 使用开源的驰骋表单设计器设计表单案例演示
  6. python中thread的setDaemon、join的用法
  7. ssh oracle id native,hibernate解决oracle的id自增?
  8. mysql jdbc 多数据源_springboot jdbc连接多个数据源
  9. 攻防世界reverse新手练习
  10. kafka 拉取的数据排序_Kafka 源码解析之 Consumer Poll 模型(七)
  11. java毕业实习日志_java毕业实习日记.doc
  12. pythonsort参数_Python sort()函数有哪些参数?
  13. 自定义View 仿QQ运动步数进度效果
  14. Android Studio中Cannot resolve symbol XXX的解决方法
  15. zookeeper 数据节点的增删改查
  16. oneinstack申请免费的R3 域名证书
  17. 51单片机c语言编程100,51单片机C语言编程100例.doc
  18. 数字IC笔记-scan chain 压缩和解压缩
  19. Python金融大数据分析——第11章 统计学(1)正态性检验 笔记
  20. 米思齐Mixly图形化编程---遥控灯

热门文章

  1. hdu_2243_考研路茫茫——单词情结(AC自动机+矩阵)
  2. 对信号集操作函数的使用方法和顺序
  3. request设置请求头_收藏 Scrapy框架各组件详细设置
  4. 电脑故障检测_检测电脑故障的简单方法
  5. 图像处理--线line 提取
  6. ROS | ROS2安装(Ubuntu 16.04版本:通过Debian包安装)
  7. TensorFlow | 使用Tensorflow带你实现MNIST手写字体识别
  8. Verilog | HDL LCD显示(代码类)
  9. 如何在VC中创建动态数组
  10. python汉诺塔问题输入层数输出整个移动流程_python实现汉诺塔方法汇总