在网上找了一个wgan的实现代码,在本地跑了以下,效果还可以,我把它封装成一个函数了,感兴趣的朋友可以用一下

不过这个gan生成的是一维数据,对于图片数据可能需要对代码进行一些改变

import numpy as np
import pandas as pd
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFoldtorch.manual_seed(1)
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings("ignore")
import os,sysdef train_model_save_gen(data, ITERS = 600, iter_ctrl=200, use_cuda = False, name_save='', file_name='./save_gen/'):if not os.path.exists(file_name):os.mkdir(file_name)if not os.path.exists('./model_file/'):os.mkdir('./model_file/')FIXED_GENERATOR = FalseLAMBDA = .1CRITIC_ITERS = 5CRITIC_ITERG = 1BATCH_SIZE = len(data)class Generator(nn.Module):def __init__(self, shape1):super(Generator, self).__init__()main = nn.Sequential(nn.Linear(shape1, 1024),nn.ReLU(True),nn.Linear(1024, 512),nn.ReLU(True),nn.Linear(512, 256),nn.ReLU(True),nn.Linear(256, 512),nn.ReLU(True),nn.Linear(512, 1024),nn.Tanh(),nn.Linear(1024, shape1),)self.main = maindef forward(self, noise, real_data):if FIXED_GENERATOR:return noise + real_dataelse:output = self.main(noise)return outputclass Discriminator(nn.Module):def __init__(self, shape1):super(Discriminator, self).__init__()self.fc1 = nn.Linear(shape1, 512)self.relu1 = nn.LeakyReLU(0.2)self.fc2 = nn.Linear(512, 256)self.relu2 = nn.LeakyReLU(0.2)self.fc3 = nn.Linear(256, 256)self.relu3 = nn.LeakyReLU(0.2)self.fc4 = nn.Linear(256, 128)self.relu4 = nn.LeakyReLU(0.2)self.fc5 = nn.Linear(128, 1)def forward(self, inputs):out = self.fc1(inputs)out = self.relu1(out)out = self.fc2(out)out = self.relu2(out)out = self.fc3(out)out = self.relu3(out)out = self.fc4(out)out = self.relu4(out)out = self.fc5(out)return out.view(-1)def weights_init(m):classname = m.__class__.__name__if classname.find('Linear') != -1:m.weight.data.normal_(0.0, 0.02)m.bias.data.fill_(0)elif classname.find('BatchNorm') != -1:m.weight.data.normal_(1.0, 0.02)m.bias.data.fill_(0)def calc_gradient_penalty(netD, real_data, fake_data):alpha = torch.rand(BATCH_SIZE, 1)alpha = alpha.expand(real_data.size())alpha = alpha.cuda() if use_cuda else alphainterpolates = alpha * real_data + ((1 - alpha) * fake_data)if use_cuda:interpolates = interpolates.cuda()interpolates = autograd.Variable(interpolates, requires_grad=True)disc_interpolates = netD(interpolates)gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(disc_interpolates.size()), create_graph=True, retain_graph=True,only_inputs=True)[0]gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDAreturn gradient_penaltynetG = Generator(data.shape[1])netD = Discriminator(data.shape[1])netD.apply(weights_init)netG.apply(weights_init)if use_cuda:netD = netD.cuda()netG = netG.cuda()optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))one = torch.tensor(1, dtype=torch.float)  ###torch.FloatTensor([1])mone = one * -1if use_cuda:one = one.cuda()mone = mone.cuda()### ### ###one_list = np.ones((data.shape[0]))zero_list = np.zeros((data.shape[0]))opt_diff_accuracy_05 = 0.5best_item = 0opt_accuracy = 0all_result = []loss_list = {'D_loss':[], 'G_loss':[]}for iteration in range(ITERS):sys.stdout.write(f'\r进行:{iteration}/{ITERS}')  # \r 默认表示将输出的内容返回到第一个指针,这样的话,后面的内容会覆盖前面的内容。sys.stdout.flush()for p in netD.parameters():p.requires_grad = True# data = inf_train_gen('data_GAN')real_data = torch.FloatTensor(data)if use_cuda:real_data = real_data.cuda()false_data = false_data.cuda()real_data_v = autograd.Variable(real_data)# false_data_v = autograd.Variable(false_data)noise = torch.randn(BATCH_SIZE, data.shape[1])if use_cuda:noise = noise.cuda()noisev = autograd.Variable(noise, volatile=True)fake = autograd.Variable(netG(noisev, real_data_v).data)fake_output = fake.data.cpu().numpy()data = real_data.data.cpu().numpy()for iter_d in range(CRITIC_ITERS):netD.zero_grad()D_real = netD(real_data_v)D_real = D_real.mean()D_real.backward(mone)  ##############noise = torch.randn(BATCH_SIZE, data.shape[1])if use_cuda:noise = noise.cuda()noisev = autograd.Variable(noise, volatile=True)  # volatile=True相当于 requires_grad=Falsefake = autograd.Variable(netG(noisev, real_data_v).data)inputv = fakeD_fake = netD(inputv)D_fake = D_fake.mean()D_fake.backward(one)  ################gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)gradient_penalty.backward()  ############D_cost = D_fake - D_real + gradient_penaltyWasserstein_D = D_real - D_fakeloss_list['D_loss'].append(D_cost.item())optimizerD.step()if not FIXED_GENERATOR:for p in netD.parameters():p.requires_grad = Falsefor iter_g in range(CRITIC_ITERG):netG.zero_grad()real_data = torch.Tensor(data)if use_cuda:real_data = real_data.cuda()real_data_v = autograd.Variable(real_data)noise = torch.randn(BATCH_SIZE, data.shape[1])if use_cuda:noise = noise.cuda()noisev = autograd.Variable(noise)fake = netG(noisev, real_data_v)G = netD(fake)G = G.mean()G.backward(mone)G_cost = -Gloss_list['G_loss'].append(G_cost.item())optimizerG.step()###save generated sample features every 200 iterationif iteration % iter_ctrl == 0:# if iteration % 10000 == 0:#     data = shuffle(data)# save_temp = pd.DataFrame(fake_output)# # fake_writer = open(file_name + "/Iteration_" + str(iteration) + ".txt", "w")# save_temp.to_csv(file_name + "/Iteration_" + str(iteration) + ".csv", index=None)print()print(f'循环{iteration}次..')x = np.concatenate((data, fake_output), axis=0)y = np.concatenate((one_list, zero_list), axis=0)kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)real_label = np.zeros((x.shape[0]))pred_label = np.zeros((x.shape[0]))for train_index, test_index in kfold.split(x, y):x_train, x_test = x[train_index], x[test_index]y_train, y_test = y[train_index], y[test_index]knn = KNeighborsClassifier(n_neighbors=1).fit(x_train, y_train)predicted_y = knn.predict(x_test)pred_label[test_index] = predicted_yreal_label[test_index] = y_testaccuracy = accuracy_score(real_label, pred_label)all_result.append(str(iteration) + "," + str(accuracy))print(f'计算{iteration}的acc={accuracy}')diff_accuracy_05 = abs(accuracy - 0.5)if diff_accuracy_05 < opt_diff_accuracy_05:opt_diff_accuracy_05 = diff_accuracy_05best_item = iterationopt_accuracy = accuracysave_temp = pd.DataFrame(fake_output)# fake_writer = open(file_name + "/Iteration_" + str(iteration) + ".txt", "w")save_temp.to_csv(file_name + "/Iteration3_"+ str(name_save) +'_'+ str(iteration) + ".csv",index=None)torch.save(netG.state_dict(), './model_file/netG'+str(iteration)+'.dict')torch.save(netD.state_dict(), './model_file/netD'+str(iteration)+'.dict')save_loss = pd.DataFrame(loss_list['G_loss'])save_loss.to_csv(file_name + "/Gloss_" + str(iteration) + '.csv', index=None)save_loss = pd.DataFrame(loss_list['D_loss'])save_loss.to_csv(file_name + "/Dloss_" + str(iteration) + '.csv', index=None)return best_item,opt_diff_accuracy_05if __name__ == '__main__':d = np.array(pd.read_csv('./data.txt', header=None))train_model_save_gen(d)

