在完成VITS论文学习后,对github上的官方仓库进行学习,帮助理解算法实现过程中的一些细节;仓库代码基于pytorch实现,链接为https://github.com/jaywalnut310/vits。论文和代码中都针对单speaker的数据集LJSpeech和多speaker的数据集VCTK进行了训练,本笔记主要针对多speaker设置下的训练代码进行注释解析,主要涉及仓库项目中的train_ms.py文件。

train_ms.py

VITS训练时,使用了混合精度训练,并且设置了对抗训练模式;其中判别器使用了多周期判别器,由多个子判别器组成,并且生成过程损失中还加上了feature_map损失。训练过程中,不是对完整的音频文件进行训练,而是提取一部分音频数据进行训练,进而在计算损失时,也要从ground truth中提取对应部分的数值进行计算。具体的训练代码及注释如下:

import os
import json
import argparse
import itertools
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScalerimport commons
import utils
from data_utils import (TextAudioSpeakerLoader,TextAudioSpeakerCollate,DistributedBucketSampler
)
from models import (SynthesizerTrn,MultiPeriodDiscriminator,
)
from losses import (generator_loss,discriminator_loss,feature_loss,kl_loss
)
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from text.symbols import symbolstorch.backends.cudnn.benchmark = True
global_step = 0def main():"""Assume Single Node Multi GPUs Training Only;只考虑单机多卡训练"""assert torch.cuda.is_available(), "CPU training is not allowed."n_gpus = torch.cuda.device_count()os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '80000'hps = utils.get_hparams()  # 获取参数超参数mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))# run函数中是实际训练代码
def run(rank, n_gpus, hps):global global_stepif rank == 0:logger = utils.get_logger(hps.model_dir)logger.info(hps)utils.check_git_hash(hps.model_dir)writer = SummaryWriter(log_dir=hps.model_dir)writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)torch.manual_seed(hps.train.seed)torch.cuda.set_device(rank)train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)  # 加载数据集# 分布式的基于桶的samplertrain_sampler = DistributedBucketSampler(train_dataset,hps.train.batch_size,[32, 300, 400, 500, 600, 700, 800, 900, 1000],  # 桶排序的边界num_replicas=n_gpus,rank=rank,shuffle=True)collate_fn = TextAudioSpeakerCollate()# 构建训练数据train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True,collate_fn=collate_fn, batch_sampler=train_sampler)if rank == 0:  # 在主机上进行验证,即此处是在主机上加载验证数据集eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False,batch_size=hps.train.batch_size, pin_memory=True,drop_last=False, collate_fn=collate_fn)# 生成器,表示文本到音频的整个模型net_g = SynthesizerTrn(len(symbols),hps.data.filter_length // 2 + 1,hps.train.segment_size // hps.data.hop_length,n_speakers=hps.data.n_speakers,**hps.model).cuda(rank)# 多周期的判别器net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)# 生成器的优化器optim_g = torch.optim.AdamW(net_g.parameters(),hps.train.learning_rate,betas=hps.train.betas,eps=hps.train.eps)# 判别器的优化器optim_d = torch.optim.AdamW(net_d.parameters(),hps.train.learning_rate,betas=hps.train.betas,eps=hps.train.eps)# 多卡分布式训练,使用DDP把生成器和判别器包裹起来net_g = DDP(net_g, device_ids=[rank])net_d = DDP(net_d, device_ids=[rank])try:  # 尝试加载可能存在的通过训练已经保存的模型参数_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,optim_g)_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,optim_d)global_step = (epoch_str - 1) * len(train_loader)except:epoch_str = 1global_step = 0# 定义生成器和判别器的学习率schedulescheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)scaler = GradScaler(enabled=hps.train.fp16_run)  # 混合精度训练for epoch in range(epoch_str, hps.train.epochs + 1):if rank == 0:  # 如果为主机,除了参入正常训练参数,还需要传入验证数据集、logger等其他参数train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,[train_loader, eval_loader], logger, [writer, writer_eval])else:train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,[train_loader, None], None, None)# 更新学习率scheduler_g.step()scheduler_d.step()# 训练和验证函数
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):net_g, net_d = nets  # 生成器和判别器optim_g, optim_d = optimsscheduler_g, scheduler_d = schedulerstrain_loader, eval_loader = loadersif writers is not None:writer, writer_eval = writerstrain_loader.batch_sampler.set_epoch(epoch)  # 设置train_loader中桶排序的随机种子,随机种子是每次的epoch,用于打乱数据,但也可以复现global global_stepnet_g.train()net_d.train()for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(train_loader):x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True)spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)speakers = speakers.cuda(rank, non_blocking=True)with autocast(enabled=hps.train.fp16_run):  # 模型计算部分进行半精度计算# 对整个音频序列采样进行训练,不是把整个音频序列送入进行训练,降低训练所需资源,ids_slice就对应采样后频谱的id# y_hat是预测的音频波形,l_length是时长预测器的损失,attn是对齐矩阵或时长信息y_hat, l_length, attn, ids_slice, x_mask, z_mask, \(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers)# 将线性谱转为mel谱图,便于后续计算L_reconmel = spec_to_mel_torch(spec,hps.data.filter_length,hps.data.n_mel_channels,hps.data.sampling_rate,hps.data.mel_fmin,hps.data.mel_fmax)# 以ids_slice作为指导,采样对应窗口的mel谱图作为targety_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)# 从生成的音频波形y_hat中提取对应的mel谱图y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1),hps.data.filter_length,hps.data.n_mel_channels,hps.data.sampling_rate,hps.data.hop_length,hps.data.win_length,hps.data.mel_fmin,hps.data.mel_fmax)# 从完整的音频数据中以ids_slice获取对应窗口部分的音频数据;判别器判别时需要真实波形数据y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size)  # slice# Discriminator;y_d_hat_r, y_d_hat_g记录所有子判别器对batch中真实波形y和生成波形y_hat的判别结果y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())with autocast(enabled=False):  # 损失的计算不进行半精度计算loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)  # 判别器的损失loss_disc_all = loss_disc# 判别器更新optim_d.zero_grad()scaler.scale(loss_disc_all).backward()scaler.unscale_(optim_d)  # 梯度剪裁前先进行unscalegrad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)  # 梯度剪裁scaler.step(optim_d)with autocast(enabled=hps.train.fp16_run):# Generator# 将生成的波形和真实波形分别送入到判别器中,希望两者在判别器的中间特征尽可能保持一致,即论文中的L_{fm},需要fmap_r, fmap_g进行计算y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)with autocast(enabled=False):loss_dur = torch.sum(l_length.float())  # 时间预测器loss,直接求和loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel  # 重构loss,论文中系数c_mel为45# 计算模型基于文本学习到的先验分布和从音频线性谱图中学习到的后验分布之间的KL散度,系数c_kl为1loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_klloss_fm = feature_loss(fmap_r, fmap_g)  # feature map 的lossloss_gen, losses_gen = generator_loss(y_d_hat_g)  # 生成器的对抗lossloss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl# 生成器更新optim_g.zero_grad()scaler.scale(loss_gen_all).backward()scaler.unscale_(optim_g)grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)scaler.step(optim_g)scaler.update()# 主卡上进行loss打印、记录和模型验证、保存if rank == 0:if global_step % hps.train.log_interval == 0:lr = optim_g.param_groups[0]['lr']losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]logger.info('Train Epoch: {} [{:.0f}%]'.format(epoch,100. * batch_idx / len(train_loader)))logger.info([x.item() for x in losses] + [global_step, lr])scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr,"grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}  # 记录损失和梯度scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl})scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})# 以图像的形式记录mel谱图和对齐信息image_dict = {"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),"all/attn": utils.plot_alignment_to_numpy(attn[0, 0].data.cpu().numpy())}# 调用定义的tensorboard的writer记录上述信息utils.summarize(writer=writer,global_step=global_step,images=image_dict,scalars=scalar_dict)if global_step % hps.train.eval_interval == 0:evaluate(hps, net_g, eval_loader, writer_eval)  # 验证# 保存生成器和判别器的参数utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))global_step += 1if rank == 0:logger.info('====> Epoch: {}'.format(epoch))# 验证
def evaluate(hps, generator, eval_loader, writer_eval):generator.eval()  # 验证模式with torch.no_grad():for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(eval_loader):x, x_lengths = x.cuda(0), x_lengths.cuda(0)spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)y, y_lengths = y.cuda(0), y_lengths.cuda(0)speakers = speakers.cuda(0)# remove elsex = x[:1]x_lengths = x_lengths[:1]spec = spec[:1]spec_lengths = spec_lengths[:1]y = y[:1]y_lengths = y_lengths[:1]speakers = speakers[:1]breaky_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, max_len=1000)  # 基于文本生成音频y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length# 提取真实的mel谱图mel = spec_to_mel_torch(spec,hps.data.filter_length,hps.data.n_mel_channels,hps.data.sampling_rate,hps.data.mel_fmin,hps.data.mel_fmax)# 从预测的音频的提取mel谱图y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1).float(),hps.data.filter_length,hps.data.n_mel_channels,hps.data.sampling_rate,hps.data.hop_length,hps.data.win_length,hps.data.mel_fmin,hps.data.mel_fmax)image_dict = {"gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())}audio_dict = {"gen/audio": y_hat[0, :, :y_hat_lengths[0]]}if global_step == 0:image_dict.update({"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})audio_dict.update({"gt/audio": y[0, :, :y_lengths[0]]})# 记录信息utils.summarize(writer=writer_eval,global_step=global_step,images=image_dict,audios=audio_dict,audio_sampling_rate=hps.data.sampling_rate)generator.train()if __name__ == "__main__":main()

