3.2.2 条件流匹配目标函数#
引入条件流匹配目标函数 (Conditional Flow Matching Objective)
LCFM(θ)=Et∼Uniform[0,1],z∼pdata(z),xt∼pt(xt∣z)[uθ(xt,t)−utarget(xt,t∣z)2]可以证明,条件流匹配目标函数 LCFM(θ) 与边缘流匹配目标函数 LFM(θ) 在最优解上是等价的。这意味着,通过最小化 LCFM(θ),我们可以找到一个神经网络 uθ(xt,t),它能够有效地逼近真实的边缘速度场 utarget(xt,t)。
证明: 首先,我们将 LCFM(θ) 的期望形式展开
LCFM(θ)=Et∼Uniform[0,1],z∼pdata(z),xt∼pt(xt∣z)[uθ(xt,t)−utarget(xt,t∣z)2]=∫01∫pdata(z)∫pt(xt∣z)uθ(xt,t)−utarget(xt,t∣z)2dxtdzdt根据贝叶斯定理,有 pt(xt,z)=pt(xt∣z)pdata(z)=p(z∣xt)pt(xt)。代入上式可得
LCFM(θ)=∫01∫pt(xt)∫p(z∣xt)uθ(xt,t)−utarget(xt,t∣z)2dzdxtdt=∫01∫pt(xt)Ez∼p(z∣xt)[uθ(xt,t)−utarget(xt,t∣z)2]dxtdt=Et∼Uniform[0,1],xt∼pt(xt)[Ez∼p(z∣xt)[uθ(xt,t)−utarget(xt,t∣z)2]]对于内层的期望,我们可以将其展开
= = = = Ez∼p(z∣xt)[uθ(xt,t)−utarget(xt,t∣z)2]Ez∼p(z∣xt)[(uθ(xt,t)−Ez∼p(z∣xt)[utarget(xt,t∣z)])+(Ez∼p(z∣xt)[utarget(xt,t∣z)]−utarget(xt,t∣z))2]uθ(xt,t)−Ez∼p(z∣xt)[utarget(xt,t∣z)]2+Ez∼p(z∣xt)[Ez∼p(z∣xt)[utarget(xt,t∣z)]−utarget(xt,t∣z)2]+2(uθ(xt,t)−Ez∼p(z∣xt)[utarget(xt,t∣z)])⊤0Ez∼p(z∣xt)[(Ez∼p(z∣xt)[utarget(xt,t∣z)]−utarget(xt,t∣z))]uθ(xt,t)−Ez∼p(z∣xt)[utarget(xt,t∣z)]2+Varz∼p(z∣xt)(utarget(xt,t∣z))uθ(xt,t)−utarget(xt,t)2+Varz∼p(z∣xt)(utarget(xt,t∣z))其中,uθ(xt,t) 和 Ez∼p(z∣xt)[utarget(xt,t∣z)] 相对于 z 是常数,因此第一项前面的期望符号去掉。此外,根据边缘性定理,utarget(xt,t)=Ez∼p(z∣xt)[utarget(xt,t∣z)],替换 Ez∼p(z∣xt)[utarget(xt,t∣z)]。因此,LCFM(θ) 可以重写为
LCFM(θ)=LFM(θ)Et∼Uniform[0,1],xt∼pt(xt)[uθ(xt,t)−utarget(xt,t)2]+与θ无关的项Et∼Uniform[0,1],xt∼pt(xt)[Varz∼p(z∣xt)(utarget(xt,t∣z))]上式表明,LCFM(θ) 等于 LFM(θ) 加上一个与模型参数 θ 无关的方差项。因此,最小化 LCFM(θ) 等价于最小化 LFM(θ)。
这个转换的优势在于,LCFM 的计算不需要边缘分布 pt(xt) 或边缘速度场 utarget(xt,t),而是依赖于条件概率路径 pt(xt∣z) 和条件速度场 utarget(xt,t∣z),这两者都可以被设计成简单的、具有解析表达式的形式,从而使得训练变得可行。
既然有了条件速度场的 utarget(xt,t∣z) 解析形式,那直接用解析形式生成数据不就行了?为什么还需要训练神经网络?
- 条件速度场需要确定的真实数据点: 条件速度场 utarget(xt,t∣z) 是基于特定的真实数据点 z 定义的。它的作用相当于是为数据的演化提供了一个方向信息。然而,推理过程中,我们不会有真实的数据点 z 可供使用,因此无法直接利用条件速度场来生成数据。
- utarget(xt,t)=Ez∼p(z∣xt)[utarget(xt,t∣z)] 的一种加权理解: 边缘速度场实际上相当于是对所有可能的真实数据点 z 的条件速度场的加权平均。这个加权过程隐含了对数据分布的整体理解,而不仅仅是单个数据点的演化方向。让 uθ(xt,t) 去拟合这个加权平均,可以让模型学会把握数据的整体模式。
3.2.3 条件概率路径与条件速度场#
为什么需要条件概率路径?
直接用设置有简单闭式表达式的边际概率路径 (marginal probability path) {pt(x)}t∈[0,1] 训练其对应的速度场 vθ(x,t) 是非常困难的。为了解决这个问题,流匹配模型引入了条件概率路径 (conditional probability path) 的概念。其核心思想是,不直接对整个分布的演化进行建模,而是为每一对噪声样本 x0∼p0(x) 和数据样本 x1∼p1(x) 定义一个简单的、确定的路径。
条件概率的边界条件
设 z∼pdata(z) 来自于真实数据分布。为了实现 p0(x0)=pprior(x0),p1(x1)=pdata(x1),条件概率族 {pt(xt∣z)}t∈[0,1] 同样需要满足一定的边界条件。
- 先验条件分布:p0(x0∣z)=p0(x0),即 x0 与 z 独立
- 数据条件分布:p1(x1∣z)=N(x1∣z,σ12Id),其中 σ1 是很小的正数,用于将分布紧紧包围在数据点 z 附近。事实上,当 σ1→0 时,就是 Dirac Delta 分布 p1(x1∣z)=δz(x1)={∞,0,x1=zx1=z
条件概率路径的设计
为了便于实现,将条件概率 pt(xt∣z) 设置为一个简单的、具有闭式表达式的路径族
pt(xt∣z)=N(xt∣μt(z),σt2Id)这个形式就将条件概率路径的设计转换成了对均值函数 μt(z) 和标准差函数 σt 的设计。现在我们只需要保证 μt(z) 和 σt 满足边界条件就可以了
μ0(z)=0,σ0=1,μ1(z)=z,σ1→0条件速度场
根据条件概率路径的定义,可以计算出对应的条件速度场。条件概率 pt(xt∣z) 和条件速度场 utarget(xt,t∣z) 满足条件 Liouville 方程:
∂t∂pt(x∣z)+∇x⋅[pt(x∣z)utarget(x,t∣z)]=0当条件概率路径被设计为高斯分布 pt(xt∣z)=N(xt∣μt(z),σt2Id) 时,可以推导出对应的条件速度场具有一个简单的解析形式:
utarget(xt,t∣z)=μt˙(z)+σtσt˙(xt−μt(z))其中 μt˙(z)=∂t∂μt(z),σt˙=dtdσt。
推导过程:
(i) 对数概率的推导
对于高斯分布 pt(xt∣z)=N(xt∣μt(z),σt2Id),其对数概率为
logpt(xt∣z)=−2σt2∥xt−μt(z)∥2−2dlog(2πσt2)对时间 t 求导,得到
∂t∂logpt(xt∣z)=σt2(xt−μt(z))⊤μ˙t(z)+σt3∥xt−μt(z)∥2σ˙t−σtσ˙td其中 μt˙(z)=∂t∂μt(z),σ˙t=dtdσt。
(ii) Liouville 方程的对数形式
另一方面,将 Liouville 方程两边除以 pt 得到
∂t∂logpt(xt∣z)=−pt(xt∣z)1∇xt⋅[pt(xt∣z)utarget(xt,t∣z)]=−[∇xtlogpt(xt∣z)]⊤utarget(xt,t∣z)−∇xt⋅utarget(xt,t∣z)其中有利用性质,∇⋅(av)=a∇⋅v+v⋅∇a。
(iii) 联立求解条件速度场
对于高斯分布
∇xtlogpt(xt∣z)=−σt2xt−μt代入上式
∂t∂logpt(xt∣z)=σt2(xt−μt(z))⊤utarget(xt,t∣z)−∇xt⋅utarget(xt,t∣z)将两个 ∂t∂logpt 的表达式联立,得到关于 utarget(xt,t∣z) 的方程
σt2(xt−μt(z))⊤μt˙(z)+σt3∥xt−μt(z)∥2σt˙−σtσt˙d=σt2(xt−μt(z))⊤utarget(xt,t∣z)−∇xt⋅utarget(xt,t∣z)可以发现当 utarget(xt,t∣z) 取以下形式时方程成立
utarget(xt,t∣z)=μt˙(z)+σtσt˙(xt−μt(z))(iv) 代入验证
∇x⋅utarget(xt,t∣z)σt2(xt−μt(z))⊤utarget=∇x⋅(μt˙(z)+σtσt˙(xt−μt(z)))=σtσt˙d=σt2(xt−μt(z))⊤(μt˙(z)+σtσt˙(xt−μt(z)))=σt2(xt−μt(z))⊤μt˙(z)+σt3∥xt−μt(z)∥2σt˙将这两项相减,正好得到 ∂t∂logpt(xt∣z) 的表达式,证明了该速度场形式的正确性。关于方程的求解,可以参考附录 C.1。
这个结果非常重要,因为它表明,我们可以设计出简单的均值函数 μt(z) 和标准差函数 σt,得到一个具有闭式解的条件速度场 utarget(xt,t∣z)。这使得条件流匹配的目标函数 LCFM(θ) 变得完全可计算。