【GANS】对Stylgan2的动漫头像预训练模型进行【属性编辑】
核心参考
【代码】Pretrained Anime StyleGAN2 — convert to pytorch and editing images by encoder
【理论】http://www.seeprettyface.com/research_notes.html#step6
本博客代码地址
- https://gitee.com/zengxy2020/csdn_stylegan2_edit_latent
0 概要
基于stylegan2-pytorch
动漫预训练模型,
结合标签和分类器,可以对生成动漫头像进行某些特征编辑
编辑图像:
其中,中间图片(第4张图
)为原图
向左生成闭嘴,向右生成张嘴
。
1. 下载项目相关
stylegan2-pytorch (非官方)
git clone https://github.com/viuts/stylegan2_pytorch.git
1.1 动漫头像(anime-face)模型下载
2020-01-11-skylion-stylegan2-animeportraits-networksnapshot-024664.pkl
- google网盘
- 百度网盘 - 提取码:
6b84
- 其他下载链接
1.2 将tf模型转为pytorch
cd your_path/stylegan2_pytorchpython run_convert_from_tf.py --input=2020-01-11-skylion-stylegan2-animeportraits-networksnapshot-024664.pkl --output checkpoint
1.3 推理生成图片
python run_generator.py generate_images --network=checkpoint/Gs.pth \
--seeds=1-5 --truncation_psi=1.0
生成图片保存在
stylegan2_pytorch/results
下
2. 获得动漫属性便签
核心步骤
- web api
- 预训练模型模型获得标签
web端示意效果
2.1 测试本地标签分类器推理
python edit_convert_cntk_2_onnx.py
- 代码见
附录
2.2 打标签
- 生成4*5000张图片,并逐个得到标签,
- 保存
dlatend
,labels
,tags
,- 其中
tags
表示输出labels序号的对应的中文名称
- 其中
- 耗时
- 2小时40分钟
3 训练需要编辑潜码方向
3.0 理论依据
- http://www.seeprettyface.com/research_notes.html#step6
如下图所示,以标签的中位数为分界线,低于该值的标签改为0,高于该值的标签改为1,然后构造w·x+b=y的目标函数,运用逻辑斯蒂回归求解此二分类问题,求出的w就可以近似为我们需要的方向向量
3.1 以编辑头发颜色为例
代码edit_find_latend_direction.py
见附录
- 训练结果
3.1.1 编辑结果1-粉色头发
- (中间是原图)3
- 左边是非粉色,右边是粉色插值
3.1.2 编辑结果2-黑色头发
3.1.3 编辑结果3-棕色色头发
3.2可能遇到问题
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
将tensor全部转到cpu或者gpu
one more thing
对其他预训练模型进行编辑(未开源
)
短发到长发
其他参考
- https://www.gwern.net/Faces#stylegan-2
附录
edit_convert_cntk_2_onnx.py
import cntk as Cimport onnxruntime
import numpy as npdef convet():model_path = 'checkpoint/danbooru-resnet_custom_v1-p3/model.cntk'ctnk_model = C.load_model(model_path)ctnk_model.save('checkpoint/model.onnx', format=C.ModelFormat.ONNX)return ctnk_modeldef test_cntk(ctnk_model):ort_session = onnxruntime.InferenceSession("checkpoint/model.onnx")x = np.random.rand(1, 3, 299, 299).astype(np.float32)# compute ONNX Runtime output predictionort_inputs = {ort_session.get_inputs()[0].name: x}ort_outs = ort_session.run(None, ort_inputs)# compute the ctnk outputcntk_out = ctnk_model.eval(x)np.testing.assert_allclose(np.array(cntk_out), ort_outs[0], rtol=1e-03, atol=1e-05)if __name__ == '__main__':ctnk_model=convet()test_cntk(ctnk_model)
edit_generate-labeled-anime-data.py
import os
import onnxruntime
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as pltinput_path = 'checkpoint'
tags_path = os.path.join(input_path, 'tags.txt')
model_path = os.path.join(input_path, 'model.onnx')
generator_path = os.path.join(input_path, 'Gs.pth')
device = device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 4
seed = 0# let's run one image to checkout if it works
C = onnxruntime.InferenceSession(model_path)with open(tags_path, 'r') as tags_stream:tags = np.array([tag for tag in (tag.strip() for tag in tags_stream) if tag])import stylegan2
from stylegan2 import utilsG = stylegan2.models.load(generator_path, map_location=device)
G.to(device)def to_image_tensor(image_tensor, pixel_min=-1, pixel_max=1):if pixel_min != 0 or pixel_max != 1:image_tensor = (image_tensor - pixel_min) / (pixel_max - pixel_min)return image_tensor.clamp(min=0, max=1)torch.manual_seed(seed)qlatents = torch.randn(1, G.latent_size).to(device=device, dtype=torch.float32)
generated = G(qlatents)
images = to_image_tensor(generated)
# 299 is the input size of the model
images = F.interpolate(images, size=(299, 299), mode='bilinear')
ort_inputs = {C.get_inputs()[0].name: images.detach().cpu().numpy()}
[predicted_labels] = C.run(None, ort_inputs)
# print out some tags
plt.imshow(images[0].detach().cpu().permute(1, 2, 0))
labels = [tags[i] for i, score in enumerate(predicted_labels[0]) if score > 0.5]
print(labels)# reset seed
torch.manual_seed(seed)
iteration = 5000progress = utils.ProgressWriter(iteration)
progress.write('Generating images...', step=False)qlatents_data = torch.Tensor(0, G.latent_size).to(device=device, dtype=torch.float32)
dlatents_data = torch.Tensor(0, 16, G.latent_size).to(device=device, dtype=torch.float32)
labels_data = torch.Tensor(0, len(tags)).to(device=device, dtype=torch.float32)
for i in range(iteration):qlatents = torch.randn(batch_size, G.latent_size).to(device=device, dtype=torch.float32)with torch.no_grad():generated, dlatents = G(latents=qlatents, return_dlatents=True)# inplace to save memorygenerated = to_image_tensor(generated)# 299 is the input size of the model# resize the image to 299 * 299images = F.interpolate(generated, size=(299, 299), mode='bilinear')labels = []## tagger does not take input as batch, need to feed one by onefor image in images:ort_inputs = {C.get_inputs()[0].name: image.reshape(1, 3, 299, 299).detach().cpu().numpy()}[[predicted_labels]] = C.run(None, ort_inputs)labels.append(predicted_labels)# store the resultlabels_tensor = torch.Tensor(labels).to(device=device, dtype=torch.float32)qlatents_data = torch.cat((qlatents_data, qlatents))dlatents_data = torch.cat((dlatents_data, dlatents))labels_data = torch.cat((labels_data, labels_tensor))progress.step()progress.write('Done!', step=False)
progress.close()torch.save({'qlatents_data': qlatents_data.cpu(),'dlatents_data': dlatents_data.cpu(),'labels_data': labels_data.cpu(),'tags': tags
}, 'latents.pth')
edit_find_latend_direction.py
import torch
import matplotlib.pylab as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import osimport stylegan2input_path = 'checkpoint'
latents_path = os.path.join(input_path, 'latents.pth')
generator_path = os.path.join(input_path, 'Gs.pth')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')#根据资源进行设置
batch_size = 16seed = 0state = torch.load(latents_path, map_location=device)qlatents_data = state['qlatents_data']
dlatents_data = state['dlatents_data']
labels_data = state['labels_data']
tags = state['tags']G = stylegan2.models.load(generator_path).to(device)dlatents_data=dlatents_data.to(device=device, dtype=torch.float32)
labels_data=labels_data.to(device=device, dtype=torch.float32)
print("dlatents_data.size()",dlatents_data.size())
print("labels_data.size()",labels_data.size())zipped = list(zip(dlatents_data, labels_data))train_size = int(0.7 * len(zipped))
valid_size = int(len(zipped) * 0.2)
test_size = len(zipped) - train_size - valid_sizetrain_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(zipped, [train_size, valid_size, test_size])#参考代码num_workers=4会报错,可根据自己实际情况修改
datasets = dict(train=torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0),valid=torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=0),test=test_dataset,
)class Net(nn.Module):def __init__(self):super(Net, self).__init__()# kernelself.main = nn.Sequential(nn.Linear(in_features=16 * 512, out_features=1),nn.LeakyReLU(0.2),nn.Sigmoid(),)def forward(self, x):return self.main(x)def train_coeff(tag, total=5):model = Net()model=model.to(device)# create your optimizeroptimizer = optim.SGD(model.parameters(), lr=0.01)criterion = nn.BCELoss()[tag_index], = np.where(tags == tag)epoch = 0epoch_val_loss_min = 100while True:epoch += 1training_loss, valid_loss = 0.0, 0.0for phase in ['train', 'valid']:dataset = datasets[phase]if phase == 'train':model.train() # Set model to training modeelse:model.eval() # Set model to evaluate moderunning_loss = 0.0running_corrects = 0for (dlatents, labels) in dataset:# in your training loop:optimizer.zero_grad() # zero the gradient bufferswith torch.set_grad_enabled(phase == 'train'):inputs = dlatents.reshape(-1, 16 * 512)inputs=inputs.to(device)output = model(inputs)targets = torch.Tensor(0, 1)targets = targets.to(device)for label in labels:value = label[tag_index]# value = 1.0 if value > 0.5 else 0.0new_label = torch.Tensor([[value]])new_label = new_label.to(device)targets = torch.cat((targets, new_label))loss = criterion(output, targets)if phase == 'train':loss.backward()optimizer.step() # Does the update# statisticsrunning_loss += loss.item() * inputs.size(0)epoch_loss = running_loss / (len(dataset) * batch_size)print(f'Epoch:{epoch}/{total}, {phase} Loss: {epoch_loss:.4f}')#根据验证loss 保存最佳if phase=='valid':if epoch_loss < epoch_val_loss_min:epoch_val_loss_min=epoch_lossweight_val_min = model.state_dict()['main.0.weight'] # not biasdirection_path = f'checkpoint/directions_{tag}_val{epoch_loss}.pth'# torch.save(weight_val_min, direction_path)print(f" the best val is:{epoch_loss}")if epoch == total:breakweight = weight_val_minreturn weight.detach().cpu().reshape(16, 512)def generate_image(dlatents, pixel_min=-1, pixel_max=1):generated = G(dlatents=dlatents)if pixel_min != 0 or pixel_max != 1:generated = (generated - pixel_min) / (pixel_max - pixel_min)generated.clamp_(min=0, max=1)return generated.detach().cpu().reshape(3, 512, 512).permute(1, 2, 0)def move_and_show(latent_vector, direction, coeffs):img_list = []for i, coeff in enumerate(coeffs):new_latent_vector = latent_vector.clone()# direction=direction.to(direction)# print((new_latent_vector),(latent_vector),(direction),coeff)new_latent_vector[:8] = (latent_vector + coeff * direction)[:8]img=generate_image(new_latent_vector)img_list.append(img)# plt.show()return img_list"core"def move_and_show_samples(direction, direction_name,sample=3, coeffs=[-10,-5,2, 0,2, 5,10]):fig, ax = plt.subplots(sample, 1, figsize=(50, 50), dpi=80)for i,(latents, labels) in enumerate(list(datasets['test'])[:sample]):inputs = latents.clone().reshape(1, 16, 512)direction=direction.to(device)img_list=move_and_show(inputs, direction, coeffs)ax[i].imshow(np.hstack(img_list))plt.suptitle(f'Edit: {direction_name}',size=16)[x.axis('off') for x in ax] #取消网格plt.tight_layout() # 使图片自适应填充save_folder=f'./edit/'os.makedirs(save_folder,exist_ok=True)plt.savefig(f"{save_folder}/{direction_name}combine_{sample}.png")if __name__ == '__main__':result = {}'''训练迭代,一般3次就够了'''flag_train=1direction_save_path = f'checkpoint/direction.pth'if flag_train:picked_tags=['black_hair','pink_hair','open_mouth','brown_hair']# filter out the real tagspicked_tags = [tag for tag in picked_tags if tag in tags]print(picked_tags)for tag in picked_tags:print(f'training {tag}')result[tag] = train_coeff(tag, 3)'''保存所有维度'''torch.save(result, direction_save_path)else:result = torch.load(direction_save_path, map_location=device)'''可视化编辑结果'''for name in result.keys():move_and_show_samples(result[name],name,sample=5)'''## Let's pick some tags and train it!colors = ['aqua', 'black', 'blue', 'brown', 'green', 'grey', 'lavender', 'light_brown', 'multicolored', 'orange','pink', 'purple', 'red', 'silver', 'white', 'yellow']switches = ['open', 'closed', 'covered']# generate composition of elementscomponents = ['eyes', 'hair', 'mouth']picked_tags = []for component in components:picked_tags = picked_tags + [f'{color}_{component}' for color in colors]picked_tags = picked_tags + [f'{switch}_{component}' for switch in switches]# filter out the real tagspicked_tags = [tag for tag in picked_tags if tag in tags]print(picked_tags)## Train all these tags!for tag in picked_tags:print(f'training {tag}')result[tag] = train_coeff(tag, 3)# try some of themmove_and_show_samples(result['open_mouth'])# play a bit more, training charater specify encoder?charas = ['hakurei_reimu', 'kirisame_marisa']# Let's check out how many samples we gotfor chara in charas:[chara_index], = np.where(tags == chara)count = [x[chara_index] for x in labels_data if x[chara_index] > 0.5]print(f'{chara}: {len(count)}, {(len(count) / len(labels_data)) * 100}%')result[chara] = train_coeff(chara, 3)# too rare, properly don't workmove_and_show_samples(result['hakurei_reimu'])move_and_show_samples(result['kirisame_marisa'])# store the resulttorch.save(result, 'checkpoint/directions.pth')'''
【GANS】对Stylgan2的动漫头像预训练模型进行【属性编辑】相关推荐
- bert中文预训练模型_HFL中文预训练系列模型已接入Transformers平台
哈工大讯飞联合实验室(HFL)在前期陆续发布了多个中文预训练模型,目前已成为最受欢迎的中文预训练资源之一.为了进一步方便广大用户的使用,借助Transformers平台可以更加便捷地调用已发布的中文预 ...
- DCGAN:生成动漫头像
个人博客:http://www.chenjianqu.com/ 原文链接:http://www.chenjianqu.com/show-55.html 上一篇博文<生成对抗网络(GAN)原理和实 ...
- 基于DCGAN的动漫头像生成
基于DCGAN的动漫头像生成 数据 数据集:动漫图库爬虫获得,经过数据清洗,裁剪得到动漫头像.分辨率为3 * 96 * 96,共5万多张动漫头像的图片,从知乎用户何之源处下载. 生成器:输入为随机噪声 ...
- 基于GAN的动漫头像生成系统(源码&教程)
1.研究背景 我们都喜欢动漫角色,并试图创造我们的定制角色.然而,要掌握绘画技巧需要巨大的努力,之后我们首先有能力设计自己的角色.为了弥补这一差距,动画角色的自动生成提供了一个机会,在没有专业技能的情 ...
- 使用TensorFlow2.0搭建DCGAN生成动漫头像(内含生成过程GIF图)
文章目录 生成对抗网络介绍 一.造假 二.训练判别器 三.训练生成器 DCGAN介绍 搭建DCGAN 数据来源 必要工作 读取数据 构建生成器 构建判别器 连接模型 连接图片 生成函数 训练 生成对抗 ...
- 通过PyTorch用DCGAN生成动漫头像
数据集 数据集我们用AnimeFaces数据集,共5万多张动漫头像. 链接:https://pan.baidu.com/s/1cp-A8ZV74YBelkSuKxuM6A 提取码:face 要把所有的 ...
- 【PyTorch】12 生成对抗网络实战——用GAN生成动漫头像
GAN 生成动漫头像 1. 获取数据 2. 用GAN生成 2.1 Generator 2.2 Discriminator 2.3 其它细节 2.4 训练思路 3. 全部代码 4. 结果展示与分析 小结 ...
- 6.28 头像预览:form方法和ajax方法
一用form实现头像预览 <!DOCTYPE html> <html lang="zh"> <head><meta charset=&qu ...
- 【百战GAN】二次元宅们,给自己做一个专属动漫头像可好!
大家好,欢迎来到专栏<百战GAN>,在这个专栏里,我们会进行算法的核心思想讲解,代码的详解,模型的训练和测试等内容. 作者&编辑 | 言有三 本文资源与生成结果展示 本文篇幅:68 ...
最新文章
- 将博客搬至51CTO
- 用ASP自动生成SQL数据库的安装源程序
- win7冒险岛java,win7玩冒险岛不兼容怎么办?解决win7玩冒险岛不兼容的方法
- linux启动tongweb命令,linux7开机自启动东方通tongweb
- 基于深度学习的文本分类1
- autowired注入jar中的依赖_springboot项目中调用jar包中的类时报错 — 没有注入类
- Java Thread等待,通知和notifyAll示例
- 17.高性能MySQL --- MySQL 分支与变种
- uni实现前端分页功能
- 计算机类专业分类及优缺点,计算机专业优势介绍及学科分类
- 蓝牙连接手表后微信语音来电不响铃
- HTML5中fillStyle和fillRect的先后顺序的理解
- 基于Python的SAP流程自动化
- 深圳现货原装正品JST提供各种专为汽车市场而设的线束连接器PAP-12V-S PAP-13V-S PAP-10V-S PHR-9 PHR-10
- win8系统 ps不能直接拖入图片的问题!解决方法
- 合影效果java_排序入门练习题9 合影效果 题解
- 音视频处理工具ffmpeg基本使用
- 玩游戏掉帧严重?看过来!
- WebDAV之葫芦儿·派盘+FX文件管理器
- Wind River workbench小结
热门文章
- 京津冀计算机学科大学排名,京津冀十强高校排名,这三所大学进入前十,实力强悍...
- FreeType与CFF
- rap2检测哪些接口在使用_RAP2介绍
- ChatGPT能接入微信了!
- 数据库版本管理工具-flyway
- tidytext | 耳目一新的R-style文本分析库
- springboot使用swagger2时,访问http://localhost:8080/swagger-ui.html页面404,无法访问
- ELK+kafaka+filebeat实现系统日志收集与预警
- 如何把程序钉到Windows7任务栏(修正版)
- Spinnaker 初探