[论文笔记随手] Training with Weighted Sum of Denoising Score Matching Objectives
[note] Training with Weighted Sum of Denoising Score Matching Objectives
利用 去噪分数匹配目标的加权和 进行训练,去噪指的是使用sde的方法就不需要自行补充噪声了。
本文的目的是解释如何对原始数据进行扰动。 from https://yang-song.github.io/blog/2021/score/
一、理论
首先,挑选一个随机过程(SDE)对原始数据分布p0p_0p0进行扰动得到扰动后数据的概率密度分布ptp_tpt。
本文选择的随机过程为:
dx=σtdw,t∈[0,1]d{\bf x} = \sigma^td{\bf w}, \ t\in[0,1] dx=σtdw, t∈[0,1]
在这种情况下,扰动后数据的概率密度分布ptp_tpt,在原始数据下的条件概率分布为:
p0t(x(t)∣x(0))=N(x(t);x(0),12logσ(σ2t−1)I)p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0)) = \mathcal{N}\bigg(\mathbf{x}(t); \mathbf{x}(0), \frac{1}{2\log \sigma}(\sigma^{2t} - 1) \mathbf{I}\bigg) p0t(x(t)∣x(0))=N(x(t);x(0),2logσ1(σ2t−1)I)
关于这个函数的解释是,使用参数$ \frac{1}{2\log \sigma}(\sigma^{2t} - 1) 作为我们的权重函数,即作为我们的权重函数,即作为我们的权重函数,即\lambda(t) = \frac{1}{2 \log \sigma}(\sigma^{2t} - 1)$.
当参数σ\sigmaσ变得非常大的时候,其中的先验分布pt=1p_{t=1}pt=1,也就是最终扰动后的数据分布就可以变成一个正太分布:
∫p0(y)N(x;y,12logσ(σ2−1)I)dy≈N(x;0,12logσ(σ2−1)I),\int p_0(\mathbf{y})\mathcal{N}\bigg(\mathbf{x}; \mathbf{y}, \frac{1}{2 \log \sigma}(\sigma^2 - 1)\mathbf{I}\bigg) d \mathbf{y} \approx \mathbf{N}\bigg(\mathbf{x}; \mathbf{0}, \frac{1}{2 \log \sigma}(\sigma^2 - 1)\mathbf{I}\bigg), ∫p0(y)N(x;y,2logσ1(σ2−1)I)dy≈N(x;0,2logσ1(σ2−1)I),
直观地说,这个SDE通过一个变种函数12logσ(σ2t−1)\frac1{2\ log\ \sigma}(\sigma^{2t}-1)2 log σ1(σ2t−1)帮助我们捕获了高斯扰动的数据变量集合(连续统continuum),即x(t)x(t)x(t)。这个数据变量集合可以帮助我们逐渐将原始数据分布p0p_0p0变成了一个简单的高斯分布p1p_1p1,也就是t=1时候的分布。
二、代码实现
1) 对t进行连续采样
# 对时间特征t进行均匀采样random_t = torch.rand(x.x.shape[0]//30, device=device) * (1. - eps) + eps # 防止采样到0
2)定义权重函数
可以看到,这里定义的权重函数就是作者在上面提到的λ(t)\lambda(t)λ(t)函数。
def marginal_prob_std(t, sigma):# t = torch.tensor(t, device=device)return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / np.log(sigma))
3)对数据进行扰动
# 表征时间的特征t, 从0到1上进行均匀采样
random_t = torch.rand(batchsize, device=device) * (1. - eps) + eps # 这里的eps是为了防止采样到t=0# 构造一个与原始数据结构一样的向量,并在[0,1)上进行均匀采样。
z = torch.randn_like(x.x)# 利用前面均匀采样的时间特征t,求得权重函数的值,这个权重函数的目的就是为了使得t=1时的扰动数据达到一个正太分布的结果。重复30遍的目的是因为一轮训练中设置的batch_size = 30
std = marginal_prob_std_func(random_t).repeat(1, 30).view(-1, 1)# 这里将噪声与标准差相乘,
perturbed_x = copy.deepcopy(x)
perturbed_x.x += z * std
4)利用扰动的数据进行训练
需要补充一下,为了训练积分函数模型,目前的目标函数变成了下面这个样子:
Et∈u(0,T)Ept(x)[λ(t)∣∣∇xlogpt(x)−sθ(x,t)∣∣22]\mathbb{E}_{t\in u(0,T)}\mathbb{E}_{p_t(x)}[\lambda(t)||\nabla_xlog\ p_t(x)-s_\theta(x,t)||_2^2] Et∈u(0,T)Ept(x)[λ(t)∣∣∇xlog pt(x)−sθ(x,t)∣∣22]
这里是最基本的目标函数的样子:
Ep(x)[∣∣∇xlogp(x)−sθ(x)∣∣22]=∫p(x)∣∣∇xlogp(x)−sθ(x)∣∣22dx.\mathbb{E}_{p(x)}[{||\nabla_xlog\ p(x)\ -\ s_\theta(x)||}_2^2]\ =\ \int\ p(x)||\nabla_x\ log\ p(x)\ -\ s_\theta(x)||_2^2dx. Ep(x)[∣∣∇xlog p(x) − sθ(x)∣∣22] = ∫ p(x)∣∣∇x log p(x) − sθ(x)∣∣22dx.
为了估计这个目标函数,需要如下估计,即使用Score Matching的方法进行估计(Hyvärinen 2005):
可以看到,去估计如下的目标函数是可以达到的。
Epdata(x)[12∣∣sθ(x)∣∣22+trace(∇xsθ(x))]\mathbb{E}_{p_{data}(x)}[\frac12||s_\theta(x)||_2^2+trace(\nabla_xs_\theta(x))] Epdata(x)[21∣∣sθ(x)∣∣22+trace(∇xsθ(x))]
具体上,体现在代码上,用的是如下的公式:
1N∑i=1N[12∣∣sθ(xi)∣∣22+trace(∇xsθ(xi))]≈1N∑i=1N[12∣∣sθ(xi)∣∣22+trace(∇xsθ(xi))\frac1N\sum^N_{i=1}[\frac12||s_\theta(x_i)||_2^2+trace(\nabla_xs_\theta(x_i))] \\ \approx \frac1N \sum_{i=1}^N [\frac12||s_\theta(x_i)||_2^2+trace(\nabla_xs_\theta(x_i)) N1i=1∑N[21∣∣sθ(xi)∣∣22+trace(∇xsθ(xi))]≈N1i=1∑N[21∣∣sθ(xi)∣∣22+trace(∇xsθ(xi))
# 计算积分函数的值
output = model(perturbed_x, random_t)
# score matching的损失函数,与上式不一致的原因在于,本文的目标函数中还有一个参数\lambda(t),所以表现为如下的形式。
loss_ = torch.mean(torch.sum(((output * std + z)**2).view(batch_size, -1)), dim=-1)
# 一轮训练之后,将score matching的目标函数的结果返回
return loss_
[论文笔记随手] Training with Weighted Sum of Denoising Score Matching Objectives相关推荐
- 论文笔记 OHEM: Training Region-based Object Detectors with Online Hard Example Mining
CVPR2016的文章,CMU与rbg大神的合作.原谅我一直没有对这篇文章做一个笔记~~ 文章提出了一种通过online hard example mining(OHEM)算法训练基于区域的卷积检测算 ...
- 联邦学习论文笔记——FedFair: Training Fair Models In Cross-Silo Fedrated Learning
cross device 跨设备 / cross silo 跨数据井 介绍 ①我们能否轻松地扩展现有的加强公平的方法,以适应协作和隐私保护的需要? NO:现有的公平性增强方法主要是在一个统一的可用训练 ...
- 虚拟换衣 CP-VTON 论文笔记
CP-VTON 介绍 论文笔记 算法目的 主要贡献 CP-VTON 算法框架 Geometric Matching Module(几何匹配模块) Try-on Module(试穿模块) 总结 参考文献 ...
- 论文笔记【A Comprehensive Study of Deep Video Action Recognition】
论文链接:A Comprehensive Study of Deep Video Action Recognition 目录 A Comprehensive Study of Deep Video A ...
- 【论文笔记-NER综述】A Survey on Deep Learning for Named Entity Recognition
本笔记理出来综述中的点,并将大体的论文都列出,方便日后调研使用查找,详细可以看论文. 神经网络的解释: The forward pass com- putes a weighted sum of th ...
- 【论文笔记】半监督的多视图学习:Semi-supervised Multi-view Deep Discriminant Representation Learning
[论文笔记]Semi-supervised Multi-view Deep Discriminant Representation Learning 1. 概念 多视图学习(Multiview Lea ...
- 论文笔记:DeepFuse: A Deep Unsupervised Approach for Exposure Fusion with Extreme Exposure Image Pairs
论文笔记:DeepFuse: A Deep Unsupervised Approach for Exposure Fusion with Extreme Exposure Image Pairs co ...
- 论文笔记目录(ver2.0)
1 时间序列 1.1 时间序列预测 论文名称 来源 主要内容 论文笔记:DCRNN (Diffusion Convolutional Recurrent Neural Network: Data-Dr ...
- 【论文笔记】用循环一致性避免形变场重叠的医学图像配准网络
本文是论文<Cycle-Consistent Training for Reducing Negative Jacobian Determinant in Deep Registration N ...
最新文章
- 在Mac上控制Alt Delete-如何在Macbook上打开任务管理器
- 从思维导图学习操作系统(二)
- 深度学习常见概念解析
- “以毒攻毒”?阿里将上线“二哈”防骚扰电话应用程序
- mysql status uptime_MySQL优化(四) 慢查询的定位及优化
- mysql 查询某个值非空_MySQL查询在单行中计算非空值
- 5加载stm32 keil_【STM32笔记】在SRAM、FLASH中调试代码的配置方法(附详细步骤)...
- 全球国家(和地区)信息JSON数据
- I2S音频接口的理解
- 计算机-计算机发展史
- html设置本地字体文件
- 如何培养孩子的记忆力?猿辅导:这个方法家长一定要知道
- 华为计算机视觉博士,华为视觉计划发布,要做“智能世界的眼睛”
- 蛋糕究竟是怎样做大的
- 一些杂事之后,该收心了
- 基于深度学习和光流的地铁乘客上下车自动检测算法
- 入选31个细分领域丨通付盾荣登嘶吼安全产业研究院《2022网络安全产业图谱》
- DPI、像素与分辨率的区别和联系
- 巨亏超10亿!“汽车金融第一股”易鑫业绩腰斩,上半年却傍上腾讯
- UDP FLood拒绝服务攻击