Network as Generator
和以前的模型不同,这次输入会加入一个随机分布中取样的变量 ,从而使得与之相对的 不同,为另一个复杂的分布。
/IMG-20260131201603271.png)
当我们的任务需要一点创造新的时候,我们需要输出一个 distribution。
/IMG-20260131201603431.png)
GAN
- Unconditional Generation
也就是只输入随机变量 ,而不输入 。
服从的分布并没有特殊要求,对性能影响差异有些许差异,但总体不大。下文均假定
/IMG-20260131201604016.png)
Discriminator
输入图片,输出 scalar,相当于打分,越符合标准分数越高
/IMG-20260131201604586.png)
Basic Idea
/IMG-20260131201605168.png)
/IMG-20260131201605190.png)
也就是协同进化,Generator 的目的是生成足够符合特征的图片,骗过 Discriminator。而 Discriminator 会通过关注图片某些特征,鉴别真假图片,从而迫使 Generator 进行改正。
Algorithm
首先初始化 G 与 D,然后…
-
固定住 G,训练 D:用 G 生成出来的图片,和 Database 的真图片训练 D,使得 D 可以分类真假图片(可以当作 Classifier,也可以当作 Regression 的问题做,总之区分两种)
/IMG-20260131201605810.png)
-
固定 D,训练 G:用 G 生成的图片,交给 D 取鉴定,然后过程中不断训练、调整 G 的参数,使得得分越高越好。这样就可以说明我们的 G 开始骗过 D 了
/IMG-20260131201606397.png)
-
交替进行第一、二步
/IMG-20260131201606982.png)
The Theory
当我们训练模型的时候,我们实际在 minimize 什么?我们的 Loss 是什么?
/IMG-20260131201607598.png)
然而,Divergence 十分难以计算…
GAN 解决的问题就是,我们可以在不知道 的具体形式的时候,只通过采样估测他们的 Divergence
/IMG-20260131201608791.png)
/IMG-20260131201609536.png)
(我们把要 minimize 的称为 Loss Function,要 maximize 的称为 Objective Function)
直观理解:
/IMG-20260131201610559.png)
所以,这就是我们要做的,刚刚所提及的过程
/IMG-20260131201611221.png)
其中,也就是:
- 为 D 的 Object Function(也就是 D 的 Loss Function 为 )
- 而 是 G 的 Loss Function,其实也就是 (第一项与 G 无关), 实则为了便于训练,解决梯度消失问题,使用
实际上,这个过程是对 JS Divergence 的估计,我们最小化 G 的 Loss Function,本质就是最小化真实图片分布和生成图片分布的 JS Divergence。
/IMG-20260131201611734.png)
通用公式:
- 目标函数:
- Discriminator Loss:
- Generator Loss:
Training Tip:WGAN
实际上 JS divergence 有一些问题。 几乎没有任何重叠。
/IMG-20260131201612100.png)
因此算出来经常是 ,也就是 D 总能正确分类 G 生成的图片和真实图片,这导致了我们无法根据 Loss 判断是不是模型正在变好。
/IMG-20260131201612404.png)
所以换一种计算方式:Wasserstein Distance
/IMG-20260131201612607.png)
/IMG-20260131201612761.png)
用这个作为两个分布接近程度的衡量,就可以通过 Loss 观察我们模型优化的过程,可以看到模型变得更好的过程。
/IMG-20260131201612938.png)
所以 WGAN 用的是这个方法,用这个作为 D 的 Object Function。
/IMG-20260131201613270.png)
/IMG-20260131201613558.png)
但总之,GAN 还是很难以训练,因为一旦 G 或者 D 有一方没能持续进步,另一方的提升也会随之停滞。
GAN 文字生成
而 GAN 生成文字的训练更是格外困难,因为我们会取最大概率对应的词,作为输出的 token。而参数微小改变不影响虽然对具体概率有影响,但一般不影响他们的大小关系,最大概率对应的 token 不变,输出不变,导致不可微分。
/IMG-20260131201613754.png)
与 Max Pooling 的不同
- GAN 生成序列时的
max不可导:
- Decoder 输出的是一个概率分布,取概率最大的那个字的编号。
- 索引的变化是跳跃式的,不是连续的。
- 如果对 Decoder 的参数做微小的改变 ,虽然 Softmax 的概率分布会发生细微变化,但只要这个变化不足以让另一个字的概率超过当前最大值,
argmax的结果就完全不会变。()- 当两个字的概率正好相等时,函数发生突变,此时导数不存在。
- CNN 的 Max Pooling 可以求导:
- Max Pooling 虽然也取最大值,但它是在连续的特征值空间里操作的。
- Max Pooling 选出的是滑动窗口里那个最大的具体数值。
- 因此梯度是分段线性的:
- 假设输入是 ,输出 。
- 如果 ,那么 ,此时 。
- 如果 ,那么 ,此时 。
- 在反向传播时,梯度会顺着当初“胜出”的那个路径原路返回,而没被选中的路径梯度为 0。
SrachGAN: Training Language GANs from Scratch
Evaluation of GAN
Quality
/IMG-20260131201613944.png)
这个分类模型和 Discriminator 的区别在于:
- 考虑单一图片
- 其并非要区分真假图,而是要把假图作为输入,尝试对图片分类
- 一般是训练好的通用的模型,避免学习假图特征强行分类
- 使用此方法,不关心到底输入的是什么类型的东西,只要看分类出来概率是不是集中(也就是是不是有一个概率很高的类,机器笃定其为某一类),就能知道是不是一个特征明显的图片,也就是是不是高质量的图片
Diversity
但是还存在 Mode Collapse 的问题:
/IMG-20260131201614127.png)
也就是 G 发现了一种能够欺骗 D 的捷径,一直只生成一些类似的图,从而导致其生成的样本多样性极度匮乏的现象
以及 Mode Dropping 的问题:
/IMG-20260131201614311.png)
生成的图片有多样性,也符合目标的分布,但是只是目标分布的一部分,没有真正覆盖完整的真实分布,剩下的可能分布状态(Mode)没有学到。意味着 G 在学习过程中,为了降低风险(更容易骗过 D),决定放弃掉某些难画的模式,只保住几个最稳妥的模式。
所以…
/IMG-20260131201614423.png)
这个指标,与刚刚的评估 Quality 不同:
- 考虑多张图片
- 计算每一张图的分类概率,加和平均,观察是否平均
- 概率数值要分散/均匀。这代表“画得全”,识别了多种模式,具有好的多样性
也就是,两个评估都是要把生成的图送给分类模型评估分类概率,Quality 看的是每一张图片概率是否集中,Diversity 看的是所有图片的概率按类别的平均是否均匀
Inception Score(IS)
这即是 Inception Score,也就是:
也就是:
Frechet Inception Distance(FID)
/IMG-20260131201614549.png)
也比较常用,不过把分布视为 Gaussians 会有一些问题。
虽然还有另外的问题:如果 GAN 只是简单输出原始分布的图片呢?或者只是进行简单的“augment”
/IMG-20260131201614715.png)
Conditional Generation
- 文生图
/IMG-20260131201614781.png)
/IMG-20260131201614866.png)
- 图生图 (pix2pix)
/IMG-20260131201614965.png)
- 音生图
/IMG-20260131201615039.png)
Cycle GAN:Learning from Unpaired Data
GAN 也可以用在 Unsupervised Learning 上,例如图片风格转换、文字风格转换,一般没有成对数据标注。
/IMG-20260131201615140.png)
以前,我们输入一个正态分布,然后输出一个复杂分布。现在,我们把输入也变为一个复杂分布,输出另一个复杂分布。
/IMG-20260131201615268.png)
但是直接套用 Conditional GAN 的思路,会导致我们的输入可能被当作高斯噪声,直接无视,生成一张和原图无关的图。
/IMG-20260131201615363.png)
所以我们额外训练一个 Generator,用于还原
/IMG-20260131201615457.png)
可以再做一个双向的:
/IMG-20260131201615555.png)
也可以用于文字,但是对于 D 有问题(argmax 导致的梯度无法传递,详见前面的说明:GAN 文字生成),需要 RL。
/IMG-20260131201615678.png)