在本文件夹下新建data.txt,把下列数据粘贴进去(9行,4列):

1,1,1,2
1,1,1,2
1.1,1.1,1.1,2.1
1,1,1,2
1,1,1,2
1.1,1.1,1.1,2.1
1,1,1,2
1,1,1,2
1.1,1.1,1.1,2.1

调用上述函数即可

pytorch实现 wgan相关推荐

  1. 【深度学习2】基于Pytorch的WGAN理论和代码解析

    目录 1 原始GAN存在问题 2 WGAN原理 3 代码理解 GitHub源码 参考文章:令人拍案叫绝的Wasserstein GAN - 知乎 (zhihu.com) 1 原始GAN存在问题 实际训 ...

  2. pytorch搭建WGAN

    DCGAN只是在网络结构上做了相应的改进,但是实质上并没有解决gan中的本质缺陷 Wasserstein GAN(下面简称WGAN)成功地做到了以下爆炸性的几点: 彻底解决GAN训练不稳定的问题,不再 ...

  3. PyTorch实现WGAN

    目录 1.GAN简述 2.生成器模块 3.判别器模块 4.数据生成模块 5.判别器训练 6.生成器训练 7.结果可视化 1.GAN简述 在GAN中,有两个模型,一个是生成模型,用于生成样本,一个是判别 ...

  4. pytorch训练WGAN网络

    使用8个高斯分布做对抗网络训练. wgan_gp.py,代码: import torch from torch import nn, optim, autograd import numpy as n ...

  5. 【总目录3】Python、神经网络与深度学习、毕业设计总结大全

    本目录主要为Python相关目录,主要包含Python相关知识.神经网络与深度学习和毕业设计(基于机器学习及深度学习的心脏病预测方法)的Python实现等. 上文目录链接:https://blog.c ...

  6. PyTorch - GAN与WGAN及其实战

    目录 GAN 基本结构 训练 对于生成器 对于判别器 训练流程 训练理论 min max公式 Where will D converge, given fixed G Where will G con ...

  7. 【Pytorch】(十)生成对抗网络之WGAN,WGAN-GP

    文章目录 WGAN,WGAN-GP 原理 Pytorch实现:生成正态分布数据 WGAN WGAN-GP 结果对比 前些天发现一个通俗易懂,风趣幽默的人工智能学习网站: 传送门 WGAN,WGAN-G ...

  8. 【Pytorch神经网络理论篇】 23 对抗神经网络:概述流程 + WGAN模型 + WGAN-gp模型 + 条件GAN + WGAN-div + W散度

    1 对抗神经简介 1.1 对抗神经网络的基本组成 1.1.1 基本构成 对抗神经网络(即生成式对抗网络,GAN)一般由两个模型组成: 生成器模型(generator):用于合成与真实样本相差无几的模拟 ...

  9. WGAN模型——pytorch实现

    论文传送门:https://arxiv.org/pdf/1701.07875.pdf 参考文章:令人拍案叫绝的Wasserstein GAN - 知乎​​​​​​ WGAN的目的:解决GAN的梯度不稳 ...

