核心参考

【代码】Pretrained Anime StyleGAN2 — convert to pytorch and editing images by encoder
【理论】http://www.seeprettyface.com/research_notes.html#step6

本博客代码地址

  • https://gitee.com/zengxy2020/csdn_stylegan2_edit_latent

0 概要

基于stylegan2-pytorch动漫预训练模型,
结合标签和分类器,可以对生成动漫头像进行某些特征编辑
编辑图像:
其中,中间图片(第4张图)为原图
向左生成闭嘴,向右生成张嘴

1. 下载项目相关

stylegan2-pytorch (非官方)

git clone https://github.com/viuts/stylegan2_pytorch.git

1.1 动漫头像(anime-face)模型下载

2020-01-11-skylion-stylegan2-animeportraits-networksnapshot-024664.pkl

  • google网盘
  • 百度网盘 - 提取码: 6b84
  • 其他下载链接

1.2 将tf模型转为pytorch

cd your_path/stylegan2_pytorchpython run_convert_from_tf.py --input=2020-01-11-skylion-stylegan2-animeportraits-networksnapshot-024664.pkl --output checkpoint

1.3 推理生成图片

python run_generator.py generate_images --network=checkpoint/Gs.pth \
--seeds=1-5 --truncation_psi=1.0
  • 生成图片保存在stylegan2_pytorch/results

2. 获得动漫属性便签

  • 核心步骤
  • web api
  • 预训练模型模型获得标签

web端示意效果

2.1 测试本地标签分类器推理

  • python edit_convert_cntk_2_onnx.py
  • 代码见附录

2.2 打标签

  • 生成4*5000张图片,并逐个得到标签,
  • 保存dlatend ,labels,tags,
    • 其中tags表示输出labels序号的对应的中文名称
  • 耗时
    • 2小时40分钟

3 训练需要编辑潜码方向

3.0 理论依据

  • http://www.seeprettyface.com/research_notes.html#step6

如下图所示,以标签的中位数为分界线,低于该值的标签改为0,高于该值的标签改为1,然后构造w·x+b=y的目标函数,运用逻辑斯蒂回归求解此二分类问题,求出的w就可以近似为我们需要的方向向量

3.1 以编辑头发颜色为例

代码edit_find_latend_direction.py附录

  • 训练结果

3.1.1 编辑结果1-粉色头发

  • (中间是原图)3
  • 左边是非粉色,右边是粉色插值

3.1.2 编辑结果2-黑色头发

3.1.3 编辑结果3-棕色色头发

3.2可能遇到问题

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
将tensor全部转到cpu或者gpu

one more thing

