生成式扩散模型初探 —— DDPM 学习笔记

前向扩散

定义

考虑这样一个过程,对于一张图片 x0x_0,我们不断为其增加噪声:(注意这表示对图片中每个像素点的每个颜色通道归一化到 [1,1][-1,1] 后的值同时进行操作)

xtαt×xt1+βt×N(0,1)x_t\sim\sqrt{\alpha_t}\times x_{t-1}+\sqrt{\beta_t}\times \mathcal{N}(0,1)

其中 N(0,1)\mathcal{N}(0,1) 表示均值为 00,方差为 11 的正态分布(标准正态分布),则相当于 xtx_t 从均值为 αt×xt1\sqrt\alpha_t\times x_{t-1},方差为 βt\beta_t 的正态分布中抽样。

其中 αt\alpha_tβt\beta_t 是人为设置的超参数,应有 βt0\beta_t\approx0,保证每一步加入的噪声比较少。

我们的目标是,在足够多轮后(例如 T=1000T=1000),使得可以近似 xTN(0,1)x_T\sim \mathcal{N}(0,1)。这样就可以从标准正态分布随机采样一个 xTx_T,再一步一步倒回来得到 x0x_0

注意前向扩散过程中方差 σt2\sigma^2_t 的变化:

σt2=αt×σt12+βt\sigma^2_{t}=\alpha_t\times\sigma^2_{t-1}+\beta_t

我们希望 σT21\sigma^2_T\approx 1。不妨假定 σ02=1\sigma^2_0=1(这可以通过对数据集进行标准化保证),也就是说我们希望扩散完成后方差依然为 11,则为了保证扩散过程中的稳定性,不妨对于所有 tt,都保证 σt2=1\sigma_t^2=1

那么代入:

1=αt+βt1=\alpha_t+\beta_t

故一般保证 αt+βt=1\alpha_t+\beta_t=1

一步到位

考虑 xtx_{t}xt2x_{t-2} 的关系:

xtαt×xt1+βt×N(0,1)αt×(αt1×xt2+βt1×N(0,1))+βt×N(0,1)αtαt1×xt2+αtβt1×N(0,1)方差为 αtβt1 +βt×N(0,1)方差为 βt \begin{aligned} x_t&\sim\sqrt{\alpha_t}\times x_{t-1}+\sqrt{\beta_t}\times \mathcal{N}(0,1)\\ &\sim\sqrt{\alpha_t}\times\left(\sqrt{\alpha_{t-1}}\times x_{t-2}+\sqrt{\beta_{t-1}}\times \mathcal{N}(0,1)\right)+\sqrt{\beta_t}\times \mathcal{N}(0,1)\\ &\sim\sqrt{\alpha_t\alpha_{t-1}}\times x_{t-2}+\underbrace{\sqrt{\alpha_t\beta_{t-1}}\times \mathcal{N}(0,1)}_{\text{方差为 }\alpha_t\beta_{t-1}\text{ }的正态分布}+\underbrace{\sqrt{\beta_t}\times \mathcal{N}(0,1)}_{\text{方差为 }\beta_{t}\text{ }的正态分布}\\ \end{aligned}

由正态分布的叠加性,对于两个相互独立的正态分布,有 N(μ1,σ12)+N(μ2,σ22)=N(μ1+μ2,σ12+σ22)\mathcal{N}(\mu_1,\sigma^2_1)+\mathcal{N}(\mu_2,\sigma^2_2)=\mathcal{N}(\mu_1+\mu_2,\sigma^2_1+\sigma^2_2),所以有:

αtβt1×N(0,1)方差为 αtβt1 +βt×N(0,1)方差为 βt =αtβt1+βt×N(0,1)方差为 αtβt1+βt \underbrace{\sqrt{\alpha_t\beta_{t-1}}\times \mathcal{N}(0,1)}_{\text{方差为 }\alpha_t\beta_{t-1}\text{ }的正态分布}+\underbrace{\sqrt{\beta_t}\times \mathcal{N}(0,1)}_{\text{方差为 }\beta_{t}\text{ }的正态分布}=\underbrace{\sqrt{\alpha_t\beta_{t-1}+\beta_t}\times \mathcal{N}(0,1)}_{\text{方差为 }\alpha_t\beta_{t-1}+\beta_{t}\text{ }的正态分布}

而由于 αt+βt=1\alpha_t+\beta_t=1,故:

αtβt1+βt=αt(1αt1)+1αt=αtαtαt1+1αt=1αtαt1\begin{aligned} \alpha_t\beta_{t-1}+\beta_t&=\alpha_t(1-\alpha_{t-1})+1-\alpha_t\\ &=\alpha_t-\alpha_t\alpha_{t-1}+1-\alpha_t\\ &=1-\alpha_t\alpha_{t-1} \end{aligned}

所以:

xtαtαt1×xt2+1αtαt1×N(0,1)x_t\sim\sqrt{\alpha_t\alpha_{t-1}}\times x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\times \mathcal{N}(0,1)

以此类推,有:

xtαtαt1α1记为 αˉt×x0+1αtαt1α1记为 βˉt×N(0,1)x_t\sim \sqrt{\underbrace{\alpha_t\alpha_{t-1}\dots\alpha_1}_{\text{记为 }\bar\alpha_t}}\times x_0+ \sqrt{\underbrace{1-\alpha_t\alpha_{t-1}\dots\alpha_1}_{\text{记为 }\bar\beta_t}}\times\mathcal{N}(0,1)

