文章目录

  • 1 实验目标
  • 2 实验流程
    • 2.1 搭建LeNet训练,测试准确度
    • 2.2 fgsm生成对抗样本
    • 2.3 探究不同epsilon值对分类准确度的影响
  • 3 实验结果
  • 4 完整代码

1 实验目标

  1. pytorch实现fgsm attack
  2. 原始样本、对抗样本与对抗扰动的可视化
  3. 探究不同epsilon值对accuracy的影响

2 实验流程

  1. 搭建LeNet网络训练MNIST分类模型,测试准确率。
  2. 生成不同epsilon值的对抗样本,送入训练好的模型,再次测试准确率,得到结果

2.1 搭建LeNet训练,测试准确度

导入pytorch必要库

import os.pathimport torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoaderimport numpy as np
import matplotlib.pyplot as plt

加载torchvision中的MNIST数据集

train_data = torchvision.datasets.MNIST(root='data',train=True,download=True,transform=transforms.ToTensor()
)
test_data = torchvision.datasets.MNIST(root='data',train=False,download=True,transform=transforms.ToTensor()
)batch_size = 64train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size)

matplotlib展示MNIST图像

plt.figure(figsize=(8, 8))
iter_dataloader = iter(test_dataloader)n=1# 取出n*batch_size张图片可视化
for i in range(n):images, labels = next(iter_dataloader)image_grid = torchvision.utils.make_grid(images)plt.subplot(1, n, i+1)plt.imshow(np.transpose(image_grid.numpy(), (1, 2, 0)))


转移到GPU训练

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

搭建LeNet网络

class LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.conv = nn.Sequential(nn.Conv2d(1,6,3,stride=1,padding=1),nn.MaxPool2d(2,2),nn.Conv2d(6,16,5,stride=1,padding=1),nn.MaxPool2d(2,2))self.fc = nn.Sequential(nn.Linear(576,120),nn.Linear(120,84),nn.Linear(84,10))def forward(self,x):out = self.conv(x)out = out.view(out.size(0),-1)out = self.fc(out)return out

定义训练函数

def train(network):losses = []iteration = 0epochs = 10for epoch in range(epochs):loss_sum = 0for i, (X, y) in enumerate(train_dataloader):X, y = X.to(device), y.to(device)pred = network(X)loss = loss_fn(pred, y)loss_sum += loss.item()optimizer.zero_grad()loss.backward()optimizer.step()mean_loss = loss_sum / len(train_dataloader.dataset)losses.append(mean_loss)iteration += 1print(f"Epoch {epoch+1} loss: {mean_loss:>7f}")# 训练完毕保存最后一轮训练的模型torch.save(network.state_dict(), "model.pth")# 绘制损失函数曲线plt.xlabel("Epochs")plt.ylabel("Loss Value")plt.plot(list(range(iteration)), losses)
network = LeNet()
network.to(device)loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=network.parameters(), lr=0.001, momentum=0.9)if os.path.exists('model.pth'):network.load_state_dict(torch.load('model.pth'))
else:train(network)

得到损失值与损失值图像


对模型进行测试,得到准确度

positive = 0
negative = 0
for X, y in test_dataloader:with torch.no_grad():X, y = X.to(device), y.to(device)pred = network(X)for item in zip(pred, y):if torch.argmax(item[0]) == item[1]:positive += 1else:negative += 1
acc = positive / (positive + negative)
print(f"{acc * 100}%")

2.2 fgsm生成对抗样本

# 寻找对抗样本,并可视化
eps = [0.01, 0.05, 0.1, 0.2, 0.5]for X, y in test_dataloader:X, y = X.to(device), y.to(device)X.requires_grad = Truepred = network(X)network.zero_grad()loss = loss_fn(pred, y)loss.backward()plt.figure(figsize=(15, 8))plt.subplot(121)image_grid = torchvision.utils.make_grid(torch.clamp(X.grad.sign(), 0, 1))plt.imshow(np.transpose(image_grid.cpu().numpy(), (1, 2, 0)))X_adv = X + eps[2] * X.grad.sign()X_adv = torch.clamp(X_adv, 0, 1)plt.subplot(122)image_grid = torchvision.utils.make_grid(X_adv)plt.imshow(np.transpose(image_grid.cpu().numpy(), (1, 2, 0)))break

左图为对抗扰动,右图为对抗样本

2.3 探究不同epsilon值对分类准确度的影响

# 用对抗样本替代原始样本,测试准确度
# 探究不同epsilon对LeNet分类准确度的影响
acc_list = []
for epsilon in eps:for X, y in test_dataloader:X, y = X.to(device), y.to(device)X.requires_grad = Truepred = network(X)network.zero_grad()loss = loss_fn(pred, y)loss.backward()X = X + epsilon * X.grad.sign()X_adv = torch.clamp(X, 0, 1)pred = network(X_adv)for item in zip(pred, y):if torch.argmax(item[0]) == item[1]:positive += 1else:negative += 1acc = positive / (positive + negative)print(f"epsilon={epsilon} acc: {acc * 100}%")acc_list.append(acc)plt.xlabel("epsilon")
plt.ylabel("Accuracy")
plt.plot(eps, acc_list, marker='o')