对其他预训练模型进行编辑(未开源

短发到长发

其他参考

  • https://www.gwern.net/Faces#stylegan-2

附录

edit_convert_cntk_2_onnx.py


import cntk as Cimport onnxruntime
import numpy as npdef convet():model_path = 'checkpoint/danbooru-resnet_custom_v1-p3/model.cntk'ctnk_model = C.load_model(model_path)ctnk_model.save('checkpoint/model.onnx', format=C.ModelFormat.ONNX)return ctnk_modeldef test_cntk(ctnk_model):ort_session = onnxruntime.InferenceSession("checkpoint/model.onnx")x = np.random.rand(1, 3, 299, 299).astype(np.float32)# compute ONNX Runtime output predictionort_inputs = {ort_session.get_inputs()[0].name: x}ort_outs = ort_session.run(None, ort_inputs)# compute the ctnk outputcntk_out = ctnk_model.eval(x)np.testing.assert_allclose(np.array(cntk_out), ort_outs[0], rtol=1e-03, atol=1e-05)if __name__ == '__main__':ctnk_model=convet()test_cntk(ctnk_model)

edit_generate-labeled-anime-data.py

import os
import onnxruntime
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as pltinput_path = 'checkpoint'
tags_path = os.path.join(input_path, 'tags.txt')
model_path = os.path.join(input_path, 'model.onnx')
generator_path = os.path.join(input_path, 'Gs.pth')
device =  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 4
seed = 0# let's run one image to checkout if it works
C = onnxruntime.InferenceSession(model_path)with open(tags_path, 'r') as tags_stream:tags = np.array([tag for tag in (tag.strip() for tag in tags_stream) if tag])import stylegan2
from stylegan2 import utilsG = stylegan2.models.load(generator_path, map_location=device)
G.to(device)def to_image_tensor(image_tensor, pixel_min=-1, pixel_max=1):if pixel_min != 0 or pixel_max != 1:image_tensor = (image_tensor - pixel_min) / (pixel_max - pixel_min)return image_tensor.clamp(min=0, max=1)torch.manual_seed(seed)qlatents = torch.randn(1, G.latent_size).to(device=device, dtype=torch.float32)
generated = G(qlatents)
images = to_image_tensor(generated)
# 299 is the input size of the model
images = F.interpolate(images, size=(299, 299), mode='bilinear')
ort_inputs = {C.get_inputs()[0].name: images.detach().cpu().numpy()}
[predicted_labels] = C.run(None, ort_inputs)
# print out some tags
plt.imshow(images[0].detach().cpu().permute(1, 2, 0))
labels = [tags[i] for i, score in enumerate(predicted_labels[0]) if score > 0.5]
print(labels)# reset seed
torch.manual_seed(seed)
iteration = 5000progress = utils.ProgressWriter(iteration)
progress.write('Generating images...', step=False)qlatents_data = torch.Tensor(0, G.latent_size).to(device=device, dtype=torch.float32)
dlatents_data = torch.Tensor(0, 16, G.latent_size).to(device=device, dtype=torch.float32)
labels_data = torch.Tensor(0, len(tags)).to(device=device, dtype=torch.float32)
for i in range(iteration):qlatents = torch.randn(batch_size, G.latent_size).to(device=device, dtype=torch.float32)with torch.no_grad():generated, dlatents = G(latents=qlatents, return_dlatents=True)# inplace to save memorygenerated = to_image_tensor(generated)# 299 is the input size of the model# resize the image to 299 * 299images = F.interpolate(generated, size=(299, 299), mode='bilinear')labels = []## tagger does not take input as batch, need to feed one by onefor image in images:ort_inputs = {C.get_inputs()[0].name: image.reshape(1, 3, 299, 299).detach().cpu().numpy()}[[predicted_labels]] = C.run(None, ort_inputs)labels.append(predicted_labels)# store the resultlabels_tensor = torch.Tensor(labels).to(device=device, dtype=torch.float32)qlatents_data = torch.cat((qlatents_data, qlatents))dlatents_data = torch.cat((dlatents_data, dlatents))labels_data = torch.cat((labels_data, labels_tensor))progress.step()progress.write('Done!', step=False)
progress.close()torch.save({'qlatents_data': qlatents_data.cpu(),'dlatents_data': dlatents_data.cpu(),'labels_data': labels_data.cpu(),'tags': tags
}, 'latents.pth')

edit_find_latend_direction.py

import torch
import matplotlib.pylab as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import osimport stylegan2input_path = 'checkpoint'
latents_path = os.path.join(input_path, 'latents.pth')
generator_path = os.path.join(input_path, 'Gs.pth')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')#根据资源进行设置
batch_size = 16seed = 0state = torch.load(latents_path, map_location=device)qlatents_data = state['qlatents_data']
dlatents_data = state['dlatents_data']
labels_data = state['labels_data']
tags = state['tags']G = stylegan2.models.load(generator_path).to(device)dlatents_data=dlatents_data.to(device=device, dtype=torch.float32)
labels_data=labels_data.to(device=device, dtype=torch.float32)
print("dlatents_data.size()",dlatents_data.size())
print("labels_data.size()",labels_data.size())zipped = list(zip(dlatents_data, labels_data))train_size = int(0.7 * len(zipped))
valid_size = int(len(zipped) * 0.2)
test_size = len(zipped) - train_size - valid_sizetrain_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(zipped, [train_size, valid_size, test_size])#参考代码num_workers=4会报错,可根据自己实际情况修改
datasets = dict(train=torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0),valid=torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=0),test=test_dataset,
)class Net(nn.Module):def __init__(self):super(Net, self).__init__()# kernelself.main = nn.Sequential(nn.Linear(in_features=16 * 512, out_features=1),nn.LeakyReLU(0.2),nn.Sigmoid(),)def forward(self, x):return self.main(x)def train_coeff(tag, total=5):model = Net()model=model.to(device)# create your optimizeroptimizer = optim.SGD(model.parameters(), lr=0.01)criterion = nn.BCELoss()[tag_index], = np.where(tags == tag)epoch = 0epoch_val_loss_min = 100while True:epoch += 1training_loss, valid_loss = 0.0, 0.0for phase in ['train', 'valid']:dataset = datasets[phase]if phase == 'train':model.train()  # Set model to training modeelse:model.eval()  # Set model to evaluate moderunning_loss = 0.0running_corrects = 0for (dlatents, labels) in dataset:# in your training loop:optimizer.zero_grad()  # zero the gradient bufferswith torch.set_grad_enabled(phase == 'train'):inputs = dlatents.reshape(-1, 16 * 512)inputs=inputs.to(device)output = model(inputs)targets = torch.Tensor(0, 1)targets = targets.to(device)for label in labels:value = label[tag_index]# value = 1.0 if value > 0.5 else 0.0new_label = torch.Tensor([[value]])new_label = new_label.to(device)targets = torch.cat((targets, new_label))loss = criterion(output, targets)if phase == 'train':loss.backward()optimizer.step()  # Does the update# statisticsrunning_loss += loss.item() * inputs.size(0)epoch_loss = running_loss / (len(dataset) * batch_size)print(f'Epoch:{epoch}/{total}, {phase} Loss: {epoch_loss:.4f}')#根据验证loss 保存最佳if phase=='valid':if epoch_loss < epoch_val_loss_min:epoch_val_loss_min=epoch_lossweight_val_min = model.state_dict()['main.0.weight'] # not biasdirection_path = f'checkpoint/directions_{tag}_val{epoch_loss}.pth'# torch.save(weight_val_min, direction_path)print(f" the best val is:{epoch_loss}")if epoch == total:breakweight = weight_val_minreturn weight.detach().cpu().reshape(16, 512)def generate_image(dlatents, pixel_min=-1, pixel_max=1):generated = G(dlatents=dlatents)if pixel_min != 0 or pixel_max != 1:generated = (generated - pixel_min) / (pixel_max - pixel_min)generated.clamp_(min=0, max=1)return generated.detach().cpu().reshape(3, 512, 512).permute(1, 2, 0)def move_and_show(latent_vector, direction, coeffs):img_list = []for i, coeff in enumerate(coeffs):new_latent_vector = latent_vector.clone()# direction=direction.to(direction)# print((new_latent_vector),(latent_vector),(direction),coeff)new_latent_vector[:8] = (latent_vector + coeff * direction)[:8]img=generate_image(new_latent_vector)img_list.append(img)# plt.show()return img_list"core"def move_and_show_samples(direction, direction_name,sample=3, coeffs=[-10,-5,2, 0,2, 5,10]):fig, ax = plt.subplots(sample, 1, figsize=(50, 50), dpi=80)for i,(latents, labels) in enumerate(list(datasets['test'])[:sample]):inputs = latents.clone().reshape(1, 16, 512)direction=direction.to(device)img_list=move_and_show(inputs, direction, coeffs)ax[i].imshow(np.hstack(img_list))plt.suptitle(f'Edit: {direction_name}',size=16)[x.axis('off') for x in ax] #取消网格plt.tight_layout() # 使图片自适应填充save_folder=f'./edit/'os.makedirs(save_folder,exist_ok=True)plt.savefig(f"{save_folder}/{direction_name}combine_{sample}.png")if __name__ == '__main__':result = {}'''训练迭代,一般3次就够了'''flag_train=1direction_save_path = f'checkpoint/direction.pth'if flag_train:picked_tags=['black_hair','pink_hair','open_mouth','brown_hair']# filter out the real tagspicked_tags = [tag for tag in picked_tags if tag in tags]print(picked_tags)for tag in picked_tags:print(f'training {tag}')result[tag] = train_coeff(tag, 3)'''保存所有维度'''torch.save(result, direction_save_path)else:result = torch.load(direction_save_path, map_location=device)'''可视化编辑结果'''for name in result.keys():move_and_show_samples(result[name],name,sample=5)'''## Let's pick some tags and train it!colors = ['aqua', 'black', 'blue', 'brown', 'green', 'grey', 'lavender', 'light_brown', 'multicolored', 'orange','pink', 'purple', 'red', 'silver', 'white', 'yellow']switches = ['open', 'closed', 'covered']# generate composition of elementscomponents = ['eyes', 'hair', 'mouth']picked_tags = []for component in components:picked_tags = picked_tags + [f'{color}_{component}' for color in colors]picked_tags = picked_tags + [f'{switch}_{component}' for switch in switches]# filter out the real tagspicked_tags = [tag for tag in picked_tags if tag in tags]print(picked_tags)## Train all these tags!for tag in picked_tags:print(f'training {tag}')result[tag] = train_coeff(tag, 3)# try some of themmove_and_show_samples(result['open_mouth'])# play a bit more, training charater specify encoder?charas = ['hakurei_reimu', 'kirisame_marisa']# Let's check out how many samples we gotfor chara in charas:[chara_index], = np.where(tags == chara)count = [x[chara_index] for x in labels_data if x[chara_index] > 0.5]print(f'{chara}: {len(count)}, {(len(count) / len(labels_data)) * 100}%')result[chara] = train_coeff(chara, 3)# too rare, properly don't workmove_and_show_samples(result['hakurei_reimu'])move_and_show_samples(result['kirisame_marisa'])# store the resulttorch.save(result, 'checkpoint/directions.pth')'''

