关注公众号,发现CV技术之美

 1 引言

该论文出自于牛津大学,主要是关于对抗训练的研究。目前已经有研究表明使用单步进行对抗训练会导致一种严重的过拟合现象,在该论文中作者经过理论分析和实验验证重新审视了对抗噪声和梯度剪切在单步对抗训练中的作用。

作者发现对于大的对抗扰动半径可有效避免过拟合现象。基于该观察结果,作者提出了一种随机噪声对抗训练,实验表明该方法不仅提供了单步对抗训练的减少计算开销的好处,而且也不会受到过拟合现象的影响。

论文里没有提供相关源代码,本文最后一节是关于该论文算法的一个简单实现。

论文链接:https://arxiv.org/abs/2202.01181

 2 预备知识

给定一个参数为的分类器,一个对抗扰动集合。如果对于任意的对抗扰动,有,则可以说在点关于对抗扰动集合是鲁棒的。对抗扰动集合的定义为:

为了使得神经网络模型能够在范数具有鲁棒性。对抗训练在数据集上修正类别训练进程并最小化损失函数,其中对抗训练的目标为:

其中是图片分类器的交叉熵损失函数。由于找到内部最大化的最优解是非常困难的,对抗训练最常见的方法就是通过来近似最坏情况下的对抗扰动。虽然这已经被证明可以产生鲁棒性模型,但是计算开销随着迭代数量而线性增加。

因此,当前的工作专注于通过一步逼近内部最大化最优解来降低对抗训练的成本。假设损失函数对于输入的变化是局部线性的,那么可以知道对抗训练内部最大化具有封闭形式的解。

利用这一点提出了,其中对抗扰动遵循梯度符号的方向,等人建议在之前添加一个随机初始化。然而,这两种方法后来都被证明容易受到多步攻击,具体公式表示为:

402 Payment Required

其中,服从概率分布。当是投影到操作,并且是均匀分布,是输入空间的维数。

 3 N-FGSM对抗训练

在进行对抗性训练时,一种常见的做法是将训练期间使用的干扰限制在范围。其背后原理是,在训练期间增加扰动的幅度可能不必要地降低分类精度,因为在测试时不会评估约束球外的扰动。

虽然通过剪裁或限制噪声大小来限制训练期间使用的扰动是一种常见做法,但是由于梯度剪切是在采取梯度上升步骤后执行的,所以剪切点可能不再进行有效的对抗训练。

基于上述动机,作者主要探索梯度剪裁操作和随机步长中噪声的大小在单步方法中获得的鲁棒性的作用。作者本文中提出了一种简单有效的单步对抗训练方法,具体的计算公式如下所示:

其中是从均分布中采样得来。由于不涉及梯度剪裁,可以发它扰动的期望平方范数大于。相关算法流程图,引理和定理的证明如下所示。

引理1(对抗扰动的期望): 已知的对抗扰动如下定义:

其中,分布是均匀分布,并且对抗扰动步长为,则有:

证明:由不等式可知,当时函数)是凹函数,则有:

则以下不等式成立:

以下主要计算期望并将缩写为,具体证明步骤如下所示:

进而则有:

402 Payment Required

证毕。

定理1 令是方法生成的对抗扰动,是方法生成的对抗扰动,是方法生成的对抗扰动,对于任意的,则有以下不等式成立:

证明:由引理1可知:

又因为

402 Payment Required

如果令超参数,,,则有:

402 Payment Required

证毕。

 4 实验结果 

下图表示的是在数据集(左)和(右)上比较和的多步方法在不同的扰动半径下使用神经网络的分类准确率。

可以发现尽管所有方法都达到干净样本的分类精度(虚线),但和单步法之间在鲁棒精度方面存在差距,而且,最重要的是是的计算开销的10倍。

下图表示的是在数据在(左)和(右)上的单步方法与网络在不同扰动半径上的比较。可以发现该论文的方法可以匹配或超过现有技术的结果,同时将计算成本降低3倍。

下图表示的是在训练开始(顶部)和结束(底部)的几个时期,对抗扰动和梯度平均值的可视化图。可以发现当过拟合之后,和无法对对抗扰动进行解释,其梯度也是如此,但是和却可以避免这种情况的发生。

 5 论文代码 

该论文并没有提供源码,以下是在数据集中对论文中代码进行的实现。

