Wasserstein GAN (WGAN) 解决本质问题 - 生成模型 GAN 网络 | 莫烦Python

Wasserstein GAN (WGAN) 解决本质问题

作者: 莫烦 编辑: 莫烦 2021-03-25

学习资料:

怎么了

这个介绍短片中LSGAN 中, 我们都提到了 GAN 会有的一些显著问题。比如生成数据量大(比如高于64*64像素)就很容易因为GAN的能力不足而 model collapse 模型坍缩了; 甚至还有时候根本就 train 不起来,一直是生成四不像的状态。这还不像监督学习,只要加深网络或者改变一下参数配置就可以提升识别效果。 训练 GAN 真的是一件掉头发的事情,在行内还有人说,因为训练GAN时,生成器和判别器是一种对抗的状态, 任何一方太强,都会碾压对方,对抗的平衡被打破,训练就会失败。以至于需要自己摸索出一套 magic tricks 才能偶然训练好。

听到这种结论,真的会劝退一波刚入坑的 GANer 们。好在也有非常多的学者在研究怎么解决这些问题。当中比较有名的一项研究就是 WGAN 了。 一句话描述 WGAN:它用一种更好的方式来测量生成数据与真实数据的差距,并用这种方式来拉近他们的距离。 这使得 GAN 对对抗平衡不那么敏感。我在下面一节会详细阐述。

先来看看这次的训练结果吧。下面看到原始的 WGAN 训练出来的 mnist 并不完美, 但是后续我们会介绍一种 WGAN 的变种 WGAN gp, 它又将训练效果往上推了很大一步。

results

什么问题

在优化生成数据和真实数据的时候,GAN的核心任务就是去拉近生成数据和真实数据的数据分布。

有同学可能对数据分布没有什么概念。数据都有数值,谈论数据分布,实际在谈论的是一批数据的数值情况,从分布上可以看出他们大部分落在那些区间。 也就是开可以看出它们的分布情况。

data distribution

而GAN要处理的有两个数据分布,一个是真实数据的分布,一个是生成数据的分布。如果生成的数据和真实数据一点都不像,那么我们就说生成的数据分布和真实数据分布离得很远, 反之离得近。像下图中,old product 和 new product 都有自己的数据分布,在 income distribution 上,两者数据分布离得远,重合部分少。 而在 age distribution 上,两者离得近,分布重合多。

two distribution

理解了这些,我们可以开始讨论 GAN 是怎么玩数据分布的。经典 GAN 中,它实际上使用的是一种叫做 JS 散度的数据分布测量方法。 知乎上有一篇文章令人拍案叫绝的Wasserstein GAN, 它从数学公式的角度上分析了经典 GAN 用下面的loss时,在 JS 散度上的一种缺陷。 分析过程太长了,有兴趣的同学可以去这个知乎看一看,里面的结论我在这里说一下。

loss

两个数据分布的 JS 散度越小,则两个数据分布越近,生成数据和真实数据也越像,GAN就像通过拉近 JS 散度来优化模型,这没什么问题。 但是 JS 散度本身有一个缺陷。当两个分布没有重叠部分,或重叠部分比较小的话,它的可以用来更新模型的梯度就是0,或忽略不计。 这种梯度为0的情况很常发生在当你的判别器能力比生成器要强很多的时候(对抗平衡被打破)。下图来自这里

js distribution

好,不重叠会有问题,那么这种分布不重叠的发生的概率大吗?的确挺大的。上面那个知乎链接解释得很到位了, 我就不复述了(涉及了太多数学和空间的概念,理解起来的确有些难度)。 结论就是:如果你不小心(非常容易不小心)把你的判别器训练得比生成器好(非常容易训练得比生成器好),那么就很容易导致更新网络的梯度很小,学不动,学废了。

上面讲述了经典 GAN 在梯度上不给力的问题,还有一个问题也是 WGAN 想要解决的,这就是 model collapse 模型坍缩成生成固定的模式,生成多样性非常少。 比如下图中,上半部分是我们希望的多样性,但是模型可能只生成下半部分的6,这就是 collapse 了。

collapse

判别器因为GAN原始目标函数目标不一致(上面知乎有细说)。 导致模型更愿意放弃多样性,生成一种模式,这样它认为更安全一点。下图是论文中的对比图,都是床照,左边的是WGAN在不同限制条件下生成的床照, 右边是DCGAN在同样限制下生成的。比如中间张,如果拿掉BatchNorm,WGAN好像没受什么影响,但是DCGAN已经泣不成声了;最下面一张,WGAN用全连接代替CNN生成网络, DCGAN用等参数量的配置,DCGAN再一次泣不成声,而且还 model collapse。