【GANS】对Stylgan2的动漫头像预训练模型进行【属性编辑】相关推荐

  1. bert中文预训练模型_HFL中文预训练系列模型已接入Transformers平台

    哈工大讯飞联合实验室(HFL)在前期陆续发布了多个中文预训练模型,目前已成为最受欢迎的中文预训练资源之一.为了进一步方便广大用户的使用,借助Transformers平台可以更加便捷地调用已发布的中文预 ...

  2. DCGAN:生成动漫头像

    个人博客:http://www.chenjianqu.com/ 原文链接:http://www.chenjianqu.com/show-55.html 上一篇博文<生成对抗网络(GAN)原理和实 ...

  3. 基于DCGAN的动漫头像生成

    基于DCGAN的动漫头像生成 数据 数据集:动漫图库爬虫获得,经过数据清洗,裁剪得到动漫头像.分辨率为3 * 96 * 96,共5万多张动漫头像的图片,从知乎用户何之源处下载. 生成器:输入为随机噪声 ...

  4. 基于GAN的动漫头像生成系统(源码&教程)

    1.研究背景 我们都喜欢动漫角色,并试图创造我们的定制角色.然而,要掌握绘画技巧需要巨大的努力,之后我们首先有能力设计自己的角色.为了弥补这一差距,动画角色的自动生成提供了一个机会,在没有专业技能的情 ...

  5. 使用TensorFlow2.0搭建DCGAN生成动漫头像(内含生成过程GIF图)

    文章目录 生成对抗网络介绍 一.造假 二.训练判别器 三.训练生成器 DCGAN介绍 搭建DCGAN 数据来源 必要工作 读取数据 构建生成器 构建判别器 连接模型 连接图片 生成函数 训练 生成对抗 ...

  6. 通过PyTorch用DCGAN生成动漫头像

    数据集 数据集我们用AnimeFaces数据集,共5万多张动漫头像. 链接:https://pan.baidu.com/s/1cp-A8ZV74YBelkSuKxuM6A 提取码:face 要把所有的 ...

  7. 【PyTorch】12 生成对抗网络实战——用GAN生成动漫头像

    GAN 生成动漫头像 1. 获取数据 2. 用GAN生成 2.1 Generator 2.2 Discriminator 2.3 其它细节 2.4 训练思路 3. 全部代码 4. 结果展示与分析 小结 ...

  8. 6.28 头像预览:form方法和ajax方法

    一用form实现头像预览 <!DOCTYPE html> <html lang="zh"> <head><meta charset=&qu ...

  9. 【百战GAN】二次元宅们,给自己做一个专属动漫头像可好!

    大家好,欢迎来到专栏<百战GAN>,在这个专栏里,我们会进行算法的核心思想讲解,代码的详解,模型的训练和测试等内容. 作者&编辑 | 言有三 本文资源与生成结果展示 本文篇幅:68 ...

