前向扩散
定义
考虑这样一个过程,对于一张图片 ,我们不断为其增加噪声:(注意这表示对图片中每个像素点的每个颜色通道归一化到 后的值同时进行操作)
其中 表示均值为 ,方差为 的正态分布(标准正态分布),则相当于 从均值为 ,方差为 的正态分布中抽样。
其中 和 是人为设置的超参数,应有 ,保证每一步加入的噪声比较少。
我们的目标是,在足够多轮后(例如 ),使得可以近似 。这样就可以从标准正态分布随机采样一个 ,再一步一步倒回来得到 。
注意前向扩散过程中方差 的变化:
我们希望 。不妨假定 (这可以通过对数据集进行标准化保证),也就是说我们希望扩散完成后方差依然为 ,则为了保证扩散过程中的稳定性,不妨对于所有 ,都保证 。
那么代入:
故一般保证 。
一步到位
考虑 和 的关系:
由正态分布的叠加性,对于两个相互独立的正态分布,有 ,所以有:
而由于 ,故:
所以:
以此类推,有:
反向生成
预测什么
为了实现反向生成,我们需要预测每一步加入的噪声 ,从而根据 推出 :
对比下面两条式子:
第一条是直接训练模型去拟合每一步加入的噪声,第二条是训练模型去拟合一步到位的噪声。
不难发现,根据第一条去训练就需要每次采样 和两个服从 的随机变量,比下面那一条多一个需要采样的变量,故较难收敛,训练起来更加费时费力。
所以不妨直接训练模型去拟合一步到位的噪声 ,即每次训练:
- 采样 ,再从 采样一个 ;
- 计算 ;
- 根据模型的输出 和 的差异,进行反向传播,更新参数;
那么根据 ,我们可以得知:
但是直接一步到位去预测 是不现实的,下面来推导一下怎么根据 的输出得到每一步加入的噪声 。
逐步逆向
令:
- 为已知原图为 ,第 步扩散结果等于 的概率;
- 为已知原图为 ,第 步扩散结果等于 且第 步扩散结果等于 的概率;
- 为已知第 步扩散结果等于 且原图等于 的前提下,第 步扩散结果等于 的概率, 同理。
那么只要求出 ,我们就能代入 和模型预测的 ,依据这个关于 的概率密度函数直接采样得到 。
根据贝叶斯公式,有:
根据之前的定义,有:
- :;
- :;
- :;
由于正态分布 的概率密度函数(抽样得到 的概率)为 ,故 等于:
对 里的东西变形,注意到我们只关心和 有关的项(常数不影响后续的配方):
接下来对其进行配方,使其变为 的形式,则:
所以:
注意到方差与模型的预测无关,所以也有人说 DDPM 就是在预测每一步的均值。
根据模型的预测,将 代入:
那么我们就得到了反向过程中每一步的采样公式:
将 和 展开就得到了原论文中的形式:
实现细节
超参设置
Liner Schedule
原论文中 ,,即 从 线性递增到 。
这样设置符合直觉,毕竟 越大 就越接近随机噪声,对 的破坏量 就可以相应地增大。而经计算可得 ,符合 的分布接近 的需求。
Cosine Schedule
观察 Liner Schedule 的 的曲线就会发现,似乎后面有很多步都是没用的,前一半扩散得太快:

所以后来提出了一种改进方法,设置 ,其中 是为了稳定数值引入的极小量,一般设置为 :(图中蓝色曲线)

实验表明,这样设置超参数可以利用好后面的扩散进程,效果更好。
实际应用中,一般令 ,,并将 裁剪到 ,其中 一般设置为 。
模型及损失函数选取
对于这种输入和输出张量形状相同的任务,一般选用 U-Net 模型。而由于模型是在预测一个服从正态分布的变量,故输出应较为接近 ,故可用 (均方误差)损失函数。
需要注意的是,时间步 也要输入模型中,一般是以 Transformer 中三角函数位置编码的形式嵌入进 U-Net 中的。
这里给出对于 的三通道图片的 U-Net 结构:

其中 DownBlock 和 UpBlock 结构一样:(TimeEmbedding 即为和 Transformer 中位置编码一样的三角函数编码,经过一个简单的多层感知机(MLP)变换后的结果)

MiddleBlock 则是一个对称的结构,不过其实也可以贪方便直接使用 DownBlock 替代掉:

