4.1.2 启发源——粒子运动的 Langevin 动力学# Langevin 方程
Langevin 方程是描述一个粒子在势能场中受随机热扰动的动力学方程。设粒子位置为 x t ∈ R d x_t \in \mathbb{R}^d x t ∈ R d ,受到来自势能 U ( x t ) U(x_t) U ( x t ) 的确定性力 − ∇ x t U ( x t ) -\nabla_{x_t} U(x_t) − ∇ x t U ( x t ) 和来自热随机噪声的扰动 2 k T d w t \sqrt{2kT}dw_t 2 k T d w t 的影响,其 Langevin 动力学方程为
d x t = − ∇ x t U ( x t ) d t + 2 k T d w t dx_t = -\nabla_{x_t} U(x_t) dt + \sqrt{2kT}dw_t d x t = − ∇ x t U ( x t ) d t + 2 k T d w t 其中 d w t dw_t d w t 是标准布朗运动增量,k k k 是 Boltzmann 常数,T T T 是系统温度。
Langevin 动力学的稳态分布
Langevin 动力学方程定义了粒子在各个时刻位置的概率分布 p t ( x ) p_t(x) p t ( x ) ,其演化由 Fokker-Planck 方程描述
∂ p t ( x ) ∂ t = ∇ x ⋅ [ p t ( x ) ∇ x U ( x ) ] + k T Δ x p t ( x ) \frac{\partial p_t(x)}{\partial t} = \nabla_x \cdot \left[p_t(x) \nabla_x U(x)\right] + kT \Delta_x p_t(x) ∂ t ∂ p t ( x ) = ∇ x ⋅ [ p t ( x ) ∇ x U ( x ) ] + k T Δ x p t ( x ) 其中 Δ x p t ( x ) = ∇ x ⋅ ∇ x p t ( x ) \Delta_x p_t(x) = \nabla_x \cdot \nabla_x p_t(x) Δ x p t ( x ) = ∇ x ⋅ ∇ x p t ( x ) 用到了拉普拉斯算子表示。假设该动力学过程存在稳态分布 p s s ( x ) = lim t → ∞ p t ( x ) p_{ss}(x) = \lim_{t \to \infty}p_t(x) p ss ( x ) = lim t → ∞ p t ( x ) ,在 t → ∞ t \to \infty t → ∞ 时可以达到,则必定满足
∂ p t ( x ) ∂ t = 0 ⇒ ∇ x ⋅ [ p s s ( x ) ∇ x U ( x ) ] + k T Δ x p s s ( x ) = 0 \frac{\partial p_t(x)}{\partial t} = 0 \quad \Rightarrow \quad \nabla_x \cdot \left[p_{ss}(x) \nabla_x U(x)\right] + kT \Delta_x p_{ss}(x) = 0 ∂ t ∂ p t ( x ) = 0 ⇒ ∇ x ⋅ [ p ss ( x ) ∇ x U ( x ) ] + k T Δ x p ss ( x ) = 0 此偏微分方程的解是著名的 Boltzmann-Gibbs 分布
p s s ( x ) = 1 Z exp ( − 1 k T U ( x ) ) \boxed{ p_{ss}(x) = \frac{1}{Z} \exp(-\frac{1}{kT}U(x)) } p ss ( x ) = Z 1 exp ( − k T 1 U ( x )) 其中 Z = ∫ exp ( − 1 k T U ( x ) ) d x Z = \int \exp(-\frac{1}{kT}U(x)) dx Z = ∫ exp ( − k T 1 U ( x )) d x 是归一化常数。这个解的得出可以通过代入验证。具体过程详见附录 C.2 。
Boltzmann-Gibbs 分布表明,粒子在势能场 U ( x ) U(x) U ( x ) 中的稳态分布与势能成指数关系,低势能区域对应高概率密度。通过调节温度 T T T ,可以控制分布的平滑程度。
4.1.3 数学形式# 能量函数与概率分布
在概率能量模型中,数据分布 p d a t a ( x ) p_{data}(x) p d a t a ( x ) 通常表示为一个基于势能函数 U d a t a ( x ) U_{data}(x) U d a t a ( x ) 的形式
p d a t a ( x ) = 1 Z exp ( − U d a t a ( x ) ) p_{data}(x) = \frac{1}{Z} \exp(-U_{data}(x)) p d a t a ( x ) = Z 1 exp ( − U d a t a ( x )) 其中 Z = ∫ exp ( − U d a t a ( x ) ) d x Z = \int \exp(-U_{data}(x)) dx Z = ∫ exp ( − U d a t a ( x )) d x 是归一化常数。势能函数 U d a t a ( x ) U_{data}(x) U d a t a ( x ) 衡量数据点 x x x 的“非典型性”,低能量对应高概率。例如,Boltzmann-Gibbs 分布就是这种形式的一个例子。
通过对数梯度变换,可以得到数据分布的对数概率密度的梯度为
∇ x log p d a t a ( x ) = − ∇ x U d a t a ( x ) \nabla_x \log p_{data}(x) = -\nabla_x U_{data}(x) ∇ x log p d a t a ( x ) = − ∇ x U d a t a ( x ) 象征着数据分布的“力场”,指向概率密度增加的方向。我们把这个梯度定义为分数函数 (score function),写作
s d a t a ( x ) = ∇ x log p d a t a ( x ) = − ∇ x U d a t a ( x ) s_{data}(x) = \nabla_x \log p_{data}(x) = -\nabla_x U_{data}(x) s d a t a ( x ) = ∇ x log p d a t a ( x ) = − ∇ x U d a t a ( x ) 分数函数形式的 Langevin 动力学
受到概率能量模型的启发,如果可以准确估计出数据分布的分数函数 s d a t a ( x ) = ∇ x log p d a t a ( x ) s_{data}(x) = \nabla_x \log p_{data}(x) s d a t a ( x ) = ∇ x log p d a t a ( x ) ,就相当于得到了势能函数 U d a t a ( x ) U_{data}(x) U d a t a ( x ) 对应的保守力。根据 Langevin 动力学方程,最终方程对应的边缘概率分布 p t ( x t ) p_t(x_t) p t ( x t ) 是会收敛到数据分布 p d a t a ( x ) p_{data}(x) p d a t a ( x ) 的。这样,通过 Langevin SDE 采样,即可实现从先验分布 p p r i o r ( x ) p_{prior}(x) p p r i or ( x ) 到数据分布 p d a t a ( x ) p_{data}(x) p d a t a ( x ) 的转换。
替换 − ∇ x U ( x ) -\nabla_x U(x) − ∇ x U ( x ) 为 ∇ x log p d a t a ( x ) \nabla_x \log p_{data}(x) ∇ x log p d a t a ( x ) ,得到分数形式的 Langevin 动力学
d x t = η t 2 ∇ x t log p d a t a ( x t ) d t + η t d w t , t = 1 , … , T dx_t = \frac{\eta_t}{2} \nabla_{x_t}\log p_{data}(x_t) dt + \sqrt{\eta_t} dw_t, \quad t = 1, \ldots, T d x t = 2 η t ∇ x t log p d a t a ( x t ) d t + η t d w t , t = 1 , … , T 其中 η t \eta_t η t 是时间相关的步长参数,控制扩散和梯度更新的权重。
去噪分数匹配与目标函数
(i) 初始目标函数
分数匹配的核心在于如何估计数据分布的分数函数 s d a t a ( x ) = ∇ x log p d a t a ( x ) s_{data}(x) = \nabla_x \log p_{data}(x) s d a t a ( x ) = ∇ x log p d a t a ( x ) 。假设我们用一个参数化的神经网络模型 s θ ( x ) s_\theta(x) s θ ( x ) 来近似分数函数。我们可以通过最小化如下分数匹配目标函数来训练模型
L S M ( θ ) = E x ∼ p d a t a ( x ) ∥ s θ ( x ) − ∇ x log p d a t a ( x ) ∥ 2 \mathcal{L}_{SM}(\theta) = \mathbb{E}_{x \sim p_{data}(x)} \left\| s_\theta(x) - \nabla_x \log p_{data}(x) \right\|^2 L S M ( θ ) = E x ∼ p d a t a ( x ) ∥ s θ ( x ) − ∇ x log p d a t a ( x ) ∥ 2 然而,直接计算 ∇ x log p d a t a ( x ) \nabla_x \log p_{data}(x) ∇ x log p d a t a ( x ) 通常是不可行的,因为 p d a t a ( x ) p_{data}(x) p d a t a ( x ) 的形式未知且难以估计。
(ii) 去噪分数匹配
为了解决这个问题,SMLD 引入了去噪分数匹配 (Denoising Score Matching, DSM) 的概念。对原始数据 z ∼ p d a t a ( z ) z \sim p_{data}(z) z ∼ p d a t a ( z ) (替换之前的符号 x x x 为 z z z 以示区分) 添加高斯噪声,得到带噪数据
x σ = z + σ ϵ , ϵ ∼ N ( 0 , I d ) x_\sigma = z + \sigma\epsilon,\quad \epsilon \sim \mathcal{N}(0, I_d) x σ = z + σ ϵ , ϵ ∼ N ( 0 , I d ) 其中 σ \sigma σ 是噪声标准差,控制噪声水平。由此得到带噪数据的条件分布为
q σ ( x σ ∣ z ) = N ( x σ ; z , σ 2 I d ) = 1 2 π exp [ − ( x σ − z ) 2 2 σ 2 ] q_\sigma(x_\sigma | z) = \mathcal{N}(x_\sigma; z, \sigma^2 I_d) = \frac{1}{\sqrt{2\pi}}\exp\left[-\frac{(x_\sigma - z)^2}{2\sigma^2}\right] q σ ( x σ ∣ z ) = N ( x σ ; z , σ 2 I d ) = 2 π 1 exp [ − 2 σ 2 ( x σ − z ) 2 ] 通过积分,得到带噪数据的边缘分布为
q σ ( x σ ) = ∫ q σ ( x σ ∣ z ) p d a t a ( z ) d z q_\sigma(x_\sigma) = \int q_\sigma(x_\sigma | z) p_{data}(z) dz q σ ( x σ ) = ∫ q σ ( x σ ∣ z ) p d a t a ( z ) d z (iii) 去噪分数匹配与条件去噪分数匹配的等价性
展开边缘含噪数据分布
∇ x σ log q σ ( x σ ) = ∇ x σ ∫ q σ ( x σ ∣ z ) p d a t a ( z ) d z q σ ( x σ ) = ∫ ∇ x σ q σ ( x σ ∣ z ) p d a t a ( z ) d z q σ ( x σ ) = ∫ ∇ x σ log q σ ( x σ ∣ z ) p ( z ∣ x σ ) d z = E z ∼ p ( z ∣ x σ ) [ ∇ x σ log q σ ( x σ ∣ z ) ] \begin{aligned} \nabla_{x_\sigma}\log q_\sigma(x_\sigma) &= \frac{\nabla_{x_\sigma}\int q_\sigma(x_\sigma | z)p_{data}(z)dz}{q_\sigma(x_\sigma)} \\ &= \frac{\int \nabla_{x_\sigma}q_\sigma(x_\sigma | z)p_{data}(z)dz}{q_\sigma(x_\sigma)} \\ &= \int \nabla_{x_\sigma}\log q_\sigma(x_\sigma | z) p(z | x_\sigma)dz \\ &= \mathbb{E}_{z \sim p(z | x_\sigma)}\left[\nabla_{x_\sigma}\log q_\sigma(x_\sigma | z)\right] \end{aligned} ∇ x σ log q σ ( x σ ) = q σ ( x σ ) ∇ x σ ∫ q σ ( x σ ∣ z ) p d a t a ( z ) d z = q σ ( x σ ) ∫ ∇ x σ q σ ( x σ ∣ z ) p d a t a ( z ) d z = ∫ ∇ x σ log q σ ( x σ ∣ z ) p ( z ∣ x σ ) d z = E z ∼ p ( z ∣ x σ ) [ ∇ x σ log q σ ( x σ ∣ z ) ] 得到 Tweedie-Stein 公式
∇ x σ log q σ ( x σ ) = E z ∼ p ( z ∣ x σ ) [ ∇ x σ log q σ ( x σ ∣ z ) ] \nabla_{x_\sigma}\log q_\sigma(x_\sigma) = \mathbb{E}_{z \sim p(z | x_\sigma)}\left[\nabla_{x_\sigma}\log q_\sigma(x_\sigma | z)\right] ∇ x σ log q σ ( x σ ) = E z ∼ p ( z ∣ x σ ) [ ∇ x σ log q σ ( x σ ∣ z ) ] 设去噪分数匹配的目标函数为
L D S M ( θ ; σ ) = E x σ ∼ q σ ( x σ ) [ 1 2 ∥ s θ ( x σ ) − ∇ x σ log q σ ( x σ ) ∥ 2 ] \mathcal{L}_{DSM}(\theta; \sigma) = \mathbb{E}_{x_\sigma \sim q_\sigma(x_\sigma)}\left[\frac{1}{2}\left\Vert s_\theta(x_\sigma) - \nabla_{x_\sigma}\log q_\sigma(x_\sigma)\right\Vert^2\right] L D S M ( θ ; σ ) = E x σ ∼ q σ ( x σ ) [ 2 1 ∥ s θ ( x σ ) − ∇ x σ log q σ ( x σ ) ∥ 2 ] 展开可以得到
L D S M ( θ ; σ ) = 1 2 ( E x σ ∥ s θ ( x σ ) ∥ 2 + E x σ ∥ ∇ x σ log q σ ( x σ ) ∥ 2 − 2 E x σ [ s θ ( x σ ) ⊤ ∇ x σ log q σ ( x σ ) ] ) \mathcal{L}_{DSM}(\theta; \sigma) = \frac{1}{2}\left(\mathbb{E}_{x_\sigma}\left\Vert s_\theta(x_\sigma)\right\Vert^2 + \mathbb{E}_{x_\sigma}\left\Vert \nabla_{x_\sigma}\log q_\sigma(x_\sigma)\right\Vert^2 - 2\mathbb{E}_{x_\sigma} \left[s_\theta(x_\sigma)^\top \nabla_{x_\sigma}\log q_\sigma(x_\sigma)\right]\right) L D S M ( θ ; σ ) = 2 1 ( E x σ ∥ s θ ( x σ ) ∥ 2 + E x σ ∥ ∇ x σ log q σ ( x σ ) ∥ 2 − 2 E x σ [ s θ ( x σ ) ⊤ ∇ x σ log q σ ( x σ ) ] ) 再定义条件去噪分数匹配的目标函数为
L C D S M ( θ ; σ ) = E z ∼ p d a t a ( z ) , x σ ∼ q σ ( x σ ∣ z ) [ 1 2 ∥ s θ ( x σ ) − ∇ x σ log q σ ( x σ ∣ z ) ∥ 2 ] \mathcal{L}_{CDSM}(\theta; \sigma) = \mathbb{E}_{z\sim p_{data}(z), x_\sigma \sim q_\sigma(x_\sigma | z)}\left[\frac{1}{2}\left\Vert s_\theta(x_\sigma) - \nabla_{x_\sigma}\log q_\sigma(x_\sigma | z) \right\Vert^2 \right] L C D S M ( θ ; σ ) = E z ∼ p d a t a ( z ) , x σ ∼ q σ ( x σ ∣ z ) [ 2 1 ∥ s θ ( x σ ) − ∇ x σ log q σ ( x σ ∣ z ) ∥ 2 ] 展开同样可以得到
L C D S M ( θ ; σ ) = E x σ ∼ q σ ( x σ ) , z ∼ p ( z ∣ x σ ) [ 1 2 ∥ s θ ( x σ ) − ∇ x σ log q σ ( x σ ∣ z ) ∥ 2 ] = 1 2 ( E x σ ∥ s θ ( x σ ) ∥ 2 + E x σ , z ∥ ∇ x σ log q σ ( x σ ∣ z ) ∥ 2 − 2 E x σ , z [ s θ ( x σ ) ⊤ ∇ x σ log q σ ( x σ ∣ z ) ] ) \begin{aligned} \mathcal{L}_{CDSM}(\theta; \sigma) &= \mathbb{E}_{x_\sigma\sim q_\sigma(x_\sigma), z \sim p(z | x_\sigma)}\left[\frac{1}{2}\left\Vert s_\theta(x_\sigma) - \nabla_{x_\sigma}\log q_\sigma(x_\sigma | z) \right\Vert^2 \right] \\ &= \frac{1}{2}\left(\mathbb{E}_{x_\sigma}\left\Vert s_\theta(x_\sigma)\right\Vert^2 + \mathbb{E}_{x_\sigma, z}\left\Vert \nabla_{x_\sigma}\log q_\sigma(x_\sigma | z)\right\Vert^2 - 2\mathbb{E}_{x_\sigma, z} \left[s_\theta(x_\sigma)^\top \nabla_{x_\sigma}\log q_\sigma(x_\sigma | z)\right]\right) \end{aligned} L C D S M ( θ ; σ ) = E x σ ∼ q σ ( x σ ) , z ∼ p ( z ∣ x σ ) [ 2 1 ∥ s θ ( x σ ) − ∇ x σ log q σ ( x σ ∣ z ) ∥ 2 ] = 2 1 ( E x σ ∥ s θ ( x σ ) ∥ 2 + E x σ , z ∥ ∇ x σ log q σ ( x σ ∣ z ) ∥ 2 − 2 E x σ , z [ s θ ( x σ ) ⊤ ∇ x σ log q σ ( x σ ∣ z ) ] ) 因为
E x σ , z [ s θ ( x σ ) ⊤ ∇ x σ log q σ ( x σ ∣ z ) ] = E x σ [ s θ ( x σ ) ⊤ E z ∼ p ( z ∣ x σ ) [ ∇ x σ log q σ ( x σ ∣ z ) ] ] = E x σ [ s θ ( x σ ) ⊤ ∇ x σ log q σ ( x σ ) ] \mathbb{E}_{x_\sigma, z} \left[s_\theta(x_\sigma)^\top \nabla_{x_\sigma}\log q_\sigma(x_\sigma | z)\right] = \mathbb{E}_{x_\sigma} \left[s_\theta(x_\sigma)^\top \mathbb{E}_{z \sim p(z | x_\sigma)}\left[\nabla_{x_\sigma}\log q_\sigma(x_\sigma | z)\right]\right] = \mathbb{E}_{x_\sigma}\left[s_\theta(x_\sigma)^\top \nabla_{x_\sigma}\log q_\sigma(x_\sigma)\right] E x σ , z [ s θ ( x σ ) ⊤ ∇ x σ log q σ ( x σ ∣ z ) ] = E x σ [ s θ ( x σ ) ⊤ E z ∼ p ( z ∣ x σ ) [ ∇ x σ log q σ ( x σ ∣ z ) ] ] = E x σ [ s θ ( x σ ) ⊤ ∇ x σ log q σ ( x σ ) ] 且
E x σ , z ∥ ∇ x σ log q σ ( x σ ∣ z ) ∥ 2 − E x σ ∥ ∇ x σ log q σ ( x σ ) ∥ 2 = const \mathbb{E}_{x_\sigma, z}\left\Vert \nabla_{x_\sigma}\log q_\sigma(x_\sigma | z)\right\Vert^2 - \mathbb{E}_{x_\sigma}\left\Vert \nabla_{x_\sigma}\log q_\sigma(x_\sigma)\right\Vert^2 = \text{const} E x σ , z ∥ ∇ x σ log q σ ( x σ ∣ z ) ∥ 2 − E x σ ∥ ∇ x σ log q σ ( x σ ) ∥ 2 = const 由此可知,去噪分数匹配的目标函数 L D S M ( θ ; σ ) \mathcal{L}_{DSM}(\theta; \sigma) L D S M ( θ ; σ ) 和条件去噪分数匹配的目标函数 L C D S M ( θ ; σ ) \mathcal{L}_{CDSM}(\theta; \sigma) L C D S M ( θ ; σ ) 在优化上是等价的
L D S M ( θ ; σ ) = L C D S M ( θ ; σ ) + const \mathcal{L}_{DSM}(\theta; \sigma) = \mathcal{L}_{CDSM}(\theta; \sigma) + \text{const} L D S M ( θ ; σ ) = L C D S M ( θ ; σ ) + const 因此,我们可以选择更易计算的条件去噪分数匹配目标函数 L C D S M ( θ ; σ ) \mathcal{L}_{CDSM}(\theta; \sigma) L C D S M ( θ ; σ ) 进行优化。
(iv) 去噪分数匹配的渐进收敛性
当噪声水平 σ → 0 \sigma \to 0 σ → 0 时,带噪数据分布 q σ ( x σ ) q_\sigma(x_\sigma) q σ ( x σ ) 渐进收敛到真实数据分布 p d a t a ( x ) p_{data}(x) p d a t a ( x ) 。此时,去噪分数匹配目标函数 L D S M ( θ ; σ ) \mathcal{L}_{DSM}(\theta; \sigma) L D S M ( θ ; σ ) 也渐进收敛到原始的分数匹配目标函数 L S M ( θ ) \mathcal{L}_{SM}(\theta) L S M ( θ ) 。由此带来的是,含噪分布的分数函数 s θ ( x σ ) s_\theta(x_\sigma) s θ ( x σ ) 也会逐渐逼近真实数据分布的分数函数 ∇ x σ log p d a t a ( x σ ) \nabla_{x_\sigma}\log p_{data}(x_\sigma) ∇ x σ log p d a t a ( x σ )
lim σ → 0 L D S M ( θ ; σ ) = L S M ( θ ) + const \lim_{\sigma \to 0} \mathcal{L}_{DSM}(\theta; \sigma) = \mathcal{L}_{SM}(\theta) + \text{const} σ → 0 lim L D S M ( θ ; σ ) = L S M ( θ ) + const lim σ → 0 s θ ( x σ ) = ∇ x log p d a t a ( x ) \lim_{\sigma \to 0} s_\theta(x_\sigma) = \nabla_{x}\log p_{data}(x) σ → 0 lim s θ ( x σ ) = ∇ x log p d a t a ( x ) 因此用去噪分数匹配训练得到的模型 s θ ( x σ ) s_\theta(x_\sigma) s θ ( x σ ) ,在 σ \sigma σ 足够小时,可以很好地近似真实数据分布的分数函数。
(v) 最终的目标函数
在实际应用中,我们通常会在多个噪声水平 { σ i } i = 1 L \{\sigma_i\}_{i=1}^L { σ i } i = 1 L 下训练模型,一方面是为了增强模型的鲁棒性和泛化能力,另一方面是为了实现渐进性。经过适当计算,可以得到
∇ x σ log q σ ( x σ ∣ z ) = z − x σ σ 2 = − ϵ σ \nabla_{x_\sigma} \log q_\sigma(x_\sigma | z) = \frac{z - x_\sigma}{\sigma^2} = -\frac{\epsilon}{\sigma} ∇ x σ log q σ ( x σ ∣ z ) = σ 2 z − x σ = − σ ϵ 为了便于模型学习,将 σ \sigma σ 作为条件输入给神经网络。于是,最终的目标函数为
L D S M ′ ( θ ) = E σ ∼ p ( σ ) , z ∼ p d a t a ( z ) , x σ ∼ q σ ( x σ ∣ z ) [ λ ( σ ) ∥ s θ ( x σ , σ ) − ∇ x σ log q σ ( x σ ∣ z ) ∥ 2 ] = E σ ∼ p ( σ ) , z ∼ p d a t a ( z ) , ϵ ∼ N ( 0 , I d ) [ λ ( σ ) ∥ s θ ( z + σ ϵ , σ ) + ϵ σ ∥ 2 ] \boxed{ \begin{aligned} \mathcal{L}_{DSM}'(\theta) &= \mathbb{E}_{\sigma \sim p(\sigma), z \sim p_{data}(z), x_\sigma \sim q_\sigma(x_\sigma | z)}\left[\lambda(\sigma) \left\Vert s_\theta(x_\sigma, \sigma) - \nabla_{x_\sigma}\log q_\sigma(x_\sigma | z) \right\Vert^2\right] \\ &= \mathbb{E}_{\sigma \sim p(\sigma), z \sim p_{data}(z), \epsilon \sim \mathcal{N}(0, I_d)}\left[\lambda(\sigma)\left\Vert s_\theta(z + \sigma\epsilon, \sigma) + \frac{\epsilon}{\sigma}\right\Vert^2\right] \end{aligned} } L D S M ′ ( θ ) = E σ ∼ p ( σ ) , z ∼ p d a t a ( z ) , x σ ∼ q σ ( x σ ∣ z ) [ λ ( σ ) ∥ s θ ( x σ , σ ) − ∇ x σ log q σ ( x σ ∣ z ) ∥ 2 ] = E σ ∼ p ( σ ) , z ∼ p d a t a ( z ) , ϵ ∼ N ( 0 , I d ) [ λ ( σ ) s θ ( z + σ ϵ , σ ) + σ ϵ 2 ] 其中 λ ( σ ) \lambda(\sigma) λ ( σ ) 是噪声水平 σ \sigma σ 的权重函数,通常取 λ ( σ ) = σ 2 \lambda(\sigma) = \sigma^2 λ ( σ ) = σ 2 。
4.1.4 训练过程与推理过程# 训练过程
SMLD 的训练过程旨在学习一个能够估计不同噪声水平下数据分布分数函数的神经网络 s θ ( x , σ ) s_\theta(x, \sigma) s θ ( x , σ ) 。
数据与噪声 :从真实数据分布 p d a t a ( z ) p_{data}(z) p d a t a ( z ) 中采样一个批次的数据 { z i } \{z_i\} { z i } 。 定义一个噪声水平的几何序列 { σ i } i = 1 L \{\sigma_i\}_{i=1}^L { σ i } i = 1 L ,其中 0 < σ L < σ L − 1 < ⋯ < σ 2 < σ 1 0 < \sigma_L < \sigma_{L-1} < \dots < \sigma_2 < \sigma_1 0 < σ L < σ L − 1 < ⋯ < σ 2 < σ 1 。σ 1 \sigma_1 σ 1 足够大,使得加噪后的数据分布接近高斯分布;σ L \sigma_L σ L 足够小,使得加噪后的数据分布接近真实数据分布。 加噪与目标计算 :对于每个数据点 z z z ,从 { σ i } i = 1 L \{\sigma_i\}_{i=1}^L { σ i } i = 1 L 中随机选择一个噪声水平 σ \sigma σ 。 从标准正态分布 N ( 0 , I d ) \mathcal{N}(0, I_d) N ( 0 , I d ) 中采样一个噪声向量 ϵ \epsilon ϵ 。 生成带噪数据 x σ = z + σ ϵ x_\sigma = z + \sigma\epsilon x σ = z + σ ϵ 。 此时,我们知道真实的分数是 ∇ x σ log q σ ( x σ ∣ z ) = − ϵ σ \nabla_{x_\sigma}\log q_\sigma(x_\sigma | z) = -\frac{\epsilon}{\sigma} ∇ x σ log q σ ( x σ ∣ z ) = − σ ϵ 。 模型优化 :将带噪数据 x σ x_\sigma x σ 和噪声水平 σ \sigma σ 输入到分数模型 s θ ( x σ , σ ) s_\theta(x_\sigma, \sigma) s θ ( x σ , σ ) 中。 计算损失函数,即模型预测分数与真实分数之间的加权均方误差: L D S M ′ ( θ ) = 1 L ∑ i = 1 L σ i 2 E z ∼ p d a t a ( z ) , ϵ ∼ N ( 0 , I d ) [ ∥ s θ ( z + σ i ϵ , σ i ) + ϵ σ i ∥ 2 ] \mathcal{L}_{DSM}'(\theta) = \frac{1}{L}\sum_{i=1}^L \sigma_i^2 \mathbb{E}_{z \sim p_{data}(z), \epsilon \sim \mathcal{N}(0, I_d)}\left[\left\Vert s_\theta(z + \sigma_i\epsilon, \sigma_i) + \frac{\epsilon}{\sigma_i}\right\Vert^2\right] L D S M ′ ( θ ) = L 1 i = 1 ∑ L σ i 2 E z ∼ p d a t a ( z ) , ϵ ∼ N ( 0 , I d ) [ s θ ( z + σ i ϵ , σ i ) + σ i ϵ 2 ] 更新模型参数 θ \theta θ 以最小化该损失。 推理过程 (退火 Langevin 动力学采样)
训练完成后,模型 s θ ( x , σ ) s_\theta(x, \sigma) s θ ( x , σ ) 可以用来生成新样本。这个过程通过模拟一个从高噪声状态向低噪声状态“退火”的 Langevin 动力学过程来实现。
初始化 :从一个简单的先验分布(如均匀分布或高斯分布)中采样初始样本 x L x_L x L 。这个初始样本可以看作是处在最高噪声水平 σ 1 \sigma_1 σ 1 下的随机状态。退火采样循环 :按照预设的噪声水平序列 { σ 1 , σ 2 , … , σ L } \{\sigma_1, \sigma_2, \ldots, \sigma_L\} { σ 1 , σ 2 , … , σ L } 从高噪声水平到低噪声水平进行迭代。对于每一个噪声水平 σ i \sigma_i σ i :Langevin 动力学更新 :执行 K K K 步 Langevin MCMC 更新。在第 k k k 步 (k = 1 , … , K k=1, \dots, K k = 1 , … , K ): x i ( k ) = x i ( k − 1 ) + α i s θ ( x i ( k − 1 ) , σ i ) + 2 α i ξ k x_{i}^{(k)} = x_{i}^{(k-1)} + \alpha_i s_\theta(x_{i}^{(k-1)}, \sigma_i) + \sqrt{2\alpha_i} \xi_k x i ( k ) = x i ( k − 1 ) + α i s θ ( x i ( k − 1 ) , σ i ) + 2 α i ξ k 其中:x i ( k − 1 ) x_{i}^{(k-1)} x i ( k − 1 ) 是上一步的样本(对于 k = 1 k=1 k = 1 ,它是在上一个噪声水平 σ i − 1 \sigma_{i-1} σ i − 1 结束时得到的样本,即 x i ( 0 ) = x i − 1 ( K ) x_{i}^{(0)} = x_{i-1}^{(K)} x i ( 0 ) = x i − 1 ( K ) )。s θ ( x i ( k − 1 ) , σ i ) s_\theta(x_{i}^{(k-1)}, \sigma_i) s θ ( x i ( k − 1 ) , σ i ) 是模型在当前样本和噪声水平下预测的分数。α i \alpha_i α i 是与 σ i \sigma_i σ i 相关的步长,通常设置为 α i ∝ σ i 2 \alpha_i \propto \sigma_i^2 α i ∝ σ i 2 。ξ k ∼ N ( 0 , I d ) \xi_k \sim \mathcal{N}(0, I_d) ξ k ∼ N ( 0 , I d ) 是一个随机高斯噪声,用于维持随机性。传递样本 :完成 K K K 步更新后,将最终得到的样本 x K x_K x K 作为下一个更低噪声水平 σ i + 1 \sigma_{i+1} σ i + 1 的初始样本。输出样本 :当完成所有噪声水平的迭代后(即在最低噪声水平 σ L \sigma_L σ L 下完成 K K K 步更新),最终得到的样本即为生成的结果。退火过程的直观解释 :在高噪声水平下,数据分布平滑且简单,Langevin 动力学可以轻松探索整个样本空间,避免陷入局部最优。随着噪声水平的逐步降低,样本被逐渐“雕琢”和“精炼”,以匹配真实数据分布中更精细的结构和特征,最终生成高质量、高保真度的样本。