最新文章

  1. 将博客搬至51CTO
  2. 用ASP自动生成SQL数据库的安装源程序
  3. win7冒险岛java,win7玩冒险岛不兼容怎么办?解决win7玩冒险岛不兼容的方法
  4. linux启动tongweb命令,linux7开机自启动东方通tongweb
  5. 基于深度学习的文本分类1
  6. autowired注入jar中的依赖_springboot项目中调用jar包中的类时报错 — 没有注入类
  7. Java Thread等待,通知和notifyAll示例
  8. 17.高性能MySQL --- MySQL 分支与变种
  9. uni实现前端分页功能
  10. 计算机类专业分类及优缺点,计算机专业优势介绍及学科分类
  11. 蓝牙连接手表后微信语音来电不响铃
  12. HTML5中fillStyle和fillRect的先后顺序的理解
  13. 基于Python的SAP流程自动化
  14. 深圳现货原装正品JST提供各种专为汽车市场而设的线束连接器PAP-12V-S PAP-13V-S PAP-10V-S PHR-9 PHR-10
  15. win8系统 ps不能直接拖入图片的问题!解决方法
  16. 合影效果java_排序入门练习题9 合影效果 题解
  17. 音视频处理工具ffmpeg基本使用
  18. 玩游戏掉帧严重?看过来!
  19. WebDAV之葫芦儿·派盘+FX文件管理器
  20. Wind River workbench小结

热门文章

  1. 京津冀计算机学科大学排名,京津冀十强高校排名,这三所大学进入前十,实力强悍...
  2. FreeType与CFF
  3. rap2检测哪些接口在使用_RAP2介绍
  4. ChatGPT能接入微信了!
  5. 数据库版本管理工具-flyway
  6. tidytext | 耳目一新的R-style文本分析库
  7. springboot使用swagger2时,访问http://localhost:8080/swagger-ui.html页面404,无法访问
  8. ELK+kafaka+filebeat实现系统日志收集与预警
  9. 如何把程序钉到Windows7任务栏(修正版)
  10. Spinnaker 初探