激活函数则一般使用 SiLU 这种一阶导平滑连续的函数,增强其在 附近的表达能力(实测 ReLU 和 Leaky_ReLU 效果都不太好)。
这是一个小小的 Demo,使用了自己的 C++ 机器学习库、开源库 stb 和 C++ 图形库 EasyX,实现了生成指定类型的手写数字:
U_NET.h
#pragma once
#include <iostream>
#include <cstdio>
#include <cmath>
#include <fstream>
#include <random>
#include <ctime>
#include <vector>
#include "./network_h/network.h"
using namespace std;
using namespace network;
const dim_t Batch_Size = 32;
const dim_t time_dim = 256, emb_dim = 32, model_dim = 64;
inline void time_embedding(dim_t t, dim_t dim, float* out)
{
for (dim_t i = 0; i < dim; i++)
{
float wk = pow(10000, -(i / 2 * 2) / (float)dim);
if (i & 1 ^ 1) out[i] = sin(wk * t);
else out[i] = cos(wk * t);
}
}
class BLOCK : public OP_Base
{
public:
FC* fc1, * fc2; // time_liner & type_liner
CONV* c0, * c1, * c2, * c3;
GN* gn1, * gn2, * gn3;
BLOCK(OP_Base* fap, dim_t d, dim_t d2, dim_t h, dim_t w) :OP_Base(fap)
{
if (d != d2) c0 = get<CONV>(af::dim4{ h,w,d,0 }, make_pair(3, 3), d2, make_pair(1, 1), make_pair(1, 1));
else c0 = NULL;
fc1 = get<FC>(time_dim, d2), fc2 = get<FC>(emb_dim, d2);
c1 = get<CONV>(af::dim4{ h,w,d,0 }, make_pair(3, 3), d2, make_pair(1, 1), make_pair(1, 1));
gn1 = get<GN>(2, d2, 32, true);
c2 = get<CONV>(af::dim4{ h,w,d2,0 }, make_pair(3, 3), d2, make_pair(1, 1), make_pair(1, 1));
gn2 = get<GN>(2, d2, 32, true);
c3 = get<CONV>(af::dim4{ h,w,d2,0 }, make_pair(3, 3), d2, make_pair(1, 1), make_pair(1, 1));
gn3 = get<GN>(2, d2, 32, true);
}
val4d* operator()(val4d* x, val4d* time_emb, val4d* type_emb)
{
val4d* y;
if (c0 != NULL) y = (*c0)(x), y = silu(y, true);
else y = x;
// add time_emb
x = (*c1)(x), x = silu(x, true);
time_emb = (*fc1)(time_emb), time_emb = silu(time_emb, true);
x = add(x, tile(time_emb, af::dim4{ x->dims(0),x->dims(1),1,1 }));
x = (*gn1)(x);
// add type_emb
x = (*c2)(x), x = silu(x, true);
type_emb = (*fc2)(type_emb), type_emb = silu(type_emb, true);
x = add(x, tile(type_emb, af::dim4{ x->dims(0),x->dims(1),1,1 }));
x = (*gn2)(x);
// add short cut
x = (*c3)(x), x = silu(x, true);
x = add(x, y);
x = (*gn3)(x);
return x;
}
};
class DOWN :public OP_Base
{
public:
CONV* c;
GN* gn;
DOWN(OP_Base* fap, dim_t d, dim_t h, dim_t w) :OP_Base(fap)
{
c = get<CONV>(af::dim4{ h,w,d,0 }, make_pair(4, 4), d * 2, make_pair(2, 2), make_pair(1, 1));
gn = get<GN>(2, d * 2, 32, true);
}
val4d* operator()(val4d* x)
{
x = (*c)(x), x = silu(x, true), x = (*gn)(x);
return x;
}
};
class UP :public OP_Base
{
public:
CONV* c1, *c2;
GN* gn1, * gn2;
UP(OP_Base* fap, dim_t d, dim_t h, dim_t w) :OP_Base(fap)
{
c1 = get<CONV>(af::dim4{ h * 2,w * 2,d,0 }, make_pair(3, 3), d, make_pair(1, 1), make_pair(1, 1));
gn1 = get<GN>(2, d, 32, true);
c2 = get<CONV>(af::dim4{ h * 2,w * 2,d,0 }, make_pair(3, 3), d / 2, make_pair(1, 1), make_pair(1, 1));
gn2 = get<GN>(2, d / 2, 32, true);
}
val4d* operator()(val4d* x)
{
x = upsample(x, { 2,2 });
x = (*c1)(x), x = silu(x, true), x = (*gn1)(x);
x = (*c2)(x), x = silu(x, true), x = (*gn2)(x);
return x;
}
};
class U_NET : public OP_Base
{
public:
/****Time Embedding****/
FC* fc0, * fc1;
/****Type Embedding****/
EMBEDDING* emb;
/****In****/
CONV* c0;
/****Encoder****/
BLOCK* down_b[3][2];
DOWN* down[3];
/****Middle*****/
BLOCK* mid[3];
/****Decoder****/
UP* up[3];
BLOCK* up_b[3][2];
/****Out*****/
CONV* c1;
float in[Batch_Size * 1 * 32 * 32];
float ftime_emb[Batch_Size * time_dim];
dim_t in_t[Batch_Size], in_type[Batch_Size];
val4d* out;
U_NET() :OP_Base(NULL)
{
// Time Embedding
fc0 = get<FC>(time_dim, time_dim * 4), fc1 = get<FC>(time_dim * 4, time_dim);
// Type Embedding
emb = get<EMBEDDING>(10, 1, emb_dim);
// U-Net
c0 = get<CONV>(af::dim4{ 32,32,1,0 }, make_pair(3, 3), model_dim, make_pair(1, 1), make_pair(1, 1));
dim_t d = model_dim, h = 32, w = 32;
for (int i = 0; i < 3; i++)
{
for (int j = 0; j < 2; j++) down_b[i][j] = get<BLOCK>(d, d, h, w);
down[i] = get<DOWN>(d, h, w);
d *= 2, h /= 2, w /= 2;
}
for (int i = 0; i < 3; i++) mid[i] = get<BLOCK>(d, d, h, w);
for (int i = 0; i < 3; i++)
{
up[i] = get<UP>(d, h, w);
h *= 2, w *= 2, d /= 2;
for (int j = 0; j < 2; j++) up_b[i][j] = get<BLOCK>(d * 2, d, h, w);
}
c1 = get<CONV>(af::dim4{ h,w,d,0 }, make_pair(1, 1), 1, make_pair(1, 1), make_pair(0, 0), make_pair(1, 1), Init_Xavier);
}
inline void forward()
{
init_forward();
dim_t n = eval ? 1 : Batch_Size;
val4d* time_emb, * type_emb;
{
for (int i = 0; i < n; i++) time_embedding(in_t[i], time_dim, ftime_emb + i * time_dim);
af::dim4 emb_s = af::dim4{ 1,1,time_dim,n };
time_emb = tmp<val4d>(emb_s);
time_emb->data() = af::array(emb_s, ftime_emb);
time_emb = (*fc0)(time_emb), time_emb = silu(time_emb, true);
time_emb = (*fc1)(time_emb), time_emb = silu(time_emb, true);
}
{
type_emb = (*emb)(n, in_type);
}
af::dim4 ins = { 32,32,1,eval ? 1 : Batch_Size };
val4d* x = tmp<val4d>(ins);
x->data() = af::array(ins, in);
x = (*c0)(x);
vector<val4d*> que;
for (int i = 0; i < 3; i++)
{
for (int j = 0; j < 2; j++)
{
x = (*down_b[i][j])(x, time_emb, type_emb);
que.push_back(x);
}
x = (*down[i])(x);
}
for (int i = 0; i < 3; i++) x = (*mid[i])(x, time_emb, type_emb);
for (int i = 0; i < 3; i++)
{
x = (*up[i])(x);
for (int j = 0; j < 2; j++)
{
x = (*up_b[i][j])(concat({ x,que.back() }, 2), time_emb, type_emb);
que.pop_back();
}
}
x = (*c1)(x), x = tanh(x, true);
out = x;
}
};
class WARMUP_COSINE
{
private:
int t, tot_t;
double preres;
public:
double mnl, mxl;
int WarmUp;
double T_MAX, T_MLT;
bool REPET;
public:
void save(ofstream& ouf)
{
writf(ouf, mnl), writf(ouf, mxl);
writf(ouf, WarmUp);
writf(ouf, T_MAX), writf(ouf, T_MLT);
writf(ouf, REPET);
writf(ouf, t), writf(ouf, tot_t);
writf(ouf, preres);
}
void load(ifstream& inf)
{
readf(inf, mnl), readf(inf, mxl);
readf(inf, WarmUp);
readf(inf, T_MAX), readf(inf, T_MLT);
readf(inf, REPET);
readf(inf, t), readf(inf, tot_t);
readf(inf, preres);
}
public:
WARMUP_COSINE(double LLRT, double RLRT, int MAX_T, bool REP = true, int WUP = 0, double MLT_T = 1)
{
mnl = LLRT, mxl = RLRT, T_MAX = MAX_T, T_MLT = MLT_T, WarmUp = WUP;
REPET = REP;
t = tot_t = 0;
preres = WUP == 0 ? mxl : mnl;
}
double get() { return preres; }
void step()
{
tot_t++;
if (tot_t <= WarmUp)
{
preres += (mxl - mnl) / (double)WarmUp;
return;
}
if (t > T_MAX) return;
preres = mnl + (1 / (double)2) * (mxl - mnl) * (1 + cos(t / (double)T_MAX * acos(-1)));
t++;
if (REPET && t == T_MAX + 1) t = 0, T_MAX *= T_MLT;
}
};
main.cpp
#define NDEBUG
#include <iostream>
#include <cstdio>
#include <cmath>
#include <fstream>
#include <random>
#include <ctime>
#include <cmath>
#include <sstream>
#include <chrono>
#include <io.h>
#include <conio.h>
using namespace std;
#include "./U_NET.h"
#include "stb_image.h"
#include "stb_image_write.h"
#include <easyx.h>
const int tot_dat = 60000;
const int T = 1000;
const int save_t = 500;
const float grad_l = -1, grad_r = 1;
const float lrt_l = 0.00001, lrt_r = 0.0005;
const int total_batch = 10000, warm_up = 100;
mt19937 rndgen(time(NULL));
float a[T + 5], b[T + 5], a_[T + 5];
float dat[tot_dat][1 * 32 * 32];
int type[tot_dat];
U_NET brn;
ADAM opt(brn.parameter(), lrt_l, 0.9, 0.999, 0.01);
WARMUP_COSINE lrt_gen(lrt_l, lrt_r, total_batch, false, warm_up);
void loaddata(string imgpath, string anspath)
{
FILE* fimg = fopen(imgpath.c_str(), "rb");
FILE* fans = fopen(anspath.c_str(), "rb");
if (fimg == NULL)
{
puts("加载图片数据失败\n");
system("pause");
exit(1);
}
if (fans == NULL)
{
puts("加载答案数据失败\n");
system("pause");
exit(1);
}
fseek(fimg, 16, SEEK_SET);
fseek(fans, 8, SEEK_SET);
unsigned char* img = new unsigned char[28 * 28];
for (int cas = 0; cas < tot_dat; cas++)
{
fread(img, 1, 28 * 28, fimg);
for (int i = 0; i < 32 * 32; i++) dat[cas][i] = -1;
for (int i = 0; i < 28 * 28; i++)
{
int x = i / 28, y = i % 28;
dat[cas][(x + 2) * 32 + (y + 2)] = img[i] / (float)255 * 2 - 1;
}
unsigned char num;
fread(&num, 1, 1, fans);
type[cas] = num;
}
delete[] img;
fclose(fimg), fclose(fans);
}
inline void putimg(int sx, int sy, float* a) // put a 64*64 img
{
for (int i = 0; i < 32; i++)
{
for (int j = 0; j < 32; j++)
{
for (int k = 0; k < 2; k++)
{
for (int l = 0; l < 2; l++)
{
putpixel(sx + j * 2 + k, sy + i * 2 + l, RGB(
(a[i * 32 + j] + 1) / 2 * 255,
(a[i * 32 + j] + 1) / 2 * 255,
(a[i * 32 + j] + 1) / 2 * 255));
}
}
}
}
}
inline void init()
{
// Liner schedule
/*
float bl = 1e-4, br = 0.02;
for (int i = 1; i <= T; i++)
{
b[i] = bl + (float)(i - 1) / (T - 1) * (br - bl);
a[i] = 1 - b[i];
}
a_[1] = a[1];
for (int i = 2; i <= T; i++) a_[i] = a_[i - 1] * a[i];
*/
// Cosine schedule
float pi = acos(-1), eps = 1e-3;
auto f = [&](int t) {return (float)pow(cos(((float)t / T + eps) / (1 + eps) * pi / 2), 2); };
a[1] = (std::min)((std::max)(f(1) / f(0), eps), 1 - eps);
for (int i = 2; i <= T; i++) a[i] = (std::min)((std::max)(f(i) / f(i - 1), eps), 1 - eps);
for (int i = 1; i <= T; i++) b[i] = 1 - a[i];
a_[1] = a[1];
for (int i = 2; i <= T; i++) a_[i] = a_[i - 1] * a[i];
}
float x0[1 * 32 * 32], noise[Batch_Size * 1 * 32 * 32];
float train()
{
uniform_int_distribution<int> rndid(0, tot_dat - 1);
uniform_int_distribution<int> rndt(1, T);
normal_distribution<float> n(0, 1);
for (int bs = 0; bs < Batch_Size; bs++)
{
int id = rndid(rndgen);
int ad = bs * 1 * 32 * 32;
for (int i = 0; i < 1 * 32 * 32; i++) x0[i] = dat[id][i];
for (int i = 0; i < 1 * 32 * 32; i++) noise[ad + i] = n(rndgen);
int t = rndt(rndgen);
brn.in_t[bs] = t;
brn.in_type[bs] = type[id];
for (int i = 0; i < 1 * 32 * 32; i++) brn.in[ad + i] = sqrt(a_[t]) * x0[i] + sqrt(1 - a_[t]) * noise[ad + i];
}
brn.forward();
opt.clear_grad();
float loss = MSEloss(brn.out, af::array({ 32,32,1,Batch_Size }, noise));
brn.out->backward();
for (auto t : brn.parameter()) *t.second = (af::max)((af::min)(*t.second, grad_r), grad_l); // grad_clip
opt.lrt = lrt_gen.get();
opt.step(), lrt_gen.step();
return loss;
}
unsigned char tmp[32][32], res[1 * 32 * 32];
inline void work(int type, int sx, int sy)
{
normal_distribution<float> n(0, 1);
for (int i = 0; i < 1 * 32 * 32; i++) brn.in[i] = n(rndgen);
for (int i = T; i >= 1; i--)
{
brn.in_t[0] = i;
brn.in_type[0] = type;
brn.forward();
float* out = brn.out->data().host<float>();
for (int j = 0; j < 1 * 32 * 32; j++)
{
brn.in[j] = 1 / sqrt(a[i]) * (brn.in[j] - b[i] / sqrt(1 - a_[i]) * out[j])
+ (i > 1 ? sqrt((1 - a_[i - 1]) / (1 - a_[i]) * b[i]) * n(rndgen) : 0);
brn.in[j] = max((float)-1, min((float)1, brn.in[j]));
}
af::freeHost(out);
if (i % 1 == 0)
{
putimg(sx, sy, brn.in);
wstringstream ssm;
ssm << "T = " << i - 1 << " ";
outtextxy(80, 180, ssm.str().c_str());
}
}
}
void load(int idx)
{
stringstream ssm;
ssm << idx << ".ai";
ifstream inf(ssm.str(), ios::in | ios::binary);
brn.load(inf);
opt.load(inf);
lrt_gen.load(inf);
inf.close();
}
void save(int idx)
{
stringstream ssm;
ssm << idx << ".ai";
ofstream ouf(ssm.str(), ios::out | ios::binary);
brn.save(ouf);
opt.save(ouf);
lrt_gen.save(ouf);
ouf.close();
}
int main()
{
init();
printf("模式选择:\n");
printf("[1] 训练\n");
printf("[2] 运行\n");
int op;
scanf("%d", &op);
if (op == 1)
{
system("cls");
string imgpath = "../../../../data/MNIST/";
//string imgpath = "./data/";
printf("图片文件夹:%s\n\n", imgpath.c_str());
loaddata(imgpath + "img", imgpath + "ans");
printf("模式选择:\n");
printf("[1] 重新训练\n");
printf("[2] 读取并继续训练(自动填充 ai 路径)\n");
int op;
scanf("%d", &op);
system("cls");
int idx;
if (op == 1) idx = 0;
else
{
printf("断点 id:\n");
scanf("%d", &idx);
load(idx);
system("cls");
}
initgraph(256, 256, EX_SHOWCONSOLE);
putimg(48, 48, dat[0]);
brn.set_eval(false);
for (int tme = idx + 1;; tme++)
{
auto start = std::chrono::high_resolution_clock::now();
printf("%d th loss: %f\tlearning rate: %f\n", tme, train(), opt.lrt);
auto stop = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count();
printf("Time use: %d ms\n\n", (int)duration);
if (_kbhit())
{
char ch;
while (_kbhit()) ch = _getch();
if (ch == 27)
{
save(tme);
break;
}
if (ch == 't')
{
printf("输入一个 [0,9] 中的整数:");
int x;
cin >> x;
brn.set_eval(true);
work(x, 48, 48);
outtextxy(80, 180, L"Finished");
brn.set_eval(false);
}
}
if (tme % save_t == 0) save(tme);
}
}
else
{
system("cls");
printf("ai 文件路径:\n");
string path;
cin >> path;
ifstream inf(path, ios::in | ios::binary);
brn.load(inf);
inf.close();
initgraph(640, 256, EX_SHOWCONSOLE);
brn.set_eval(true);
for (int i = 0; i < 10; i++)
{
wstringstream ssm;
ssm << i;
outtextxy(30 + i * 64, 80, ssm.str().c_str());
work(i, i * 64, 0);
}
system("pause");
}
return 0;
}
这是训练了 14876 个 Batch 后的生成效果:
