[note] Training with Weighted Sum of Denoising Score Matching Objectives

利用 去噪分数匹配目标的加权和 进行训练,去噪指的是使用sde的方法就不需要自行补充噪声了。

本文的目的是解释如何对原始数据进行扰动。 from https://yang-song.github.io/blog/2021/score/

一、理论

首先,挑选一个随机过程(SDE)对原始数据分布p0p_0p0​进行扰动得到扰动后数据的概率密度分布ptp_tpt​。

本文选择的随机过程为:
dx=σtdw,t∈[0,1]d{\bf x} = \sigma^td{\bf w}, \ t\in[0,1] dx=σtdw, t∈[0,1]
在这种情况下,扰动后数据的概率密度分布ptp_tpt​,在原始数据下的条件概率分布为:
p0t(x(t)∣x(0))=N(x(t);x(0),12log⁡σ(σ2t−1)I)p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0)) = \mathcal{N}\bigg(\mathbf{x}(t); \mathbf{x}(0), \frac{1}{2\log \sigma}(\sigma^{2t} - 1) \mathbf{I}\bigg) p0t​(x(t)∣x(0))=N(x(t);x(0),2logσ1​(σ2t−1)I)
关于这个函数的解释是,使用参数$ \frac{1}{2\log \sigma}(\sigma^{2t} - 1) 作为我们的权重函数,即作为我们的权重函数,即作为我们的权重函数,即\lambda(t) = \frac{1}{2 \log \sigma}(\sigma^{2t} - 1)$.

当参数σ\sigmaσ变得非常大的时候,其中的先验分布pt=1p_{t=1}pt=1​,也就是最终扰动后的数据分布就可以变成一个正太分布:
∫p0(y)N(x;y,12log⁡σ(σ2−1)I)dy≈N(x;0,12log⁡σ(σ2−1)I),\int p_0(\mathbf{y})\mathcal{N}\bigg(\mathbf{x}; \mathbf{y}, \frac{1}{2 \log \sigma}(\sigma^2 - 1)\mathbf{I}\bigg) d \mathbf{y} \approx \mathbf{N}\bigg(\mathbf{x}; \mathbf{0}, \frac{1}{2 \log \sigma}(\sigma^2 - 1)\mathbf{I}\bigg), ∫p0​(y)N(x;y,2logσ1​(σ2−1)I)dy≈N(x;0,2logσ1​(σ2−1)I),
直观地说,这个SDE通过一个变种函数12logσ(σ2t−1)\frac1{2\ log\ \sigma}(\sigma^{2t}-1)2 log σ1​(σ2t−1)帮助我们捕获了高斯扰动的数据变量集合(连续统continuum),即x(t)x(t)x(t)。这个数据变量集合可以帮助我们逐渐将原始数据分布p0p_0p0​变成了一个简单的高斯分布p1p_1p1​,也就是t=1时候的分布。

二、代码实现

1) 对t进行连续采样

 # 对时间特征t进行均匀采样random_t = torch.rand(x.x.shape[0]//30, device=device) * (1. - eps) + eps # 防止采样到0

2)定义权重函数

可以看到,这里定义的权重函数就是作者在上面提到的λ(t)\lambda(t)λ(t)函数。

def marginal_prob_std(t, sigma):# t = torch.tensor(t, device=device)return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / np.log(sigma))

3)对数据进行扰动

# 表征时间的特征t, 从0到1上进行均匀采样
random_t = torch.rand(batchsize, device=device) * (1. - eps) + eps # 这里的eps是为了防止采样到t=0# 构造一个与原始数据结构一样的向量,并在[0,1)上进行均匀采样。
z = torch.randn_like(x.x)# 利用前面均匀采样的时间特征t,求得权重函数的值,这个权重函数的目的就是为了使得t=1时的扰动数据达到一个正太分布的结果。重复30遍的目的是因为一轮训练中设置的batch_size = 30
std = marginal_prob_std_func(random_t).repeat(1, 30).view(-1, 1)# 这里将噪声与标准差相乘,
perturbed_x = copy.deepcopy(x)
perturbed_x.x += z * std