import argparse
import logging
import timeimport numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import os
import argparsedef get_args():parser = argparse.ArgumentParser()parser.add_argument('--batch-size', default=100, type=int)parser.add_argument('--data-dir', default='mnist-data', type=str)parser.add_argument('--epochs', default=10, type=int)parser.add_argument('--epsilon', default=0.3, type=float)parser.add_argument('--alpha', default=0.375, type=float)parser.add_argument('--lr-max', default=5e-3, type=float)parser.add_argument('--lr-type', default='cyclic')parser.add_argument('--fname', default='mnist_model', type=str)parser.add_argument('--seed', default=0, type=int)return parser.parse_args()class Flatten(nn.Module):def forward(self, x):return x.view(x.size(0), -1)def mnist_net():model = nn.Sequential(nn.Conv2d(1, 16, 4, stride=2, padding=1),nn.ReLU(),nn.Conv2d(16, 32, 4, stride=2, padding=1),nn.ReLU(),Flatten(),nn.Linear(32*7*7,100),nn.ReLU(),nn.Linear(100, 10))return modelclass Attack_methods(object):def __init__(self, model, X, Y, epsilon, alpha):self.model = modelself.epsilon = epsilonself.X = Xself.Y = Yself.epsilon = epsilonself.alpha = alphadef nfgsm(self):eta = torch.zeros_like(self.X).uniform_(-self.epsilon, self.epsilon)delta = torch.zeros_like(self.X)eta.requires_grad = Trueoutput = self.model(self.X + eta)loss = nn.CrossEntropyLoss()(output, self.Y)loss.backward()grad = eta.grad.detach()delta.data = eta + self.alpha * torch.sign(grad)return deltaclass Adversarial_Trainings(object):def __init__(self, epochs, train_loader, model, opt, epsilon, alpha, iter_num, lr_max, lr_schedule,fname, logger):self.epochs = epochsself.train_loader = train_loaderself.model = modelself.opt = optself.epsilon = epsilonself.alpha = alphaself.iter_num = iter_numself.lr_max = lr_maxself.lr_schedule = lr_scheduleself.fname = fnameself.logger = loggerdef fast_training(self):for epoch in range(self.epochs):start_time = time.time()train_loss = 0train_acc = 0train_n = 0for i, (X, y) in enumerate(self.train_loader):X, y = X.cuda(), y.cuda()lr = self.lr_schedule(epoch + (i + 1) / len(self.train_loader))self.opt.param_groups[0].update(lr=lr)# Generating adversarial exampleadversarial_attack = Attack_methods(self.model, X, y, self.epsilon, self.alpha)delta = adversarial_attack.nfgsm()# Update network parametersoutput = self.model(torch.clamp(X + delta, 0, 1))loss = nn.CrossEntropyLoss()(output, y)self.opt.zero_grad()loss.backward()self.opt.step()train_loss += loss.item() * y.size(0)train_acc += (output.max(1)[1] == y).sum().item()train_n += y.size(0)train_time = time.time()self.logger.info('%d \t %.1f \t %.4f \t %.4f \t %.4f', epoch, train_time - start_time, lr, train_loss/train_n, train_acc/train_n)torch.save(self.model.state_dict(), self.fname)logger = logging.getLogger(__name__)
logging.basicConfig(format='[%(asctime)s] - %(message)s',datefmt='%Y/%m/%d %H:%M:%S',level=logging.DEBUG)def main():args = get_args()logger.info(args)np.random.seed(args.seed)torch.manual_seed(args.seed)torch.cuda.manual_seed(args.seed)mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=args.batch_size, shuffle=True)model = mnist_net().cuda()model.train()opt = torch.optim.Adam(model.parameters(), lr=args.lr_max)if args.lr_type == 'cyclic':lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0]elif args.lr_type == 'flat':lr_schedule = lambda t: args.lr_maxelse:raise ValueError('Unknown lr_type')logger.info('Epoch \t Time \t LR \t \t Train Loss \t Train Acc')adversarial_training = Adversarial_Trainings(args.epochs, train_loader, model, opt, args.epsilon, args.alpha, 40,args.lr_max, lr_schedule, args.fname, logger)adversarial_training.fast_training()if __name__ == "__main__":main()

运行的实验结果如下所示

END

欢迎加入「对抗训练交流群