反向生成

预测什么

为了实现反向生成,我们需要预测每一步加入的噪声 ϵtN(0,1)\epsilon_t\sim \mathcal{N}(0,1),从而根据 xtx_t 推出 xt1x_{t-1}

xt=αt×xt1+βt×ϵxt1=1αt×(xtβt×ϵ)x_{t}=\sqrt{\alpha_t}\times x_{t-1}+\sqrt{\beta_t}\times \epsilon\\ x_{t-1}=\frac{1}{\sqrt{\alpha_t}}\times \left(x_{t}-\sqrt{\beta_t}\times\epsilon\right)

对比下面两条式子:

xtαt×(αˉt1×x0+βˉt1×N(0,1))+βt×ϵ1xtαˉt×x0+βˉt×ϵ2x_t\sim \sqrt{\alpha_t}\times \left(\sqrt{\bar\alpha_{t-1}}\times x_0+\sqrt{\bar\beta_{t-1}}\times \mathcal{N}(0,1)\right) + \sqrt{\beta_t}\times \epsilon_1\\ x_t\sim \sqrt{\bar\alpha_t}\times x_0+\sqrt{\bar\beta_t}\times \epsilon_2

第一条是直接训练模型去拟合每一步加入的噪声,第二条是训练模型去拟合一步到位的噪声。

不难发现,根据第一条去训练就需要每次采样 x0,tx_0,t 和两个服从 N(0,1)\mathcal{N}(0,1) 的随机变量,比下面那一条多一个需要采样的变量,故较难收敛,训练起来更加费时费力。

所以不妨直接训练模型去拟合一步到位的噪声 ϵ2\epsilon_2,即每次训练:

  • 采样 x0,tx_0,t,再从 N(0,1)\mathcal{N}(0,1) 采样一个 ϵ\epsilon
  • 计算 xt=αˉt×x0+βˉt×ϵx_t=\sqrt{\bar\alpha_t}\times x_0+\sqrt{\bar\beta_t}\times \epsilon
  • 根据模型的输出 M(xt,t)M(x_t,t)ϵ\epsilon 的差异,进行反向传播,更新参数;

那么根据 xtαˉt×x0+βˉt×ϵx_t\sim \sqrt{\bar\alpha_t}\times x_0+\sqrt{\bar\beta_t}\times \epsilon,我们可以得知:

x0=1αˉt×(xtβˉt×M(xt,t))x_0=\frac{1}{\sqrt{\bar\alpha_t}}\times\left(x_t-\sqrt{\bar\beta_t}\times M(x_t,t)\right)

但是直接一步到位去预测 x0x_0 是不现实的,下面来推导一下怎么根据 M(xt,t)M(x_t,t) 的输出得到每一步加入的噪声 ϵt\epsilon_t

逐步逆向

令:

  • P(xtx0)P(x_t|x_0) 为已知原图为 x0x_0,第 tt 步扩散结果等于 xtx_t 的概率;
  • P(xt1,xtx0)P(x_{t-1},x_t|x_0) 为已知原图为 x0x_0,第 t1t-1 步扩散结果等于 xt1x_{t-1} 且第 tt 步扩散结果等于 xtx_t 的概率;
  • P(xtxt1,x0)P(x_t|x_{t-1},x_0) 为已知第 t1t-1 步扩散结果等于 xt1x_{t-1} 且原图等于 x0x_0 的前提下,第 tt 步扩散结果等于 xtx_{t} 的概率,P(xt1xt,x0)P(x_{t-1}|x_{t},x_0) 同理。

那么只要求出 P(xt1xt,x0)P(x_{t-1}|x_{t},x_0),我们就能代入 xtx_t 和模型预测的 x0x_0,依据这个关于 xt1x_{t-1} 的概率密度函数直接采样得到 xt1x_{t-1}

根据贝叶斯公式,有:

P(xt1xt,x0)=P(xt1,xtx0)P(xtx0)=P(xtxt1,x0)×P(xt1x0)P(xtx0)P(x_{t-1}|x_{t},x_0)=\frac{P(x_{t-1},x_t|x_0)}{P(x_{t}|x_0)}=\frac{P(x_t|x_{t-1},x_0)\times P(x_{t-1}|x_0)}{P(x_{t}|x_0)}

根据之前的定义,有:

  • P(xtxt1,x0)P(x_t|x_{t-1},x_0)xtN(αtxt1,βt)x_t\sim \mathcal{N}(\sqrt{\alpha_t}x_{t-1},\beta_t)
  • P(xt1x0)P(x_{t-1}|x_0)xt1N(αˉt1x0,βˉt1)x_{t-1}\sim\mathcal{N}(\sqrt{\bar\alpha_{t-1}}x_{0},\bar\beta_{t-1})
  • P(xtx0)P(x_{t}|x_0)xtN(αˉtx0,βˉt)x_{t}\sim\mathcal{N}(\sqrt{\bar\alpha_t}x_{0},\bar\beta_t)