4)利用扰动的数据进行训练

需要补充一下,为了训练积分函数模型,目前的目标函数变成了下面这个样子:
Et∈u(0,T)Ept(x)[λ(t)∣∣∇xlogpt(x)−sθ(x,t)∣∣22]\mathbb{E}_{t\in u(0,T)}\mathbb{E}_{p_t(x)}[\lambda(t)||\nabla_xlog\ p_t(x)-s_\theta(x,t)||_2^2] Et∈u(0,T)​Ept​(x)​[λ(t)∣∣∇x​log pt​(x)−sθ​(x,t)∣∣22​]

这里是最基本的目标函数的样子:
Ep(x)[∣∣∇xlogp(x)−sθ(x)∣∣22]=∫p(x)∣∣∇xlogp(x)−sθ(x)∣∣22dx.\mathbb{E}_{p(x)}[{||\nabla_xlog\ p(x)\ -\ s_\theta(x)||}_2^2]\ =\ \int\ p(x)||\nabla_x\ log\ p(x)\ -\ s_\theta(x)||_2^2dx. Ep(x)​[∣∣∇x​log p(x) − sθ​(x)∣∣22​] = ∫ p(x)∣∣∇x​ log p(x) − sθ​(x)∣∣22​dx.
为了估计这个目标函数,需要如下估计,即使用Score Matching的方法进行估计(Hyvärinen 2005):

可以看到,去估计如下的目标函数是可以达到的。

Epdata(x)[12∣∣sθ(x)∣∣22+trace(∇xsθ(x))]\mathbb{E}_{p_{data}(x)}[\frac12||s_\theta(x)||_2^2+trace(\nabla_xs_\theta(x))] Epdata​(x)​[21​∣∣sθ​(x)∣∣22​+trace(∇x​sθ​(x))]