3 实验结果

对同一个分类模型来说,随着epsilon的增加,fgsm生成的对抗样本使得分类准确度减小

4 完整代码

github: https://github.com/RyanKao2001/FGSM-MNIST

FGSM生成对抗样本(MNIST数据集)Pytorch代码实现与实验分析相关推荐

  1. 论文盘点:GAN生成对抗样本的方法解析

    ©PaperWeekly 原创 · 作者|孙裕道 学校|北京邮电大学博士生 研究方向|GAN图像生成.情绪对抗样本生成 引言 对抗样本的生成方式很多.一般情况下会分成三大类,第一种是基于梯度的生成方式 ...

  2. 利用python实现深度学习生成对抗样本模型,为任一图片加扰动并恢复原像素的全流程记录

    利用python实现深度学习生成对抗样本,为任一图片加扰动并恢复原像素 一.前言 (一)什么是深度学习 (二)什么是样本模型 (三)什么是对抗样本 1.对抗的目的 2.谁来对抗? 3.对抗的敌人是谁? ...

  3. 对抗攻击之利用水印生成对抗样本

    本文为52CV粉丝鬼道投稿,介绍了对抗学习领域最新的工作Adv-watermark. 论文标题:Adv-watermark: A Novel Watermark Perturbation for Ad ...

  4. ACL2020 | 使用强化学习为机器翻译生成对抗样本

    2020-07-12 03:08:49 本文介绍的是 ACL 2020 论文<A Reinforced Generation of Adversarial Examples for Neural ...

  5. GAN生成对抗网络基本概念及基于mnist数据集的代码实现

    本文主要总结了GAN(Generative Adversarial Networks) 生成对抗网络的基本原理并通过mnist数据集展示GAN网络的应用. GAN网络是由两个目标相对立的网络构成的,在 ...

  6. pytorch基于GAN生成对抗网络的数据集扩充

    文章目录 前言 一.GAN基本原理 1.结构图 2.目标函数 二.实现 1.实现流程图 2.实例 2.1采集少量原始数据 2.2GAN模型训练(注意修改图片路径) 2.3用训练好的模型扩充数据集(生成 ...

  7. 原始GAN-pytorch-生成MNIST数据集(代码)

    文章目录 原始GAN生成MNIST数据集 1. Data loading and preparing 2. Dataset and Model parameter 3. Result save pat ...

  8. PyTorch FGSM Attack 对抗样本生成

    要阅读 带有插图的文章版本 请前往 http://studyai.com/pytorch-1.4/beginner/fgsm_tutorial.html 如果你正在阅读这篇文章,希望你能体会到一些机器 ...

  9. 对抗自编码器AAE——pytorch代码解读试验

    AAE网络结构基本框架如论文中所示: 闲话不多说,直接来学习一下加了注释和微调的基本AAE的代码(初始代码链接github): aae_pytorch_basic.py #!/usr/bin/env ...

最新文章

  1. Django --ORM常用的字段和参数 多对多创建形式
  2. 优化一个小时不出结果的SQL
  3. MOSS 2007 User Profile 系列 索引
  4. EventSource
  5. nyoj187 快速查找素数
  6. 如何用unit test测试controller_如何用电缆故障测试仪冲闪测试确定故障点?
  7. IDEA 导入Weka的Maven依赖jar包
  8. 40个良好用户界面设计Tips
  9. 系统测试主要测试类型
  10. baidu经纬度坐标与google经纬度坐标转换
  11. sue的小球 牛客(区间dp)
  12. Euclid辗转相除法c语言,euclid辗转相除法求greatest common divisor
  13. winmerge多个文件夹生成html,功能强大的文件、文件夹比对工具-WinMerge使用教程
  14. Prometheus 监控案例详解
  15. 空间超分辨率(SISR)领域非常不错的blog/论文(长期更新)
  16. WIN10怎么安装SQL server2000数据库
  17. 虚拟机性能监控、故障处理工具
  18. 技术分享 | 详解SQL加密函数:AES_ENCRYPT()
  19. 常见软件---SQLite3的C语言下使用
  20. 一加6android9上手体验,超级夜景,全速旗舰,一加6T上手体验

热门文章

  1. 电机驱动芯片-L298N介绍
  2. 014-Ambari功能介绍
  3. html制作玫瑰代码,玫瑰花小制作分享-JavaScript(七夕专属浪漫)
  4. SCI期刊收不收费也有门道,你知道吗?
  5. jQuery入门第一章(jQuery初体验)
  6. 用python爬取网易云评论10w+的歌曲名_Python3爬取网易云音乐评论
  7. 7-4 部落(25 分)
  8. “笨办法”学Python 3基础篇 - 函数
  9. Java实现归并排序-有图有真相
  10. Python之TCP Socket网络编程