GAN之生成对抗网络(Matlab)
代码来源
代码全文
clear all; close all; clc;
%% Basic Generative Adversarial Network
%% Load Data
load('mnistAll.mat')
trainX = preprocess(mnist.train_images);
trainY = mnist.train_labels;
testX = preprocess(mnist.test_images);
testY = mnist.test_labels;
%% Settings
settings.latent_dim = 100;
settings.batch_size = 32; settings.image_size = [28,28,1];
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;%% Initialization
%% Generator
paramsGen.FCW1 = dlarray(...initializeGaussian([256,settings.latent_dim],.02));
paramsGen.FCb1 = dlarray(zeros(256,1,'single'));
paramsGen.BNo1 = dlarray(zeros(256,1,'single'));
paramsGen.BNs1 = dlarray(ones(256,1,'single'));
paramsGen.FCW2 = dlarray(initializeGaussian([512,256]));
paramsGen.FCb2 = dlarray(zeros(512,1,'single'));
paramsGen.BNo2 = dlarray(zeros(512,1,'single'));
paramsGen.BNs2 = dlarray(ones(512,1,'single'));
paramsGen.FCW3 = dlarray(initializeGaussian([1024,512]));
paramsGen.FCb3 = dlarray(zeros(1024,1,'single'));
paramsGen.BNo3 = dlarray(zeros(1024,1,'single'));
paramsGen.BNs3 = dlarray(ones(1024,1,'single'));
paramsGen.FCW4 = dlarray(initializeGaussian(...[prod(settings.image_size),1024]));
paramsGen.FCb4 = dlarray(zeros(prod(settings.image_size)...,1,'single'));stGen.BN1 = []; stGen.BN2 = []; stGen.BN3 = [];%% Discriminator
paramsDis.FCW1 = dlarray(initializeGaussian([1024,...prod(settings.image_size)],.02));
paramsDis.FCb1 = dlarray(zeros(1024,1,'single'));
paramsDis.BNo1 = dlarray(zeros(1024,1,'single'));
paramsDis.BNs1 = dlarray(ones(1024,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([512,1024]));
paramsDis.FCb2 = dlarray(zeros(512,1,'single'));
paramsDis.BNo2 = dlarray(zeros(512,1,'single'));
paramsDis.BNs2 = dlarray(ones(512,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb3 = dlarray(zeros(256,1,'single'));
paramsDis.FCW4 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb4 = dlarray(zeros(1,1,'single'));stDis.BN1 = []; stDis.BN2 = [];% average Gradient and average Gradient squared holders
avgG.Dis = []; avgGS.Dis = []; avgG.Gen = []; avgGS.Gen = [];
%% Train
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~outtic; trainXshuffle = trainX(:,randperm(size(trainX,2)));fprintf('Epoch %d\n',epoch) for i=1:numIterationsglobal_iter = global_iter+1;noise = gpdl(randn([settings.latent_dim,...settings.batch_size]),'CB');idx = (i-1)*settings.batch_size+1:i*settings.batch_size;XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');[GradGen,GradDis,stGen,stDis] = ...dlfeval(@modelGradients,XBatch,noise,...paramsGen,paramsDis,stGen,stDis);% Update Discriminator network parameters[paramsDis,avgG.Dis,avgGS.Dis] = ...adamupdate(paramsDis, GradDis, ...avgG.Dis, avgGS.Dis, global_iter, ...settings.lrD, settings.beta1, settings.beta2);% Update Generator network parameters[paramsGen,avgG.Gen,avgGS.Gen] = ...adamupdate(paramsGen, GradGen, ...avgG.Gen, avgGS.Gen, global_iter, ...settings.lrG, settings.beta1, settings.beta2);if i==1 || rem(i,20)==0progressplot(paramsGen,stGen,settings);
% if i==1 || (epoch>=0 && i==1)
% h = gcf;
% % Capture the plot as an image
% frame = getframe(h);
% im = frame2im(frame);
% [imind,cm] = rgb2ind(im,256);
% % Write to the GIF File
% if epoch == 0
% imwrite(imind,cm,'GANmnist.gif','gif', 'Loopcount',inf);
% else
% imwrite(imind,cm,'GANmnist.gif','gif','WriteMode','append');
% end
% endendendelapsedTime = toc;disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")epoch = epoch+1;if epoch == settings.maxepochsout = true;end
end
%% Helper Functions
%% preprocess
function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end
%% extract data
function x = gatext(x)
x = gather(extractdata(x));
end
%% gpu dl array wrapper
function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end
%% Weight initialization
function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end
%% Generator
function [dly,st] = Generator(dlx,params,st)
% fully connected
%1
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,0.2);
% if isempty(st.BN1)
% [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,params.BNs1);
% else
% [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,...
% params.BNs1,st.BN1.mu,st.BN1.sig);
% end
%2
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,0.2);
% if isempty(st.BN2)
% [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,params.BNs2);
% else
% [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,...
% params.BNs2,st.BN2.mu,st.BN2.sig);
% end
%3
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,0.2);
% if isempty(st.BN3)
% [dly,st.BN3.mu,st.BN3.sig] = batchnorm(dly,params.BNo3,params.BNs3);
% else
% [dly,st.BN3.mu,st.BN3.sig] = batchnorm(dly,params.BNo3,...
% params.BNs3,st.BN3.mu,st.BN3.sig);
% end
%4
dly = fullyconnect(dly,params.FCW4,params.FCb4);
% tanh
dly = tanh(dly);
end
%% Discriminator
function [dly,st] = Discriminator(dlx,params,st)
% fully connected
%1
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,0.2);
dly = dropout(dly);
% if isempty(st.BN1)
% [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,params.BNs1);
% else
% [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,...
% params.BNs1,st.BN1.mu,st.BN1.sig);
% end
%2
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,0.2);
dly = dropout(dly);
% if isempty(st.BN2)
% [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,params.BNs2);
% else
% [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,...
% params.BNs2,st.BN2.mu,st.BN2.sig);
% end
%3
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,0.2);
dly = dropout(dly);
%4
dly = fullyconnect(dly,params.FCW4,params.FCb4);
% sigmoid
dly = sigmoid(dly);
end
%% modelGradients
function [GradGen,GradDis,stGen,stDis]=modelGradients(x,z,paramsGen,...paramsDis,stGen,stDis)
[fake_images,stGen] = Generator(z,paramsGen,stGen);
d_output_real = Discriminator(x,paramsDis,stDis);
[d_output_fake,stDis] = Discriminator(fake_images,paramsDis,stDis);% Loss due to true or not
d_loss = -mean(.9*log(d_output_real+eps)+log(1-d_output_fake+eps));
g_loss = -mean(log(d_output_fake+eps));% For each network, calculate the gradients with respect to the loss.
GradGen = dlgradient(g_loss,paramsGen,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end
%% progressplot
function progressplot(paramsGen,stGen,settings)
r = 5; c = 5;
noise = gpdl(randn([settings.latent_dim,r*c]),'CB');
gen_imgs = Generator(noise,paramsGen,stGen);
gen_imgs = reshape(gen_imgs,28,28,[]);fig = gcf;
if ~isempty(fig.Children)delete(fig.Children)
endI = imtile(gatext(gen_imgs));
I = rescale(I);
imagesc(I)
title("Generated Images")
colormap graydrawnow;
end
%% dropout
function dly = dropout(dlx,p)
if nargin < 2p = .3;
end
n = p*10;
mask = randi([1,10],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*
代码展示
本代码采用MNIST手写数字数据集(训练集60000个,测试集10000个,本例中采用训练集数据),可实现数据集自动下载,最大的epoch次数为50,单个epoch中有1875个batch,batch_size为32,结果迭代如下:
第0次迭代
第10次迭代
第20次迭代
第30次迭代
第40次迭代
第50迭代
代码以及相关资料附件
ggg9
内容简介:
1、上述网址源代码压缩包(本文代码在...\github_repo\GAN下,GAN.m文件)
2、生成对抗网络的源文献
关注公众号“故障诊断与寿命预测工具箱”,每天进步一点点。
GAN之生成对抗网络(Matlab)相关推荐
- DL之GAN:生成对抗网络GAN的简介、应用、经典案例之详细攻略
DL之GAN:生成对抗网络GAN的简介.应用.经典案例之详细攻略 目录 生成对抗网络GAN的简介 1.生成对抗网络的重要进展 1.1.1986年的RBM→2006年的DBN
- 什么是GAN(生成对抗网络)?
GAN是一种深度学习模型,全称为生成对抗网络(Generative Adversarial Networks).它由两个神经网络组成:一个生成器网络和一个判别器网络. 什么是GAN(生成对抗网络)? ...
- GAN(生成对抗网络)在合成时间序列数据中的应用(第二部分——利用GAN生成时间序列数据)
GAN(生成对抗网络)在合成时间序列数据中的应用(第二部分–TimeGAN 与合成金融输入) (本文基本是对Jasen 的<Machine Learning for Algorithmic Tr ...
- GAN(生成对抗网络) and CGAN(条件生成对抗网络)
前言 GAN(生成对抗网络)是2014年由Goodfellow大佬提出的一种深度生成模型,适用于无监督学习.监督学习.但是GAN进行生成时是不可控的,所以后来又有人提出可控的CGAN(条件生成对抗网络 ...
- 【万物皆可 GAN】生成对抗网络生成手写数字 Part 1
[万物皆可 GAN]生成对抗网络生成手写数字 Part 1 概述 GAN 网络结构 GAN 训练流程 模型详解 生成器 判别器 概述 GAN (Generative Adversarial Netwo ...
- GAN(1)-生成对抗网络的开山之作
生成对抗网络的开山之作-GAN 1.有监督到无监督 图上方表示监督学习,我们将标记好的数据对传入网络,在标签的作用下监督训练.而很多时候我们提供不了训练数据,这时候神经网络就应该学会自己给数据打 ...
- GAN(生成对抗网络) 解释
GAN (生成对抗网络)是近几年深度学习中一个比较热门的研究方向,它的变种有上千种. 1.什么是GAN GAN的英文全称是Generative Adversarial Network,中文名是生成对抗 ...
- pytorch实现GAN(生成对抗网络)生成二次元头像(附代码)
目录 GAN基本概念 GAN算法流程 代码实现与讲解 1.准备数据集 代码实现 定义鉴别器 定义生成器 训练 补充 附完整代码 参考链接及书目 GAN基本概念 GAN, 全称Generative Ad ...
- 机器学习:Gan(生成对抗网络)
版权声明:本文为CSDN博主「意念回复」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明. 原文链接:https://blog.csdn.net/weixin_3991 ...
最新文章
- Esfog_UnityShader教程_漫反射DiffuseReflection
- dubbo 常见错误
- cv::imread导致段错误_网络诊断举例LSO导致的网络性能问题
- 从一道面试题说起—js隐式转换踩坑合集
- linux之一些比较新但是常用的命令(expr ag tree cloc stat tmux axel)
- udf提权 udf.php,UDF提权
- 【*项目调研+论文阅读】SVM-BILSTM-CRF模型SVM-BILSTM-CRF模型 | day7
- 人生一知己,足以慰风尘吗?
- LeetCode 116/117 填充每个节点下一个右侧指针
- vs2013 CodeLens
- 第3章 Kafka API
- 线性代数之矩阵逆的求法
- 288388D-EnterCAT调试
- 拍视频到底用手机还是相机好?
- VC++ 6.0之MSComm控件安装、使用
- 2017南宁(重温经典)
- 【P1889 士兵站队】(洛谷)
- 新驾考指南---[C1-图文全程指导篇]
- 做人的六原则 40条心计 共勉
- 附代码 | OpenCV实现银行卡号识别,字符识别算法你知多少?