由于正态分布 N(μ,σ2)\mathcal{N}(\mu,\sigma^2) 的概率密度函数(抽样得到 xx 的概率)为 exp(12×(xμ)2σ2)\exp\left(-\frac{1}{2}\times \frac{(x-\mu)^2}{\sigma^2}\right),故 P(xt1xt,x0)P(x_{t-1}|x_{t},x_0) 等于:

exp(12×((xtαtxt1)2βt+(xt1αˉt1x0)2βˉt1(xtαˉtx0)2βˉt))\exp\left(-\frac{1}{2}\times \left(\frac{\left(x_t-\sqrt{\alpha_t}x_{t-1}\right)^2}{\beta_t}+\frac{\left(x_{t-1}-\sqrt{\bar\alpha_{t-1}}x_0\right)^2}{\bar\beta_{t-1}}-\frac{\left(x_{t}-\sqrt{\bar\alpha_{t}}x_0\right)^2}{\bar\beta_t}\right)\right)

12×()-\frac{1}{2}\times(\dots) 里的东西变形,注意到我们只关心和 xt1x_{t-1} 有关的项(常数不影响后续的配方):

(xtαtxt1)2βt+(xt1αˉt1x0)2βˉt1(xtαˉtx0)2βˉtxt22αtxtxt1+αtxt12βt+xt122αˉt1xt1x0+αˉt1x02βˉt12αtxtβt×xt1+αtβt×xt12+1βˉt1×xt12+2αˉt1x0βˉt1×xt1(αtβt+1βˉt1)×xt122(αtxtβt+αˉt1x0βˉt1)×xt1\frac{\left(x_t-\sqrt{\alpha_t}x_{t-1}\right)^2}{\beta_t}+\frac{\left(x_{t-1}-\sqrt{\bar\alpha_{t-1}}x_0\right)^2}{\bar\beta_{t-1}}-\frac{\left(x_{t}-\sqrt{\bar\alpha_{t}}x_0\right)^2}{\bar\beta_t}\\ \frac{x_t^2-2\sqrt{\alpha_t}x_tx_{t-1}+\alpha_tx_{t-1}^2}{\beta_t}+\frac{x_{t-1}^2-2\sqrt{\bar\alpha_{t-1}}x_{t-1}x_0+\bar\alpha_{t-1}x_0^2}{\bar\beta_{t-1}}\\ \frac{-2\sqrt{\alpha_t}x_t}{\beta_t}\times x_{t-1}+\frac{\alpha_t}{\beta_t}\times x_{t-1}^2+\frac{1}{\bar\beta_{t-1}}\times x_{t-1}^2+\frac{-2\sqrt{\bar\alpha_{t-1}}x_0}{\bar\beta_{t-1}}\times x_{t-1}\\ \left(\frac{\alpha_t}{\beta_t}+\frac{1}{\bar\beta_{t-1}}\right)\times x_{t-1}^2-2\left(\frac{\sqrt{\alpha_t}x_t}{\beta_t}+\frac{\sqrt{\bar\alpha_{t-1}}x_0}{\bar\beta_{t-1}}\right)\times x_{t-1}\\

接下来对其进行配方,使其变为 (xt1μ)2σ2\frac{(x_{t-1}-\mu)^2}{\sigma^2} 的形式,则:

(xt1μ)2σ2=xt122μxt1+μ2σ21σ2×xt122μσ2×xt1\frac{(x_{t-1}-\mu)^2}{\sigma^2}=\frac{x_{t-1}^2-2\mu x_{t-1}+\mu^2}{\sigma^2}\\ \frac{1}{\sigma^2}\times x_{t-1}^2-2\frac{\mu}{\sigma^2}\times x_{t-1}

所以:

σ2=1αtβt+1βˉt1=βtβˉt1αtβˉt1+βt=βtβˉt1αtαˉt+1αt=βtβˉt1βˉtμ=(αtxtβt+αˉt1x0βˉt1)×σ2=βˉt1αtxt+βtαˉt1x0βtβˉt1×βtβˉt1βˉt=βˉt1αtxt+βtαˉt1x0βˉt\sigma^2=\frac{1}{\frac{\alpha_t}{\beta_t}+\frac{1}{\bar\beta_{t-1}}}\\ =\frac{\beta_t\bar\beta_{t-1}}{\alpha_t\bar\beta_{t-1}+\beta_t}\\ =\frac{\beta_t\bar\beta_{t-1}}{\alpha_t-\bar\alpha_{t}+1-\alpha_t}\\ =\frac{\beta_t\bar\beta_{t-1}}{\bar\beta_t}\\ \mu=\left(\frac{\sqrt{\alpha_t}x_t}{\beta_t}+\frac{\sqrt{\bar\alpha_{t-1}}x_0}{\bar\beta_{t-1}}\right)\times \sigma^2\\ =\frac{\bar\beta_{t-1}\sqrt{\alpha_t}x_t+\beta_t\sqrt{\bar\alpha_{t-1}}x_0}{\beta_t\bar\beta_{t-1}}\times \frac{\beta_t\bar\beta_{t-1}}{\bar\beta_t}\\ =\frac{\bar\beta_{t-1}\sqrt{\alpha_t}x_t+\beta_t\sqrt{\bar\alpha_{t-1}}x_0}{\bar\beta_t}

注意到方差与模型的预测无关,所以也有人说 DDPM 就是在预测每一步的均值。

