DDPM
约 943 字大约 3 分钟
生成模型
2025-04-23
Denoising Diffusion Probabilistic Models
1. Difussion
graph LR
A[x0]-->B[x1]-.-C[xt-1]-->D[xt]-.-E[xT-1]-->F[xT]
扩散过程:
- 一个固定过程
- 扩散超参 β
xt−1→xt:
xt=1−βtxt−1+βtzt, zt∼N(0,I)
β: 10−4∼2−2, linear, T≈2000
令 1−βt=αt
xt=αtxt+1−αtzt=αt(αt−1xt−2+1−αt−1zt−1)+1−αtzt
根据高斯分布的叠加方法:
xt=αtαt−1xt−2+1−αtαt−1z, z∼N(0,I)
令 αˉt=∏i=1Tαi,得到 x0 与 xt 的关系:
xt=αˉtx0+1−αˉtz
如何取β的值?:αˉ→0, xT→N(0,I)
train a model to fit noise in each steps
2. Training
假设batch_size = 4,T = 2000
def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):
"""对任意时刻t进行采样计算loss"""
batch_size = x_0.shape[0]
#对一个batchsize样本生成随机的时刻t
t = torch.randint(0,n_steps,size=(batch_size//2,)).to(device)
t = torch.cat([t,n_steps-1-t],dim=0)
t = t.unsqueeze(-1)
#x0的系数
a = alphas_bar_sqrt[t]
#eps的系数
aml = one_minus_alphas_bar_sqrt[t]
#生成随机噪音eps
e = torch.randn_like(x_0).to(device)
#构造模型的输入
x = x_0 * a + e * aml
#送入模型,得到t时刻的随机噪声预测值
output = model(x,t.squeeze(-1))
#与真实噪声一起计算误差,求平均值
return (e - output).square().mean()
3. Sampling
根据贝叶斯理论:
q(xt−1∣xt)=q(xt)q(xt,xt−1)=q(xt)q(xt∣xt−1)q(xt−1)
目前已知的有:xt−1→xt,x0→xt
xtxt=αtxt−1+1−αtz∼N(αtxt−1,(1−αt)I)⟶q(xt∣xt−1)=αˉtx0+1−αˉtz∼N(αˉtx0,(1−αˉt)I)⟶q(xt)
q(xt∣xt−1)q(xt)q(xt−1)∼N(αtxt−1,(1−αt)I)∼N(αˉtx0,(1−αˉt)I)∼N(αˉt−1x0,(1−αˉt−1)I)
又因为,高斯分布的密度函数可以写成:N(μ,σ2)∝exp(−21σ2(x−μ)2)
q(xt)q(xt∣xt−1)q(xt−1)∝exp[−21(1−αt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2)]∝exp{−21[(βtαt+1−αˉt−11)xt−12−2(βtαtxt+1−αˉt−1αˉt−1x0)xt−1]+ ...}∝exp[−21(Axt−12+Bxt+C)]∝exp[−21A(xt−1+2AB)2+ ...]
提取系数,A=βtαt+1−αˉt−11,B=−2(βtαtxt+1−αˉt−1αˉt−1x0)
所以,μ=2AB,σ2=A1
σ2=A1=βtαt+1−αˉt−111=1−αtαˉt−1βt(1−αˉt−1)=1−αˉt1−αˉt−1βt
μ=2AB=(βtαtxt+1−αˉt−1αˉt−1x0)1−αˉt1−αˉt−1βt=1−αˉtαtxt(1−αˉt−1)+1−αˉtαˉt−1x0βt=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0
因为 xt=αtxt−1+1−αtz ,x0=αˉt1(xt−1−αˉtz~),带入 μ
μ=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtαˉt1(xt−1−αˉtz~)=αˉtxt(1−αˉtαt−αˉt+βt)+αtz~1−αˉtβt=αt1(xt−1−αˉtβtz~)