losses.py

从论文中可知,本模型训练过程中涉及很多的损失,对抗训练过程中,判别器是常规的判别器损失结构,但是使用的是多周期判别器,由多个子判别器组成;生成器的损失,包括mel重建损失、KL散度、时长预测器损失、对抗训练生成损失以及特征图损失,其中时长预测器损失在模型forward函数中直接计算、mel重建损失是直接计算L1损失,剩下的四种损失在losses.py文件中定义,代码如下:

import torch
from torch.nn import functional as Fimport commons# 计算对抗训练中生成波形和真实波形在判别器中间特征之间的距离损失
def feature_loss(fmap_r, fmap_g):loss = 0for dr, dg in zip(fmap_r, fmap_g):  # 遍历真实波形和预测波形在判别器每层的特征图for rl, gl in zip(dr, dg):rl = rl.float().detach()gl = gl.float()loss += torch.mean(torch.abs(rl - gl))  # 计算L1损失return loss * 2# 判别器损失
def discriminator_loss(disc_real_outputs, disc_generated_outputs):loss = 0r_losses = []g_losses = []for dr, dg in zip(disc_real_outputs, disc_generated_outputs):  # 遍历多个子判别器的判别结果dr = dr.float()  # 一个子判别器对真实波形的判别结果dg = dg.float()  # 一个子判别器对生成波形的判别结果r_loss = torch.mean((1 - dr) ** 2)  # 真实波形的判别结果越接近于1越好g_loss = torch.mean(dg ** 2)  # 生成波形的判别结果越接近于0越好loss += (r_loss + g_loss)  # 累加当前子判别器的损失r_losses.append(r_loss.item())g_losses.append(g_loss.item())return loss, r_losses, g_losses# 生成器的对抗损失,就是将生成器生成的波形经过判别器后的输出与1计算距离损失,L2损失
def generator_loss(disc_outputs):loss = 0gen_losses = []for dg in disc_outputs:dg = dg.float()l = torch.mean((1 - dg) ** 2)gen_losses.append(l)loss += lreturn loss, gen_losses# 先验分布和后验分布之间的KL散度
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):"""z_p, logs_q: [b, h, t_t]m_p, logs_p: [b, h, t_t]"""z_p = z_p.float()logs_q = logs_q.float()m_p = m_p.float()logs_p = logs_p.float()z_mask = z_mask.float()kl = logs_p - logs_q - 0.5kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2. * logs_p)kl = torch.sum(kl * z_mask)l = kl / torch.sum(z_mask)return l