根据模型的预测,将 x0x_0 代入:

μ=βˉt1αtxt+βtαˉt1x0βˉt=βˉt1αtxt+βtαˉt11αˉt(xtβˉtM(xt,t))βˉt=βˉt1αtxt+βtαˉt1αˉtxtβtαˉt1βˉtαˉtM(xt,t)βˉt=βˉt1αtαˉt+βtαˉt1αˉtβˉtxtβtαˉt1βˉtαˉtβˉtM(xt,t)=(βˉt1αt+βt)αˉt1αˉtβˉtxtβtαtβˉtM(xt,t)=(αtαˉt+1αt)αˉt1αˉtβˉtxtβtαtβˉtM(xt,t)=1αtxtβtαtβˉtM(xt,t)=1αt(xtβtβˉtM(xt,t))\begin{aligned} \mu&=\frac{\bar\beta_{t-1}\sqrt{\alpha_t}x_t+\beta_t\sqrt{\bar\alpha_{t-1}}x_0}{\bar\beta_t}\\ &=\frac{\bar\beta_{t-1}\sqrt{\alpha_t}x_t+\beta_t\sqrt{\bar\alpha_{t-1}}\frac{1}{\sqrt{\bar\alpha_t}}\left(x_t-\sqrt{\bar\beta_t}M(x_t,t)\right)}{\bar\beta_t}\\ &=\frac{\bar\beta_{t-1}\sqrt{\alpha_t}x_t+\frac{\beta_t\sqrt{\bar\alpha_{t-1}}}{\sqrt{\bar\alpha_t}}x_t-\frac{\beta_t\sqrt{\bar\alpha_{t-1}}\sqrt{\bar\beta_t}}{\sqrt{\bar\alpha_t}}M(x_t,t)}{\bar\beta_t}\\ &=\frac{\bar\beta_{t-1}\sqrt{\alpha_t}\sqrt{\bar\alpha_t}+\beta_t\sqrt{\bar\alpha_{t-1}}}{\sqrt{\bar\alpha_t}\bar\beta_t}x_t-\frac{\beta_t\sqrt{\bar\alpha_{t-1}}\sqrt{\bar\beta_t}}{\sqrt{\bar\alpha_t}\bar\beta_t}M(x_t,t)\\ &=\frac{\left(\bar\beta_{t-1}\alpha_t+\beta_t\right)\sqrt{\bar\alpha_{t-1}}}{\sqrt{\bar\alpha_t}\bar\beta_t}x_t-\frac{\beta_t}{\sqrt{\alpha_t}\sqrt{\bar\beta_t}}M(x_t,t)\\ &=\frac{\left(\alpha_t-\bar\alpha_t+1-\alpha_t\right)\sqrt{\bar\alpha_{t-1}}}{\sqrt{\bar\alpha_t}\bar\beta_t}x_t-\frac{\beta_t}{\sqrt{\alpha_t}\sqrt{\bar\beta_t}}M(x_t,t)\\ &=\frac{1}{\sqrt{\alpha_t}}x_t-\frac{\beta_t}{\sqrt{\alpha_t}\sqrt{\bar\beta_t}}M(x_t,t)\\ &=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{\beta_t}{\sqrt{\bar\beta_t}}M(x_t,t)\right)\\ \end{aligned}

那么我们就得到了反向过程中每一步的采样公式:

xt11αt(xtβtβˉtM(xt,t))+βtβˉt1βˉt×N(0,1)x_{t-1}\sim \frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{\beta_t}{\sqrt{\bar\beta_t}}M(x_t,t)\right)+\sqrt{\frac{\beta_t\bar\beta_{t-1}}{\bar\beta_t}}\times\mathcal{N}(0,1)

β\betaβˉ\bar\beta 展开就得到了原论文中的形式:

xt11αt(xt1αt1αˉtM(xt,t))+βt×1αˉt11αˉt×N(0,1)x_{t-1}\sim \frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}}M(x_t,t)\right)+\sqrt{\beta_t\times \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_{t}}}\times\mathcal{N}(0,1)

实现细节

超参设置

Liner Schedule

原论文中 T=1000T=1000βt=104+tT×(0.02104)\beta_t=10^{-4}+\frac{t}{T}\times (0.02-10^{-4}),即 βt\beta_t10410^{-4} 线性递增到 0.020.02

这样设置符合直觉,毕竟 tt 越大 xtx_t 就越接近随机噪声,对 xtx_t 的破坏量 βt\beta_t 就可以相应地增大。而经计算可得 αˉT4×105\bar\alpha_{T}\approx 4\times 10^{-5},符合 xTx_T 的分布接近 N(0,1)\mathcal{N}(0,1) 的需求。

Cosine Schedule

观察 Liner Schedule 的 αˉt\bar\alpha_t 的曲线就会发现,似乎后面有很多步都是没用的,前一半扩散得太快:

所以后来提出了一种改进方法,设置 aˉt=cos(tT+ϵ1+ϵ×π2)2\bar a_t=\cos\left(\frac{\frac{t}{T}+\epsilon}{1+\epsilon}\times \frac{\pi}{2}\right)^2,其中 ϵ\epsilon 是为了稳定数值引入的极小量,一般设置为 10310^{-3}:(图中蓝色曲线)

