跑一个GAN DEMO , 运行时出错。
出错代码:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 1]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

原因是:
错误来源于 PYTORCH的版本不同(我的运行版本是1.8.1, 源代码出自1.4.1版本), 内置的BACKWARD的流程发生了变化,

原始代码:

# 使用 GAN 生成一个类似二次曲线
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# torch.manual_seed(1)       # reproducible
# np.random.seed(1)# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001       # learning rate for generator
LR_D = 0.0001       # learning rate for discriminator
N_IDEAS = 5         # think of this as number of ideas for generating an art work(Generator)
ART_COMPONENTS = 15 # it could be total point G can drew in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])
# show our beautiful painting range
plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='#74BCFF', lw=3, label='standard curve')
plt.legend(loc='best')
plt.show()def artist_works():    # painting from the famous artist (real target)#a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]r = 0.02 * np.random.randn(1, ART_COMPONENTS)paintings = np.sin(PAINT_POINTS * np.pi) + rpaintings = torch.from_numpy(paintings).float()return paintings
r = 0.02 * np.random.randn(1, ART_COMPONENTS)
paintings = np.sin(PAINT_POINTS * np.pi) + r
plt.plot(PAINT_POINTS[0],paintings[0])
plt.show()G = nn.Sequential(                  # Generatornn.Linear(N_IDEAS, 128),        # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
)D = nn.Sequential(                  # Discriminatornn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(),                   # tell the probability that the art work is made by artist
)
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)plt.ion()    # something about continuous plottingD_loss_history = []
G_loss_history = []
for step in range(10000):artist_paintings = artist_works()          # real painting from artist , shape of [64,15]G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas, shape of [64, 5]G_paintings = G(G_ideas)                   # fake painting from G (random ideas),  G_paintings.shape= [64,15])prob_artist0 = D(artist_paintings)         # D try to increase this prob, prob_artist0.shape =[64,1]prob_artist1 = D(G_paintings)              # D try to reduce this prob, prob_artist1.shape =[64,1]D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))D_loss_history.append(D_loss)G_loss_history.append(G_loss)opt_D.zero_grad()D_loss.backward(retain_graph=True)    # reusing computational graphopt_D.step()opt_G.zero_grad()G_loss.backward()opt_G.step()if step % 50 == 0:  # plottingplt.cla()plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='#74BCFF', lw=3, label='standard curve')plt.text(-1, 0.75, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 8})plt.text(-1, 0.5, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 8})plt.ylim((-1, 1));plt.legend(loc='lower right', fontsize=10);plt.draw();plt.pause(0.01)plt.ioff()
plt.show()

修改后的代码

# 使用 GAN 生成一个类似二次曲线
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# torch.manual_seed(1)       # reproducible
# np.random.seed(1)# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001       # learning rate for generator
LR_D = 0.0001       # learning rate for discriminator
N_IDEAS = 5         # think of this as number of ideas for generating an art work(Generator)
ART_COMPONENTS = 15 # it could be total point G can drew in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])
# show our beautiful painting range
plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='#74BCFF', lw=3, label='standard curve')
plt.legend(loc='best')
plt.show()def artist_works():    # painting from the famous artist (real target)#a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]r = 0.02 * np.random.randn(1, ART_COMPONENTS)paintings = np.sin(PAINT_POINTS * np.pi) + rpaintings = torch.from_numpy(paintings).float()return paintings
r = 0.02 * np.random.randn(1, ART_COMPONENTS)
paintings = np.sin(PAINT_POINTS * np.pi) + r
plt.plot(PAINT_POINTS[0],paintings[0])
plt.show()G = nn.Sequential(                  # Generatornn.Linear(N_IDEAS, 128),        # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
)D = nn.Sequential(                  # Discriminatornn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(),                   # tell the probability that the art work is made by artist
)
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)plt.ion()    # something about continuous plottingD_loss_history = []
G_loss_history = []
for step in range(10000):artist_paintings = artist_works()          # real painting from artist , shape of [64,15]G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas, shape of [64, 5]G_paintings = G(G_ideas)                   # fake painting from G (random ideas),  G_paintings.shape= [64,15])prob_artist1 = D(G_paintings)              # D try to reduce this prob, prob_artist1.shape =[64,1]G_loss = torch.mean(torch.log(1. - prob_artist1))opt_G.zero_grad()G_loss.backward()opt_G.step()prob_artist0 = D(artist_paintings)         # D try to increase this prob, prob_artist0.shape =[64,1]# detach here to make sure we don't backprop in G that was already changed.prob_artist1 = D(G_paintings.detach())  # D try to reduce this probD_loss_history.append(D_loss)G_loss_history.append(G_loss)D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))opt_D.zero_grad()D_loss.backward(retain_graph=True)    # reusing computational graphopt_D.step()if step % 50 == 0:  # plottingplt.cla()plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='#74BCFF', lw=3, label='standard curve')plt.text(-1, 0.75, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 8})plt.text(-1, 0.5, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 8})plt.ylim((-1, 1));plt.legend(loc='lower right', fontsize=10);plt.draw();plt.pause(0.01)plt.ioff()
plt.show()