牛津大学出品:随机噪声对抗训练相关推荐

  1. 资源 | 《GAN实战:生成对抗网络深度学习》牛津大学Jakub著作(附下载)

    来源:专知 本文共1000字,建议阅读5分钟. 本书囊括了关于GAN的定义.训练.变体等,是关于GAN的最好的书籍之一. [ 导读 ]生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可 ...

  2. 泛化性乱弹:从随机噪声、梯度惩罚到虚拟对抗训练

    ©PaperWeekly 原创 · 作者|苏剑林 单位|追一科技 研究方向|NLP.神经网络 提高模型的泛化性能是机器学习致力追求的目标之一.常见的提高泛化性的方法主要有两种:第一种是添加噪声,比如往 ...

  3. CV算法复现(分类算法3/6):VGG(2014年 牛津大学)

    致谢:霹雳吧啦Wz:https://space.bilibili.com/18161609 目录 致谢:霹雳吧啦Wz:https://space.bilibili.com/18161609 1 本次要 ...

  4. 神经网络“炼丹炉”内部构造长啥样?牛津大学博士小姐姐用论文解读

    萧箫 发自 凹非寺 量子位 报道 | 公众号 QbitAI 神经网络就像"炼丹炉"一样,投喂大量数据,或许能获得神奇的效果. "炼丹"成功后,神经网络也能对没见 ...

  5. 最新综述:关于自动驾驶的可解释性(牛津大学)

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨黄浴@知乎 来源丨https://zhuanlan.zhihu.com/p/426573034 ...

  6. 2021首期Nature封面:牛津大学ML算法实现10万高压非晶硅原子的模拟​ | AI日报

    2021首期Nature封面:牛津大学ML算法实现10万高压非晶硅原子的模拟 为了对一般无序结构材料有更深的理解,人们广泛研究了非晶硅在高压条件下的富相行为.然而在和原子打交道的层面上,人们一直需要借 ...

  7. 被誉为「教科书」,牛津大学231页博士论文全面阐述神经微分方程,Jeff Dean点赞...

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 来自:机器之心 牛津大学的这篇博士论文对神经微分方程(NDE)展开了全面细致的研究.论 ...

  8. 牛津大学计算机系主任:人工智能立法重在抓机遇、防危害

    来源:科技日报  作者:郑焕斌 "人工智能立法的重点应在于充分利用AI技术所提供的各种机遇,构建适宜的环境以激励.培育大量AI初创公司和新服务的发展,防范和应对AI技术所带来的各种潜在危害. ...

  9. 牛津大学的研究人员首次在人体植入“闭环”生物电子研究系统

    牛津大学的研究人员植入了一个新型的闭环研究平台,用于研究脚桥核(pedunculopontine nucleus, PPN)在帕金森氏型多系统萎缩(Multiple Systems Atrophy, ...

最新文章

  1. R语言时间序列(time series)分析实战:HoltWinters平滑法预测
  2. java中使用MD5进行计算摘要
  3. 初识ABP vNext(4):vue用户登录菜单权限
  4. 酱油和gbt酱油哪个好_酱油可不是越贵越好?看清瓶身上的5个字,教你1分钟买到好酱油...
  5. php代码里加图片,php如何添加图片
  6. php设置html全局路径_PHPCMS V9 URL去掉或修改/html路径的方法
  7. 实现光晕效果_马自达6车灯升级激光四透镜实现四近四远光
  8. HBase HA完全分布式环境搭建
  9. apache mediawiki 安装_MediaWiki初探:安装及使用入门
  10. 使用 Windows 7 VHD启动计算机
  11. 从零基础入门Tensorflow2.0 ----八、42. 自定义流程
  12. 操作BOM对象的方法
  13. 内核线程、轻量级进程、用户线程三种线程概念解惑(线程≠轻量级进程)
  14. 3600000毫秒等于多少小时_一秒多少毫秒
  15. 简单理解以太网和令牌环网【区别】
  16. 苹果CMSv10系统标签,仿站必备
  17. 计算机表格制作ppt,计算机基础使用excel2003制作表格.ppt
  18. 优启通做服务器系统,系统安装教程1:制作优启通PE启动盘
  19. 计算机管理器鼠标不见了怎么办,电脑的鼠标光标消失了
  20. jzoj6495 死星 (竞赛图五元环)

热门文章

  1. html大作业(含资源)
  2. 程序设计思维 B - 东东学打牌
  3. 自己开发了一个SmartPhone用的手机归属地软件SmartPhone
  4. 论文解析(1)——语义分割(求索ljj解读:A Review on Deep learning Techniques Applied to Semantic Segmentation)(更新中))
  5. 在手机上安装youget_You-get 的安装与使用
  6. nb-lot plc python_基于NB-LOT实现.....
  7. XBOX之Kinect1与2的区别
  8. 向单片机flash中烧录自定义数据的方法
  9. 我推荐亲戚家小孩学编程,差点被打一顿!
  10. 中英双语多语言外贸企业网站源码系统 - HanCMS - 安装部署教程