代码来源

代码全文


clear all; close all; clc;
%% Basic Generative Adversarial Network
%% Load Data
load('mnistAll.mat')
trainX = preprocess(mnist.train_images);
trainY = mnist.train_labels;
testX = preprocess(mnist.test_images);
testY = mnist.test_labels;
%% Settings
settings.latent_dim = 100;
settings.batch_size = 32; settings.image_size = [28,28,1];
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;%% Initialization
%% Generator
paramsGen.FCW1 = dlarray(...initializeGaussian([256,settings.latent_dim],.02));
paramsGen.FCb1 = dlarray(zeros(256,1,'single'));
paramsGen.BNo1 = dlarray(zeros(256,1,'single'));
paramsGen.BNs1 = dlarray(ones(256,1,'single'));
paramsGen.FCW2 = dlarray(initializeGaussian([512,256]));
paramsGen.FCb2 = dlarray(zeros(512,1,'single'));
paramsGen.BNo2 = dlarray(zeros(512,1,'single'));
paramsGen.BNs2 = dlarray(ones(512,1,'single'));
paramsGen.FCW3 = dlarray(initializeGaussian([1024,512]));
paramsGen.FCb3 = dlarray(zeros(1024,1,'single'));
paramsGen.BNo3 = dlarray(zeros(1024,1,'single'));
paramsGen.BNs3 = dlarray(ones(1024,1,'single'));
paramsGen.FCW4 = dlarray(initializeGaussian(...[prod(settings.image_size),1024]));
paramsGen.FCb4 = dlarray(zeros(prod(settings.image_size)...,1,'single'));stGen.BN1 = []; stGen.BN2 = []; stGen.BN3 = [];%% Discriminator
paramsDis.FCW1 = dlarray(initializeGaussian([1024,...prod(settings.image_size)],.02));
paramsDis.FCb1 = dlarray(zeros(1024,1,'single'));
paramsDis.BNo1 = dlarray(zeros(1024,1,'single'));
paramsDis.BNs1 = dlarray(ones(1024,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([512,1024]));
paramsDis.FCb2 = dlarray(zeros(512,1,'single'));
paramsDis.BNo2 = dlarray(zeros(512,1,'single'));
paramsDis.BNs2 = dlarray(ones(512,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb3 = dlarray(zeros(256,1,'single'));
paramsDis.FCW4 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb4 = dlarray(zeros(1,1,'single'));stDis.BN1 = []; stDis.BN2 = [];% average Gradient and average Gradient squared holders
avgG.Dis = []; avgGS.Dis = []; avgG.Gen = []; avgGS.Gen = [];
%% Train
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~outtic; trainXshuffle = trainX(:,randperm(size(trainX,2)));fprintf('Epoch %d\n',epoch) for i=1:numIterationsglobal_iter = global_iter+1;noise = gpdl(randn([settings.latent_dim,...settings.batch_size]),'CB');idx = (i-1)*settings.batch_size+1:i*settings.batch_size;XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');[GradGen,GradDis,stGen,stDis] = ...dlfeval(@modelGradients,XBatch,noise,...paramsGen,paramsDis,stGen,stDis);% Update Discriminator network parameters[paramsDis,avgG.Dis,avgGS.Dis] = ...adamupdate(paramsDis, GradDis, ...avgG.Dis, avgGS.Dis, global_iter, ...settings.lrD, settings.beta1, settings.beta2);% Update Generator network parameters[paramsGen,avgG.Gen,avgGS.Gen] = ...adamupdate(paramsGen, GradGen, ...avgG.Gen, avgGS.Gen, global_iter, ...settings.lrG, settings.beta1, settings.beta2);if i==1 || rem(i,20)==0progressplot(paramsGen,stGen,settings);
%             if i==1 || (epoch>=0 && i==1)
%                 h = gcf;
%                 % Capture the plot as an image
%                 frame = getframe(h);
%                 im = frame2im(frame);
%                 [imind,cm] = rgb2ind(im,256);
%                 % Write to the GIF File
%                 if epoch == 0
%                   imwrite(imind,cm,'GANmnist.gif','gif', 'Loopcount',inf);
%                 else
%                   imwrite(imind,cm,'GANmnist.gif','gif','WriteMode','append');
%                 end
%             endendendelapsedTime = toc;disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")epoch = epoch+1;if epoch == settings.maxepochsout = true;end
end
%% Helper Functions
%% preprocess
function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end
%% extract data
function x = gatext(x)
x = gather(extractdata(x));
end
%% gpu dl array wrapper
function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end
%% Weight initialization
function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end
%% Generator
function [dly,st] = Generator(dlx,params,st)
% fully connected
%1
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,0.2);
% if isempty(st.BN1)
%     [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,params.BNs1);
% else
%     [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,...
%         params.BNs1,st.BN1.mu,st.BN1.sig);
% end
%2
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,0.2);
% if isempty(st.BN2)
%     [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,params.BNs2);
% else
%     [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,...
%         params.BNs2,st.BN2.mu,st.BN2.sig);
% end
%3
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,0.2);
% if isempty(st.BN3)
%     [dly,st.BN3.mu,st.BN3.sig] = batchnorm(dly,params.BNo3,params.BNs3);
% else
%     [dly,st.BN3.mu,st.BN3.sig] = batchnorm(dly,params.BNo3,...
%         params.BNs3,st.BN3.mu,st.BN3.sig);
% end
%4
dly = fullyconnect(dly,params.FCW4,params.FCb4);
% tanh
dly = tanh(dly);
end
%% Discriminator
function [dly,st] = Discriminator(dlx,params,st)
% fully connected
%1
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,0.2);
dly = dropout(dly);
% if isempty(st.BN1)
%     [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,params.BNs1);
% else
%     [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,...
%         params.BNs1,st.BN1.mu,st.BN1.sig);
% end
%2
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,0.2);
dly = dropout(dly);
% if isempty(st.BN2)
%     [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,params.BNs2);
% else
%     [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,...
%         params.BNs2,st.BN2.mu,st.BN2.sig);
% end
%3
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,0.2);
dly = dropout(dly);
%4
dly = fullyconnect(dly,params.FCW4,params.FCb4);
% sigmoid
dly = sigmoid(dly);
end
%% modelGradients
function [GradGen,GradDis,stGen,stDis]=modelGradients(x,z,paramsGen,...paramsDis,stGen,stDis)
[fake_images,stGen] = Generator(z,paramsGen,stGen);
d_output_real = Discriminator(x,paramsDis,stDis);
[d_output_fake,stDis] = Discriminator(fake_images,paramsDis,stDis);% Loss due to true or not
d_loss = -mean(.9*log(d_output_real+eps)+log(1-d_output_fake+eps));
g_loss = -mean(log(d_output_fake+eps));% For each network, calculate the gradients with respect to the loss.
GradGen = dlgradient(g_loss,paramsGen,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end
%% progressplot
function progressplot(paramsGen,stGen,settings)
r = 5; c = 5;
noise = gpdl(randn([settings.latent_dim,r*c]),'CB');
gen_imgs = Generator(noise,paramsGen,stGen);
gen_imgs = reshape(gen_imgs,28,28,[]);fig = gcf;
if ~isempty(fig.Children)delete(fig.Children)
endI = imtile(gatext(gen_imgs));
I = rescale(I);
imagesc(I)
title("Generated Images")
colormap graydrawnow;
end
%% dropout
function dly = dropout(dlx,p)
if nargin < 2p = .3;
end
n = p*10;
mask = randi([1,10],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*

代码展示

本代码采用MNIST手写数字数据集(训练集60000个,测试集10000个,本例中采用训练集数据),可实现数据集自动下载,最大的epoch次数为50,单个epoch中有1875个batch,batch_size为32,结果迭代如下:

第0次迭代

第10次迭代

第20次迭代

第30次迭代

第40次迭代

第50迭代

代码以及相关资料附件

ggg9

内容简介:

1、上述网址源代码压缩包(本文代码在...\github_repo\GAN下,GAN.m文件)

2、生成对抗网络的源文献

关注公众号“故障诊断与寿命预测工具箱”,每天进步一点点。

GAN之生成对抗网络(Matlab)相关推荐

  1. DL之GAN:生成对抗网络GAN的简介、应用、经典案例之详细攻略

    DL之GAN:生成对抗网络GAN的简介.应用.经典案例之详细攻略 目录 生成对抗网络GAN的简介 1.生成对抗网络的重要进展 1.1.1986年的RBM→2006年的DBN

  2. 什么是GAN(生成对抗网络)?

    GAN是一种深度学习模型,全称为生成对抗网络(Generative Adversarial Networks).它由两个神经网络组成:一个生成器网络和一个判别器网络. 什么是GAN(生成对抗网络)? ...

  3. GAN(生成对抗网络)在合成时间序列数据中的应用(第二部分——利用GAN生成时间序列数据)

    GAN(生成对抗网络)在合成时间序列数据中的应用(第二部分–TimeGAN 与合成金融输入) (本文基本是对Jasen 的<Machine Learning for Algorithmic Tr ...

  4. GAN(生成对抗网络) and CGAN(条件生成对抗网络)

    前言 GAN(生成对抗网络)是2014年由Goodfellow大佬提出的一种深度生成模型,适用于无监督学习.监督学习.但是GAN进行生成时是不可控的,所以后来又有人提出可控的CGAN(条件生成对抗网络 ...

  5. 【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

    [万物皆可 GAN]生成对抗网络生成手写数字 Part 1 概述 GAN 网络结构 GAN 训练流程 模型详解 生成器 判别器 概述 GAN (Generative Adversarial Netwo ...

  6. GAN(1)-生成对抗网络的开山之作

    生成对抗网络的开山之作-GAN 1.有监督到无监督  ​ 图上方表示监督学习,我们将标记好的数据对传入网络,在标签的作用下监督训练.而很多时候我们提供不了训练数据,这时候神经网络就应该学会自己给数据打 ...

  7. GAN(生成对抗网络) 解释

    GAN (生成对抗网络)是近几年深度学习中一个比较热门的研究方向,它的变种有上千种. 1.什么是GAN GAN的英文全称是Generative Adversarial Network,中文名是生成对抗 ...

  8. pytorch实现GAN(生成对抗网络)生成二次元头像(附代码)

    目录 GAN基本概念 GAN算法流程 代码实现与讲解 1.准备数据集 代码实现 定义鉴别器 定义生成器 训练 补充 附完整代码 参考链接及书目 GAN基本概念 GAN, 全称Generative Ad ...

  9. 机器学习:Gan(生成对抗网络)

    版权声明:本文为CSDN博主「意念回复」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明. 原文链接:https://blog.csdn.net/weixin_3991 ...

最新文章

  1. Esfog_UnityShader教程_漫反射DiffuseReflection
  2. dubbo 常见错误
  3. cv::imread导致段错误_网络诊断举例LSO导致的网络性能问题
  4. 从一道面试题说起—js隐式转换踩坑合集
  5. linux之一些比较新但是常用的命令(expr ag tree cloc stat tmux axel)
  6. udf提权 udf.php,UDF提权
  7. 【*项目调研+论文阅读】SVM-BILSTM-CRF模型SVM-BILSTM-CRF模型 | day7
  8. 人生一知己,足以慰风尘吗?
  9. LeetCode 116/117 填充每个节点下一个右侧指针
  10. vs2013 CodeLens
  11. 第3章 Kafka API
  12. 线性代数之矩阵逆的求法
  13. 288388D-EnterCAT调试
  14. 拍视频到底用手机还是相机好?
  15. VC++ 6.0之MSComm控件安装、使用
  16. 2017南宁(重温经典)
  17. 【P1889 士兵站队】(洛谷)
  18. 新驾考指南---[C1-图文全程指导篇]
  19. 做人的六原则 40条心计 共勉
  20. 附代码 | OpenCV实现银行卡号识别,字符识别算法你知多少?

热门文章

  1. 【期末考试不挂科】计算机网络必刷题
  2. 【STM32】HAL库PWM实现呼吸灯实验
  3. android短彩信数据库设计源码解析(二)
  4. bootstrap datetimepicker
  5. 想学IT的必看!如何才能通过一线互联网公司面试?面试必问
  6. linux ping时丢包怎么解决办法,ping丢包故障处理方法
  7. 春季心动款外套搭配 彩色小西装令气质飙升 - 七丽女性网
  8. zzulioj1049:平方和与立方和
  9. 消息队列、RabbitMQ原理、消息队列保证幂等性,消息丢失,消息顺序性,以及处理消息队列消息积压问题
  10. ext4文件系统综述