再战FGM!Tensorflow2.0 自定义模型训练实现NLP中的FGM对抗训练 代码实现
TF版本2.2及以上
def creat_FGM(epsilon=1.0):@tf.function def train_step(self, data):'''计算在embedding上的gradient计算扰动 在embedding上加上扰动重新计算loss和gradient删除embedding上的扰动,并更新参数'''data = data_adapter.expand_1d(data)x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)with tf.GradientTape() as tape:y_pred = model(x,training=True)loss = loss_func(y,y_pred)embedding = model.trainable_variables[0]embedding_gradients = tape.gradient(loss,[model.trainable_variables[0]])[0]embedding_gradients = tf.zeros_like(embedding) + embedding_gradientsdelta = 0.2 * embedding_gradients / (tf.math.sqrt(tf.reduce_sum(embedding_gradients**2)) + 1e-8) # 计算扰动model.trainable_variables[0].assign_add(delta)with tf.GradientTape() as tape2:y_pred = model(x,training=True)new_loss = loss_func(y,y_pred)gradients = tape2.gradient(new_loss,model.trainable_variables)model.trainable_variables[0].assign_sub(delta)optimizer.apply_gradients(zip(gradients,model.trainable_variables))train_loss.update_state(loss)return {m.name: m.result() for m in self.metrics}return train_step
使用方法
TF2.2 及以上的方法比较简单
model.compile(loss='sparse_categorical_crossentropy',optimizer=tf.keras.optimizers.Adam(0.001),metrics=['acc'],)#替换model.train_step 方法即可,并且删除原有的 train_function方法
train_step = creat_FGM()
model.train_step = functools.partial(train_step, model)
model.train_function = Nonehistory = model.fit(X_train,y_train,epochs=5,validation_data=(X_test,y_test),verbose=1,batch_size=32)
TF版本2.2以下,适用于2.0GPU版本
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_func = tf.losses.SparseCategoricalCrossentropy()
train_loss = tf.metrics.Mean(name='train_loss')ds_train = tf.data.Dataset.from_tensor_slices((X_train,y_train)) \.shuffle(buffer_size = 1000).batch(32) \.prefetch(tf.data.experimental.AUTOTUNE).cache()@tf.function
def train_step(model,x,y,loss_func,optimizer,train_loss):with tf.GradientTape() as tape:y_pred = model(x,training=True)loss = loss_func(y,y_pred)embedding = model.trainable_variables[0]embedding_gradients = tape.gradient(loss,[model.trainable_variables[0]])[0]embedding_gradients = tf.zeros_like(embedding) + embedding_gradientsdelta = 0.2 * embedding_gradients / (tf.math.sqrt(tf.reduce_sum(embedding_gradients**2)) + 1e-8) # 计算扰动model.trainable_variables[0].assign_add(delta)with tf.GradientTape() as tape2:y_pred = model(x,training=True)new_loss = loss_func(y,y_pred)gradients = tape2.gradient(new_loss,model.trainable_variables)model.trainable_variables[0].assign_sub(delta)optimizer.apply_gradients(zip(gradients,model.trainable_variables))train_loss.update_state(loss)@tf.function
def printbar():ts = tf.timestamp()today_ts = ts%(24*60*60)hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)minite = tf.cast((today_ts%3600)//60,tf.int32)second = tf.cast(tf.floor(today_ts%60),tf.int32)def timeformat(m):if tf.strings.length(tf.strings.format("{}",m))==1:return(tf.strings.format("0{}",m))else:return(tf.strings.format("{}",m))timestring = tf.strings.join([timeformat(hour),timeformat(minite),timeformat(second)],separator = ":")tf.print("=========="*8,end = "")tf.print(timestring)
训练代码
def train_model(model,ds_train,epochs):for epoch in tf.range(1,epochs+1):for x, y in ds_train:train_step(model,x,y,loss_func,optimizer,train_loss)logs = 'Epoch={},Loss:{}'if epoch%1 ==0:printbar()tf.print(tf.strings.format(logs,(epoch,train_loss.result())))tf.print("")train_loss.reset_states()train_model(model,ds_train,10)
以上方法均在小模型上测试完成,由于本人的GPU显存不足,导致无法给出一个BERTbase模型的效果分析,各位可以自己搬运后尝试一下。
对于FGM的介绍可以参考苏神文章:
苏剑林. (2020, Mar 01). 《对抗训练浅谈:意义、方法和思考(附Keras实现) 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/7234
再战FGM!Tensorflow2.0 自定义模型训练实现NLP中的FGM对抗训练 代码实现相关推荐
- TensorFlow 2.0 - 自定义模型、训练过程
文章目录 1. 自定义模型 2. 学习流程 学习于:简单粗暴 TensorFlow 2 1. 自定义模型 重载 call() 方法,pytorch 是重载 forward() 方法 import te ...
- Tensorflow2.0 自定义网络
自定义网络 keras.Sequential 容器 keras.layers.Layer keras.Model Keras.Sequential 容器 网络层的搭建 model = keras.Se ...
- Google 最强开源模型 BERT 在 NLP 中的应用 | 技术头条
作者 | 董文涛 责编 | 唐小引 出品 | CSDN(ID:CSDNnews) [CSDN 编者按]Google 的 BERT 模型一经发布便点燃了 NLP 各界的欢腾,Google Brain 的 ...
- NLP 中的语言模型预训练微调
1 引言 语言模型(Language Model),语言模型简单来说就是一串词序列的概率分布.具体来说,语言模型的作用是为一个长度为m的文本确定一个概率分布P,表示这段文本存在的可能性.在实践中,如果 ...
- 从0到1,了解NLP中的文本相似度
本文由云+社区发表 作者:netkiddy 导语 AI在2018年应该是互联网界最火的名词,没有之一.时间来到了9102年,也是项目相关,涉及到了一些AI写作相关的功能,为客户生成一些素材文章.但是, ...
- 从0到1,了解NLP中的文本相似度 1
导语 AI在2018年应该是互联网界最火的名词,没有之一.时间来到了9102年,也是项目相关,涉及到了一些AI写作相关的功能,为客户生成一些素材文章.但是,AI并不一定最懂你,客户对于AI写出来的文章 ...
- sqlite 0转换为bit_Cisco Talos在SQLite中发现了一个远程代码执行漏洞
思科Talos的研究人员在SQLite中发现了一个use-after-free() 的漏洞,攻击者可利用该漏洞在受影响设备上远程执行代码. 攻击者可以通过向受影响的SQLite安装发送恶意SQL命令来 ...
- TensorFlow2.0保存模型
介绍 模型保存有5种:1.整体保存:2.网络架构保存:3.权重保存:4.回调保存:5.自定义训练模型的保存 1.整体保存:权重值,模型配置(架构),优化器配置 整个模型可以保存到一个文件中,其中包含权 ...
- TensorFlow2.0:自定义层与自定义网络
自定义层函数需要继承layers.Layer,自定义网络需要继承keras.Model. 其内部需要定义两个函数: 1.__init__初始化函数,内部需要定义构造形式: 2.call函数,内部需要定 ...
最新文章
- 算法复习——bitset(bzoj3687简单题)
- mysql100个优化技巧_MySQL 调优/优化的 100 个建议
- cisco 系列时间修改
- 给Source Insight做个外挂系列之二--将本地代码注入到Source Insight进程
- 使用开源的驰骋表单设计器设计表单案例演示
- python中thread的setDaemon、join的用法
- ssh oracle id native,hibernate解决oracle的id自增?
- mysql jdbc 多数据源_springboot jdbc连接多个数据源
- 攻防世界reverse新手练习
- kafka 拉取的数据排序_Kafka 源码解析之 Consumer Poll 模型(七)
- java毕业实习日志_java毕业实习日记.doc
- pythonsort参数_Python sort()函数有哪些参数?
- 自定义View 仿QQ运动步数进度效果
- Android Studio中Cannot resolve symbol XXX的解决方法
- zookeeper 数据节点的增删改查
- oneinstack申请免费的R3 域名证书
- 51单片机c语言编程100,51单片机C语言编程100例.doc
- 数字IC笔记-scan chain 压缩和解压缩
- Python金融大数据分析——第11章 统计学(1)正态性检验 笔记
- 米思齐Mixly图形化编程---遥控灯
热门文章
- hdu_2243_考研路茫茫——单词情结(AC自动机+矩阵)
- 对信号集操作函数的使用方法和顺序
- request设置请求头_收藏 Scrapy框架各组件详细设置
- 电脑故障检测_检测电脑故障的简单方法
- 图像处理--线line 提取
- ROS | ROS2安装(Ubuntu 16.04版本:通过Debian包安装)
- TensorFlow | 使用Tensorflow带你实现MNIST手写字体识别
- Verilog | HDL LCD显示(代码类)
- 如何在VC中创建动态数组
- python汉诺塔问题输入层数输出整个移动流程_python实现汉诺塔方法汇总