dcgan comparison

怎么训练

不研究不知道,一研究就发现原来GAN有这么多问题。基于上面的分析,WGAN 就尝试使用不同的 loss 来解决上述问题。如果说 JS 散度对GAN的训练不友善, 那 WGAN 就换掉它,取而代之的是 Wasserstein 距离,也称为 Earth-Mover (EM) 推土距离。

wasserstein distance

这个公式乍一看没啥头绪,我们用可视化的方法解释一下它在干什么。 下面的图都是从 这里 运过来的, 你也可以在这个 知乎 上找到中文翻译。

earth mover

还记得我们前文提过的,GAN 学着拉近生成数据和真实数据的 数据分布 吗。这个 Earth-Mover 也是要拉近分布,它用的方法是把两个数据的分布变成土堆。 把生成数据的土堆,推成真实数据的土堆。怎么推?就选消耗的工作量最少的推法。 上图就是将高低不同的土堆量给填充补全成目标土堆的样子。而最优的推土搬运成本就是 Wasserstein 距离。 所以 Wasserstein 距离也就被称为 Earth-Mover distance 推土机距离。

max W distance

最后的训练伪代码就如下面所示了。从真实和生成样本中抽样后, 让 discriminator 计算这批数据中真实=真生成=真的强度,两者之差是 EM 距离。discriminator 本身要最大化这个距离, generator 的任务就变成了最小化 EM 距离。 避免以前 loss 中的 sigmoid 和 log 的部分后,梯度也顺畅很多。

pseudocode

这里还有一个 clip 网络参数的操作,说是因为判别器需要满足 1-Lipschitz(利普希茨连续),将其的更新约束起来, 限制判别器的能力,让判别器别那么浪, 而 clip 是一种简单粗暴的满足方法。这是一个数学概念,我其实也没有特别理解(知道比较形象解释的同学欢迎在下方留言)。 后面的 WGAN-gp 用了一种比 clip 更好的方式,后面我们接着讲。

秀代码

如果想直接看全部代码, 请点击这里去往我的github.

下面我们先看看整体的训练循环,在WGAN中,需要将判别器训练多次,目的是为了得到一个稍微好点的判别器, 应该也有可能是为了使受到更新约束(1-Lipschitz)的判别器在约束下多学一点,不至于太菜。接着在训练一次生成器。 所以这里就明摆着,我的训练出一个好点的判别器,然后让这个判别器好好调教一下生成器。而经典GAN就不能这么做,经典GAN需要让判别器和生成器实力相当。

下面我们看看WGAN的构造可以怎么搭建,我省去了 _get_generator()_get_discriminator() 的步骤,因为这两步和我们之前的 DCGANLSGAN 一样。都是拿一个反卷积一个卷积出来。 下面代码里我们着重突出 WGAN 的特别之处。所以你想看没有被省略的代码的话,请点我的Github.

我们在上面定义的 w_distance() 这个就是 WGAN 的核心重点了。而且 clip 网络参数的步骤你也能在 train_d() 里找到,当我们训练完判别器,在截断它的参数, 让它别限制起来,不让它浪。 但是有一个不好的地方,因为在训练时,让判别器多训练了很久,所以训练时间要长上不少,所以你最好要有一个GPU。

最后的训练结果是这样,可能也是我的CLIP卡得太死了,导致判别器的能力也被卡得太死,学出来的效果不是很好。你可以尝试加大Discriminator的网络来解决这个问题。 后面的WGAN-gp教程我们就来解决这个卡死的问题。

wgan res

总结

WGAN 从 GAN 的理论上深度剖析了 经典GAN loss 的劣势,并提出了释放GAN学习能力的方法, 重新用 Earth-Mover (EM) distance 或者说是 Wasserstein (W) distance 释放GAN的学习力。不过由于里面用了一个 clip 截断网络参数的做法, 让GAN的学习力还没有完全的发挥出来,所以后续也有人提出了更好的释放学习力方法。 接下来的 WGAN-gp教程 中,我们就一起看看吧。


降低知识传递的门槛

莫烦很常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 通过 【赞助莫烦】 能让我感到认同,我也更有理由坚持下去。