运行结果:


。。。。。。

解决一个GAN训练过程中的报错:one of the variables needed for gradient computation has been modified by an inplace相关推荐

  1. Pytorch Bug解决:RuntimeError:one of the variables needed for gradient computation has been modified

    Pytorch Bug解决:RuntimeError: one of the variables needed for gradient computation has been modified b ...

  2. PyTorch报错“RuntimeError: one of the variables needed for gradient computation has been modified by……”

    1 问题描述 今天在写作DeepLabV3+的代码时,遇到了一个问题, 程序报错: RuntimeError: one of the variables needed for gradient com ...

  3. 【完美解决】RuntimeError: one of the variables needed for gradient computation has been modified by an inp

    正文在后面,往下拉即可~~~~~~~~~~~~ 欢迎各位深度学习的小伙伴订阅的我的专栏 Pytorch深度学习·理论篇+实战篇(2023版)专栏地址:

  4. idea安装及项目导入过程中pom报错解决办法

    1.idea安装可用破解版或者在淘宝上买正版账号(20多块钱一年,挺便宜的) 2.pom报错解决办法: 在项目导入过程中pom报错: 原因:相关jar包未下载完.(下载速度慢,因为下载的链接是国外的节 ...

  5. mysql字段超长会报错吗_MySQL使用过程中的报错处理(持续更新)

    MySQL使用过程中的报错处理(持续更新) 一.数据库初始化 1.Percona的MySQL 5.6.20版本数据库初始化 初始化命令(MySQL 5.6版本不适用mysqld命令进行初始化) ./s ...

  6. 搭建ADG过程中复制报错 RMAN-03009 ORA-03113

    搭建ADG过程中复制报错 RMAN-03009 ORA-03113 猜测主备之间网络路由过多导致... 开启mrp进程报错 发现数据文件是主库ASM的路径,备库是单机的 switch database ...

  7. elasticsearch部署过程中各种报错解析

    elasticsearch.bootstrap.StartupException: java.lang.RuntimeException: can not run elasticsearch as r ...

  8. 【实验】主题建模工具BERTopic的安装及使用过程中的报错解决方案

    代码网址:https://github.com/MaartenGr/BERTopic 安装BERTopic Package 在本地Pycharm新建一个项目,安装bertopic包的时候出现报错,找不 ...

  9. Exchange2010安装过程中先决条件报错得处理方法

    摘要: 5 个项目.2 个成功,3 个失败.已用时间: 00:00:45 组织先决条件失败错误: 需要为 Exchange Server 准备 Active Directory,并且此计算机上没有安装 ...

最新文章

  1. 限制用户对页的访问php,如何限制对Django中管理页的访问?
  2. java 重写set方法_Java程序设计-方法的重写(override)(笔记)
  3. leetcode 3.无重复字符的最长子串(中等)
  4. 自制H3C交换机CONSOLE线
  5. vue 带全选和多选的表格怎么写_EXCEL五分钟,批量制作带照片的工地出入证
  6. Windows应急响应操作手册
  7. Apollo测试通知登记
  8. Mac中使用svn进行项目管理
  9. linux_nmon监控教程,如何使用Nmon监控Linux系统性能
  10. PDA平台上MessageBox和SIP的冲突
  11. Powershell ——findstr
  12. EMUELEC游戏添加删除工具
  13. Android学习路线指南
  14. 找不到本地计算机策略组,Win10家庭版找不到本地组策略gpedit.msc解决办法
  15. 常用Linux版本虚拟机的使用比较
  16. 第15周项目二—洗牌(1)
  17. KVO实现机制 如何自己动手实现 KVO
  18. C++坦克大战(新手)
  19. echart实现地图的逐级钻取
  20. 肠胃不好吃什么调理?

热门文章

  1. 计算机课 趣味导入,计算机趣味编程教案
  2. 攻防世界can_has_stdio?
  3. 统计字符串中每个单词出现的次数
  4. 入门 labelImg 数据标注
  5. 用java写一个SM2算法
  6. kettle学习笔记
  7. Allegro Design Entry CIS 和 Orcad Capture CIS 关系
  8. 高能预警 | 找到你的苍老师???
  9. Cadence Allegro导出DXF文件图文教程及视频演示
  10. vscode restclient 插件