最新文章

  1. R函数:交集intersect、并集union、找不同setdiff、判断相同setequal
  2. HALCON双目重建
  3. 【Git】从Git远程存储库中删除所有.pyc编译文件;附常用gitignore配置
  4. 屌丝就爱尝鲜头——java8初体验
  5. 从 Flink 应用场景出发,了解它的设计思路
  6. AD转换中参考电压的作用 .
  7. Android 系统 (79)---Android应用程序安装过程解析
  8. 苹果芯片团队又遭挖人,重量级芯片设计师被微软挖走
  9. EventBus HandlerPoster简单分析
  10. 帮忙做c语言作业,c语言..题目.帮忙做一下
  11. 服务器进销财务管理系统,进销存财务管理系统
  12. WooCommerce税收入门指南,第2部分
  13. 无法加入webmeeting, 无法打开webex
  14. 网站SEO优化::降低网站跳出率、提高用户粘度网站用户粘度
  15. 提取Blast2go blast结果中的一部分
  16. debian 10的安装DVD
  17. 自动驾驶专题介绍 ———— 动力传动系统
  18. DirectShow 智能连接
  19. AfxGetThreadState 与 _AFX_THREAD_STATE 剖析
  20. PTA 6-6 使用函数计算两个复数之积

热门文章

  1. 研华微型计算机biso,研华工控机BIOS设置通电自启动方法技巧
  2. Linux磁盘扩容(2T以上/parted/gpt分区)
  3. Vue移动端项目(一)
  4. 要想焊好氩弧焊,这些工艺禁忌你要记牢
  5. 数据库正规化和设计技巧
  6. [拓扑排序][DP][Tarjan][并查集]JZOJ 4253 QYQ在艾泽拉斯
  7. Linux下载VOC数据集
  8. 灰度发布与滚动发布、蓝绿发布介绍
  9. 用canvas画矩形
  10. Linux监控Nginx服务,关闭就自动重启