实验表明,这样设置超参数可以利用好后面的扩散进程,效果更好。

实际应用中,一般令 f(t)=cos(tT+ϵ1+ϵ×π2)2f(t)=\cos\left(\frac{\frac{t}{T}+\epsilon}{1+\epsilon}\times \frac{\pi}{2}\right)^2αˉt=f(t)f(0)\bar\alpha_t=\frac{f(t)}{f(0)},并将 αt\alpha_t 裁剪到 [ϵ,1ϵ][\epsilon,1-\epsilon],其中 ϵ\epsilon 一般设置为 10310^{-3}

模型及损失函数选取

对于这种输入和输出张量形状相同的任务,一般选用 U-Net 模型。而由于模型是在预测一个服从正态分布的变量,故输出应较为接近 00,故可用 MSE\text{MSE}(均方误差)损失函数。

需要注意的是,时间步 tt 也要输入模型中,一般是以 Transformer 中三角函数位置编码的形式嵌入进 U-Net 中的。

这里给出对于 32×3232\times 32 的三通道图片的 U-Net 结构:

其中 DownBlock 和 UpBlock 结构一样:(TimeEmbedding 即为和 Transformer 中位置编码一样的三角函数编码,经过一个简单的多层感知机(MLP)变换后的结果)

MiddleBlock 则是一个对称的结构,不过其实也可以贪方便直接使用 DownBlock 替代掉:

激活函数则一般使用 SiLU 这种一阶导平滑连续的函数,增强其在 00 附近的表达能力(实测 ReLU 和 Leaky_ReLU 效果都不太好)。

这是一个小小的 Demo,使用了自己的 C++ 机器学习库开源库 stbC++ 图形库 EasyX,实现了生成指定类型的手写数字:

U_NET.h

#pragma once

#include <iostream>
#include <cstdio>
#include <cmath>
#include <fstream>
#include <random>
#include <ctime>
#include <vector>

#include "./network_h/network.h"

using namespace std;

using namespace network;

const dim_t Batch_Size = 32;
const dim_t time_dim = 256, emb_dim = 32, model_dim = 64;

inline void time_embedding(dim_t t, dim_t dim, float* out)
{
	for (dim_t i = 0; i < dim; i++)
	{
		float wk = pow(10000, -(i / 2 * 2) / (float)dim);
		if (i & 1 ^ 1) out[i] = sin(wk * t);
		else out[i] = cos(wk * t);
	}
}

class BLOCK : public OP_Base
{
public:
	FC* fc1, * fc2; // time_liner & type_liner
	CONV* c0, * c1, * c2, * c3;
	GN* gn1, * gn2, * gn3;
	BLOCK(OP_Base* fap, dim_t d, dim_t d2, dim_t h, dim_t w) :OP_Base(fap)
	{
		if (d != d2) c0 = get<CONV>(af::dim4{ h,w,d,0 }, make_pair(3, 3), d2, make_pair(1, 1), make_pair(1, 1));
		else c0 = NULL;
		fc1 = get<FC>(time_dim, d2), fc2 = get<FC>(emb_dim, d2);
		c1 = get<CONV>(af::dim4{ h,w,d,0 }, make_pair(3, 3), d2, make_pair(1, 1), make_pair(1, 1));
		gn1 = get<GN>(2, d2, 32, true);
		c2 = get<CONV>(af::dim4{ h,w,d2,0 }, make_pair(3, 3), d2, make_pair(1, 1), make_pair(1, 1));
		gn2 = get<GN>(2, d2, 32, true);
		c3 = get<CONV>(af::dim4{ h,w,d2,0 }, make_pair(3, 3), d2, make_pair(1, 1), make_pair(1, 1));
		gn3 = get<GN>(2, d2, 32, true);
	}
	val4d* operator()(val4d* x, val4d* time_emb, val4d* type_emb)
	{
		val4d* y;
		if (c0 != NULL) y = (*c0)(x), y = silu(y, true);
		else y = x;
		// add time_emb
		x = (*c1)(x), x = silu(x, true);
		time_emb = (*fc1)(time_emb), time_emb = silu(time_emb, true);
		x = add(x, tile(time_emb, af::dim4{ x->dims(0),x->dims(1),1,1 }));
		x = (*gn1)(x);
		// add type_emb
		x = (*c2)(x), x = silu(x, true);
		type_emb = (*fc2)(type_emb), type_emb = silu(type_emb, true);
		x = add(x, tile(type_emb, af::dim4{ x->dims(0),x->dims(1),1,1 }));
		x = (*gn2)(x);
		// add short cut
		x = (*c3)(x), x = silu(x, true);
		x = add(x, y);
		x = (*gn3)(x);
		return x;
	}
};

class DOWN :public OP_Base
{
public:
	CONV* c;
	GN* gn;
	DOWN(OP_Base* fap, dim_t d, dim_t h, dim_t w) :OP_Base(fap)
	{
		c = get<CONV>(af::dim4{ h,w,d,0 }, make_pair(4, 4), d * 2, make_pair(2, 2), make_pair(1, 1));
		gn = get<GN>(2, d * 2, 32, true);
	}
	val4d* operator()(val4d* x)
	{
		x = (*c)(x), x = silu(x, true), x = (*gn)(x);
		return x;
	}
};