具体上,体现在代码上,用的是如下的公式:
1N∑i=1N[12∣∣sθ(xi)∣∣22+trace(∇xsθ(xi))]≈1N∑i=1N[12∣∣sθ(xi)∣∣22+trace(∇xsθ(xi))\frac1N\sum^N_{i=1}[\frac12||s_\theta(x_i)||_2^2+trace(\nabla_xs_\theta(x_i))] \\ \approx \frac1N \sum_{i=1}^N [\frac12||s_\theta(x_i)||_2^2+trace(\nabla_xs_\theta(x_i)) N1​i=1∑N​[21​∣∣sθ​(xi​)∣∣22​+trace(∇x​sθ​(xi​))]≈N1​i=1∑N​[21​∣∣sθ​(xi​)∣∣22​+trace(∇x​sθ​(xi​))

# 计算积分函数的值
output = model(perturbed_x, random_t)
# score matching的损失函数,与上式不一致的原因在于,本文的目标函数中还有一个参数\lambda(t),所以表现为如下的形式。
loss_ = torch.mean(torch.sum(((output * std + z)**2).view(batch_size, -1)), dim=-1)
# 一轮训练之后,将score matching的目标函数的结果返回
return loss_

[论文笔记随手] Training with Weighted Sum of Denoising Score Matching Objectives相关推荐

  1. 论文笔记 OHEM: Training Region-based Object Detectors with Online Hard Example Mining

    CVPR2016的文章,CMU与rbg大神的合作.原谅我一直没有对这篇文章做一个笔记~~ 文章提出了一种通过online hard example mining(OHEM)算法训练基于区域的卷积检测算 ...

  2. 联邦学习论文笔记——FedFair: Training Fair Models In Cross-Silo Fedrated Learning

    cross device 跨设备 / cross silo 跨数据井 介绍 ①我们能否轻松地扩展现有的加强公平的方法,以适应协作和隐私保护的需要? NO:现有的公平性增强方法主要是在一个统一的可用训练 ...

  3. 虚拟换衣 CP-VTON 论文笔记

    CP-VTON 介绍 论文笔记 算法目的 主要贡献 CP-VTON 算法框架 Geometric Matching Module(几何匹配模块) Try-on Module(试穿模块) 总结 参考文献 ...

  4. 论文笔记【A Comprehensive Study of Deep Video Action Recognition】

    论文链接:A Comprehensive Study of Deep Video Action Recognition 目录 A Comprehensive Study of Deep Video A ...

  5. 【论文笔记-NER综述】A Survey on Deep Learning for Named Entity Recognition

    本笔记理出来综述中的点,并将大体的论文都列出,方便日后调研使用查找,详细可以看论文. 神经网络的解释: The forward pass com- putes a weighted sum of th ...

  6. 【论文笔记】半监督的多视图学习:Semi-supervised Multi-view Deep Discriminant Representation Learning

    [论文笔记]Semi-supervised Multi-view Deep Discriminant Representation Learning 1. 概念 多视图学习(Multiview Lea ...

  7. 论文笔记:DeepFuse: A Deep Unsupervised Approach for Exposure Fusion with Extreme Exposure Image Pairs

    论文笔记:DeepFuse: A Deep Unsupervised Approach for Exposure Fusion with Extreme Exposure Image Pairs co ...

  8. 论文笔记目录(ver2.0)

    1 时间序列 1.1 时间序列预测 论文名称 来源 主要内容 论文笔记:DCRNN (Diffusion Convolutional Recurrent Neural Network: Data-Dr ...

  9. 【论文笔记】用循环一致性避免形变场重叠的医学图像配准网络

    本文是论文<Cycle-Consistent Training for Reducing Negative Jacobian Determinant in Deep Registration N ...

最新文章

  1. 在Mac上控制Alt Delete-如何在Macbook上打开任务管理器
  2. 从思维导图学习操作系统(二)
  3. 深度学习常见概念解析
  4. “以毒攻毒”?阿里将上线“二哈”防骚扰电话应用程序
  5. mysql status uptime_MySQL优化(四) 慢查询的定位及优化
  6. mysql 查询某个值非空_MySQL查询在单行中计算非空值
  7. 5加载stm32 keil_【STM32笔记】在SRAM、FLASH中调试代码的配置方法(附详细步骤)...
  8. 全球国家(和地区)信息JSON数据
  9. I2S音频接口的理解
  10. 计算机-计算机发展史
  11. html设置本地字体文件
  12. 如何培养孩子的记忆力?猿辅导:这个方法家长一定要知道
  13. 华为计算机视觉博士,华为视觉计划发布,要做“智能世界的眼睛”
  14. 蛋糕究竟是怎样做大的
  15. 一些杂事之后,该收心了
  16. 基于深度学习和光流的地铁乘客上下车自动检测算法
  17. 入选31个细分领域丨通付盾荣登嘶吼安全产业研究院《2022网络安全产业图谱》
  18. DPI、像素与分辨率的区别和联系
  19. 巨亏超10亿!“汽车金融第一股”易鑫业绩腰斩,上半年却傍上腾讯
  20. UDP FLood拒绝服务攻击

热门文章

  1. 企业、政府单位微信公众号名称怎么修改?
  2. DUToj1085 Water Problem(矩阵快速幂)
  3. 【Python学习笔记】下划线的含义
  4. 【经典书】统计学中的因果推断
  5. 信息检索(基础知识一)——词项-文档关联矩阵及倒排索引构建
  6. HEIC图片格式如何快速转换呢?
  7. Hexo 博客优化之实用功能添加系列(持续更新)
  8. v-model组件使用
  9. 微型计算机sp作用,微机借口与技术试卷
  10. Android当中的MVP模式(七)终篇---关于对MVP模式中代码臃肿