本笔记主要记录vits官方仓库中模型训练相关代码,其中涉及到的一些辅助函数,如果有必要后续会进行补充。本笔记主要是对代码进行详细的注释,读者若发现问题或错误,请评论指出,互相学习。

vits官方gituhb项目--模型训练相关推荐

  1. IDDPM官方gituhb项目--模型构建

    在完成IDDPM论文学习后,对github上的官方仓库进行学习,通过具体的代码理解算法实现过程中的一些细节:官方仓库代码基于pytorch实现,链接为https://github.com/openai ...

  2. 基于yolov5的目标检测和模型训练(Miniconda3+PyTorch+Pycharm+实战项目——装甲板识别)

    目录 一.环境配置和源码获取 1.Miniconda 2.MIniconda虚拟环境配置PyTorch 3.yolov5项目源码 4.pycharm 二.目标检测 三.模型训练 1.数据集 1.ima ...

  3. 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...

    (文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...

  4. 可以在手机里运行的Detectron2来了:Facebook官方出品,支持端到端模型训练、量化和部署...

    鱼羊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 做目标检测.语义分割,你一定听说过Detectron2. 作为一个基于PyTorch实现的模块化目标检测库,Detectron2当年刚一开源 ...

  5. Kaggle经典数据分析项目:泰坦尼克号生存预测!1. 数据概述与可视化2. 数据预处理3. 模型训练4. 模型优化(调参)

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习 ,不错过 Datawhale干货 作者:陈锴,中山大学,Datawhale成员 最近有很多读者留言,希望 ...

  6. yolov3(一:模型训练)

    第一部分:训练已有的voc datasets 搞清楚该算法的模型训练流程 Darknet是Joseph维护的开源的神经网络框架,使用C语言编写:https://pjreddie.com/darknet ...

  7. 轻松学Pytorch – 行人检测Mask-RCNN模型训练与使用

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 大家好,这个是轻松学Pytorch的第20篇的文章分享,主要是给大 ...

  8. 微软nni_实践空间站 | 为微软官方开源项目贡献代码,你准备好了吗?

    亟需一个契机重新驱动你在冬日沉睡的大脑? 2020 年春季学期微软学生俱乐部实践空间站项目正等待你大展身手! 实践空间站是微软学生俱乐部打造的全学年持续性活动,通过项目导师指导与自主创新结合的方式,帮 ...

  9. alexeyab darknet 编译_【目标检测实战】Darknet—yolov3模型训练(VOC数据集)

    原文发表在:语雀文档 0.前言 本文为Darknet框架下,利用官方VOC数据集的yolov3模型训练,训练环境为:Ubuntu18.04下的GPU训练,cuda版本10.0:cudnn版本7.6.5 ...

最新文章

  1. 【Linux 内核】进程管理 task_struct 结构体 ② ( state 字段 | stack 字段 | pid 字段 | tgid 字段 | pid_links 字段 )
  2. struts 2读书笔记-----struts2的开发流程
  3. 解决:TypeError: Value passed to parameter 'a' has DataType int64 not in list of allowed values: float1
  4. spring数据持久化
  5. 一只老猴子说的话,太经典了!
  6. mysql 安装、建库、导入导出数据
  7. Linux逻辑卷(LVM)技术详解
  8. 【宽度优先搜索】计蒜客:蒜头君回家(带条件的BFS)
  9. Java /Jsp 执行操作系统命令 windows/Linux
  10. fiddler中文乱码解决方案
  11. 基于微信校园二手书交易小程序系统 毕业设计毕设参考
  12. 如何配置百度地图应用访问白名单
  13. Elasticsearch(ES)生产集群健康状况为黄色(yellow)的官方详细解释、原因分析和解决方案(实测可用)
  14. solidity-msg.sender到底是什么?
  15. 安卓移动办公软件_尚朋高科TeeTek云端移动办公系统,云端软件5G时代的趋势
  16. html表格垂直居中的CSS代码,使用3行CSS代码使任何元素垂直居中
  17. matlab feedforward,premnmx(mapminmax) newff (feedforwardnet) tramnmx 如何使用
  18. 360网神奇安信管理地址_360网神桌面云管理系统
  19. php mysql 压力测试_MySQL的性能基线收集及压力测试
  20. 在国内市场《智能家居》的可行性发展

热门文章

  1. 知识普及篇之掌柜日记
  2. RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED
  3. html中数字换行字母换行
  4. 202309读书笔记|《大白鲸原创图画书优秀作品:虾一跳》——蝴蝶效应之最,你值得一读
  5. java.sql.SQLException: 无效的列类型: 错误解析
  6. StringTokenizer使用
  7. mt4交易平台哪个好?不妨试试福瑞斯外汇平台
  8. “自拍产业”的市场还大着呢
  9. php 采集qq头像,自带采集QQ头像表白墙源码
  10. 学习TypeScript数据类型-从零到英雄