class UP :public OP_Base
{
public:
	CONV* c1, *c2;
	GN* gn1, * gn2;
	UP(OP_Base* fap, dim_t d, dim_t h, dim_t w) :OP_Base(fap)
	{
		c1 = get<CONV>(af::dim4{ h * 2,w * 2,d,0 }, make_pair(3, 3), d, make_pair(1, 1), make_pair(1, 1));
		gn1 = get<GN>(2, d, 32, true);
		c2 = get<CONV>(af::dim4{ h * 2,w * 2,d,0 }, make_pair(3, 3), d / 2, make_pair(1, 1), make_pair(1, 1));
		gn2 = get<GN>(2, d / 2, 32, true);
	}
	val4d* operator()(val4d* x)
	{
		x = upsample(x, { 2,2 });
		x = (*c1)(x), x = silu(x, true), x = (*gn1)(x);
		x = (*c2)(x), x = silu(x, true), x = (*gn2)(x);
		return x;
	}
};

class U_NET : public OP_Base
{
public:
	/****Time Embedding****/
	FC* fc0, * fc1;

	/****Type Embedding****/
	EMBEDDING* emb;


	/****In****/
	CONV* c0;

	/****Encoder****/
	BLOCK* down_b[3][2];
	DOWN* down[3];

	/****Middle*****/
	BLOCK* mid[3];

	/****Decoder****/
	UP* up[3];
	BLOCK* up_b[3][2];

	/****Out*****/
	CONV* c1;

	float in[Batch_Size * 1 * 32 * 32];
	float ftime_emb[Batch_Size * time_dim];
	dim_t in_t[Batch_Size], in_type[Batch_Size];
	val4d* out;

	U_NET() :OP_Base(NULL)
	{
		// Time Embedding
		fc0 = get<FC>(time_dim, time_dim * 4), fc1 = get<FC>(time_dim * 4, time_dim);
		// Type Embedding
		emb = get<EMBEDDING>(10, 1, emb_dim);
		// U-Net
		c0 = get<CONV>(af::dim4{ 32,32,1,0 }, make_pair(3, 3), model_dim, make_pair(1, 1), make_pair(1, 1));
		dim_t d = model_dim, h = 32, w = 32;
		for (int i = 0; i < 3; i++)
		{
			for (int j = 0; j < 2; j++) down_b[i][j] = get<BLOCK>(d, d, h, w);
			down[i] = get<DOWN>(d, h, w);
			d *= 2, h /= 2, w /= 2;
		}
		for (int i = 0; i < 3; i++) mid[i] = get<BLOCK>(d, d, h, w);
		for (int i = 0; i < 3; i++)
		{
			up[i] = get<UP>(d, h, w);
			h *= 2, w *= 2, d /= 2;
			for (int j = 0; j < 2; j++) up_b[i][j] = get<BLOCK>(d * 2, d, h, w);
		}
		c1 = get<CONV>(af::dim4{ h,w,d,0 }, make_pair(1, 1), 1, make_pair(1, 1), make_pair(0, 0), make_pair(1, 1), Init_Xavier);
	}
	inline void forward()
	{
		init_forward();
		dim_t n = eval ? 1 : Batch_Size;
		val4d* time_emb, * type_emb;
		{
			for (int i = 0; i < n; i++) time_embedding(in_t[i], time_dim, ftime_emb + i * time_dim);
			af::dim4 emb_s = af::dim4{ 1,1,time_dim,n };
			time_emb = tmp<val4d>(emb_s);
			time_emb->data() = af::array(emb_s, ftime_emb);
			time_emb = (*fc0)(time_emb), time_emb = silu(time_emb, true);
			time_emb = (*fc1)(time_emb), time_emb = silu(time_emb, true);
		}
		{
			type_emb = (*emb)(n, in_type);
		}
		af::dim4 ins = { 32,32,1,eval ? 1 : Batch_Size };
		val4d* x = tmp<val4d>(ins);
		x->data() = af::array(ins, in);
		x = (*c0)(x);
		vector<val4d*> que;
		for (int i = 0; i < 3; i++)
		{
			for (int j = 0; j < 2; j++)
			{
				x = (*down_b[i][j])(x, time_emb, type_emb);
				que.push_back(x);
			}
			x = (*down[i])(x);
		}
		for (int i = 0; i < 3; i++) x = (*mid[i])(x, time_emb, type_emb);
		for (int i = 0; i < 3; i++)
		{
			x = (*up[i])(x);
			for (int j = 0; j < 2; j++)
			{
				x = (*up_b[i][j])(concat({ x,que.back() }, 2), time_emb, type_emb);
				que.pop_back();
			}
		}
		x = (*c1)(x), x = tanh(x, true);
		out = x;
	}
};

class WARMUP_COSINE
{
private:
	int t, tot_t;
	double preres;
public:
	double mnl, mxl;
	int WarmUp;
	double T_MAX, T_MLT;
	bool REPET;

public:
	void save(ofstream& ouf)
	{
		writf(ouf, mnl), writf(ouf, mxl);
		writf(ouf, WarmUp);
		writf(ouf, T_MAX), writf(ouf, T_MLT);
		writf(ouf, REPET);
		writf(ouf, t), writf(ouf, tot_t);
		writf(ouf, preres);
	}
	void load(ifstream& inf)
	{
		readf(inf, mnl), readf(inf, mxl);
		readf(inf, WarmUp);
		readf(inf, T_MAX), readf(inf, T_MLT);
		readf(inf, REPET);
		readf(inf, t), readf(inf, tot_t);
		readf(inf, preres);
	}

public:
	WARMUP_COSINE(double LLRT, double RLRT, int MAX_T, bool REP = true, int WUP = 0, double MLT_T = 1)
	{
		mnl = LLRT, mxl = RLRT, T_MAX = MAX_T, T_MLT = MLT_T, WarmUp = WUP;
		REPET = REP;
		t = tot_t = 0;
		preres = WUP == 0 ? mxl : mnl;
	}
	double get() { return preres; }
	void step()
	{
		tot_t++;
		if (tot_t <= WarmUp)
		{
			preres += (mxl - mnl) / (double)WarmUp;
			return;
		}
		if (t > T_MAX) return;
		preres = mnl + (1 / (double)2) * (mxl - mnl) * (1 + cos(t / (double)T_MAX * acos(-1)));
		t++;
		if (REPET && t == T_MAX + 1) t = 0, T_MAX *= T_MLT;
	}
};

main.cpp

#define NDEBUG

#include <iostream>
#include <cstdio>
#include <cmath>
#include <fstream>
#include <random>
#include <ctime>
#include <cmath>
#include <sstream>
#include <chrono>
#include <io.h>
#include <conio.h>

using namespace std;

#include "./U_NET.h"

#include "stb_image.h"
#include "stb_image_write.h"

#include <easyx.h>

const int tot_dat = 60000;
const int T = 1000;
const int save_t = 500;
const float grad_l = -1, grad_r = 1;
const float lrt_l = 0.00001, lrt_r = 0.0005;
const int total_batch = 10000, warm_up = 100;

mt19937 rndgen(time(NULL));
float a[T + 5], b[T + 5], a_[T + 5];
float dat[tot_dat][1 * 32 * 32];
int type[tot_dat];
U_NET brn;
ADAM opt(brn.parameter(), lrt_l, 0.9, 0.999, 0.01);
WARMUP_COSINE lrt_gen(lrt_l, lrt_r, total_batch, false, warm_up);

void loaddata(string imgpath, string anspath)
{
	FILE* fimg = fopen(imgpath.c_str(), "rb");
	FILE* fans = fopen(anspath.c_str(), "rb");
	if (fimg == NULL)
	{
		puts("加载图片数据失败\n");
		system("pause");
		exit(1);
	}
	if (fans == NULL)
	{
		puts("加载答案数据失败\n");
		system("pause");
		exit(1);
	}
	fseek(fimg, 16, SEEK_SET);
	fseek(fans, 8, SEEK_SET);
	unsigned char* img = new unsigned char[28 * 28];
	for (int cas = 0; cas < tot_dat; cas++)
	{
		fread(img, 1, 28 * 28, fimg);
		for (int i = 0; i < 32 * 32; i++) dat[cas][i] = -1;
		for (int i = 0; i < 28 * 28; i++)
		{
			int x = i / 28, y = i % 28;
			dat[cas][(x + 2) * 32 + (y + 2)] = img[i] / (float)255 * 2 - 1;
		}
		unsigned char num;
		fread(&num, 1, 1, fans);
		type[cas] = num;
	}
	delete[] img;
	fclose(fimg), fclose(fans);
}
inline void putimg(int sx, int sy, float* a) // put a 64*64 img
{
	for (int i = 0; i < 32; i++)
	{
		for (int j = 0; j < 32; j++)
		{
			for (int k = 0; k < 2; k++)
			{
				for (int l = 0; l < 2; l++)
				{
					putpixel(sx + j * 2 + k, sy + i * 2 + l, RGB(
						(a[i * 32 + j] + 1) / 2 * 255,
						(a[i * 32 + j] + 1) / 2 * 255,
						(a[i * 32 + j] + 1) / 2 * 255));
				}
			}
		}
	}
}

inline void init()
{
	// Liner schedule
	/*
	float bl = 1e-4, br = 0.02;
	for (int i = 1; i <= T; i++)
	{
		b[i] = bl + (float)(i - 1) / (T - 1) * (br - bl);
		a[i] = 1 - b[i];
	}
	a_[1] = a[1];
	for (int i = 2; i <= T; i++) a_[i] = a_[i - 1] * a[i];
	*/
	
	// Cosine schedule
	float pi = acos(-1), eps = 1e-3;
	auto f = [&](int t) {return (float)pow(cos(((float)t / T + eps) / (1 + eps) * pi / 2), 2); };
	a[1] = (std::min)((std::max)(f(1) / f(0), eps), 1 - eps);
	for (int i = 2; i <= T; i++) a[i] = (std::min)((std::max)(f(i) / f(i - 1), eps), 1 - eps);
	for (int i = 1; i <= T; i++) b[i] = 1 - a[i];
	a_[1] = a[1];
	for (int i = 2; i <= T; i++) a_[i] = a_[i - 1] * a[i];
}

float x0[1 * 32 * 32], noise[Batch_Size * 1 * 32 * 32];

float train()
{
	uniform_int_distribution<int> rndid(0, tot_dat - 1);
	uniform_int_distribution<int> rndt(1, T);
	normal_distribution<float> n(0, 1);
	for (int bs = 0; bs < Batch_Size; bs++)
	{
		int id = rndid(rndgen);
		int ad = bs * 1 * 32 * 32;
		for (int i = 0; i < 1 * 32 * 32; i++) x0[i] = dat[id][i];
		for (int i = 0; i < 1 * 32 * 32; i++) noise[ad + i] = n(rndgen);
		int t = rndt(rndgen);
		brn.in_t[bs] = t;
		brn.in_type[bs] = type[id];
		for (int i = 0; i < 1 * 32 * 32; i++) brn.in[ad + i] = sqrt(a_[t]) * x0[i] + sqrt(1 - a_[t]) * noise[ad + i];
	}
	brn.forward();
	opt.clear_grad();
	float loss = MSEloss(brn.out, af::array({ 32,32,1,Batch_Size }, noise));
	brn.out->backward();
	for (auto t : brn.parameter()) *t.second = (af::max)((af::min)(*t.second, grad_r), grad_l); // grad_clip
	opt.lrt = lrt_gen.get();
	opt.step(), lrt_gen.step();
	return loss;
}

unsigned char tmp[32][32], res[1 * 32 * 32];

inline void work(int type, int sx, int sy)
{
	normal_distribution<float> n(0, 1);
	for (int i = 0; i < 1 * 32 * 32; i++) brn.in[i] = n(rndgen);
	for (int i = T; i >= 1; i--)
	{
		brn.in_t[0] = i;
		brn.in_type[0] = type;
		brn.forward();
		float* out = brn.out->data().host<float>();
		for (int j = 0; j < 1 * 32 * 32; j++)
		{
			brn.in[j] = 1 / sqrt(a[i]) * (brn.in[j] - b[i] / sqrt(1 - a_[i]) * out[j])
				+ (i > 1 ? sqrt((1 - a_[i - 1]) / (1 - a_[i]) * b[i]) * n(rndgen) : 0);
			brn.in[j] = max((float)-1, min((float)1, brn.in[j]));
		}
		af::freeHost(out);
		if (i % 1 == 0)
		{
			putimg(sx, sy, brn.in);
			wstringstream ssm;
			ssm << "T = " << i - 1 << "   ";
			outtextxy(80, 180, ssm.str().c_str());
		}
	}
}

void load(int idx)
{
	stringstream ssm;
	ssm << idx << ".ai";
	ifstream inf(ssm.str(), ios::in | ios::binary);
	brn.load(inf);
	opt.load(inf);
	lrt_gen.load(inf);
	inf.close();
}

void save(int idx)
{
	stringstream ssm;
	ssm << idx << ".ai";
	ofstream ouf(ssm.str(), ios::out | ios::binary);
	brn.save(ouf);
	opt.save(ouf);
	lrt_gen.save(ouf);
	ouf.close();
}

int main()
{
	init();
	printf("模式选择:\n");
	printf("[1] 训练\n");
	printf("[2] 运行\n");
	int op;
	scanf("%d", &op);
	if (op == 1)
	{
		system("cls");
		string imgpath = "../../../../data/MNIST/";
		//string imgpath = "./data/";
		printf("图片文件夹:%s\n\n", imgpath.c_str());
		loaddata(imgpath + "img", imgpath + "ans");
		printf("模式选择:\n");
		printf("[1] 重新训练\n");
		printf("[2] 读取并继续训练(自动填充 ai 路径)\n");
		int op;
		scanf("%d", &op);
		system("cls");
		int idx;
		if (op == 1) idx = 0;
		else
		{
			printf("断点 id:\n");
			scanf("%d", &idx);
			load(idx);
			system("cls");
		}
		initgraph(256, 256, EX_SHOWCONSOLE);
		putimg(48, 48, dat[0]);
		brn.set_eval(false);
		for (int tme = idx + 1;; tme++)
		{
			auto start = std::chrono::high_resolution_clock::now();
			printf("%d th loss: %f\tlearning rate: %f\n", tme, train(), opt.lrt);
			auto stop = std::chrono::high_resolution_clock::now();
			auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count();
			printf("Time use: %d ms\n\n", (int)duration);
			if (_kbhit())
			{
				char ch;
				while (_kbhit()) ch = _getch();
				if (ch == 27)
				{
					save(tme);
					break;
				}
				if (ch == 't')
				{
					printf("输入一个 [0,9] 中的整数:");
					int x;
					cin >> x;
					brn.set_eval(true);
					work(x, 48, 48);
					outtextxy(80, 180, L"Finished");
					brn.set_eval(false);
				}
			}
			if (tme % save_t == 0) save(tme);
		}
	}
	else
	{
		system("cls");
		printf("ai 文件路径:\n");
		string path;
		cin >> path;
		ifstream inf(path, ios::in | ios::binary);
		brn.load(inf);
		inf.close();
		initgraph(640, 256, EX_SHOWCONSOLE);
		brn.set_eval(true);
		for (int i = 0; i < 10; i++)
		{
			wstringstream ssm;
			ssm << i;
			outtextxy(30 + i * 64, 80, ssm.str().c_str());
			work(i, i * 64, 0);
		}
		system("pause");
	}
	return 0;
}

这是训练了 14876 个 Batch 后的生成效果: