WGAN with Gradient Penalty (WGAN-gp) 惩罚该罚的
学习资料:
- 我制作的GAN简介短片
- 我制作的GAN的问题漏洞短片
- 论文 Improved Training of Wasserstein GANs
- 本节代码
- 代码有我自己定义的依赖utils.py, visual.py, mnist_ds.py
怎么了¶
我们已经在这个介绍短片中、 LSGAN和WGAN 中, 都已经详细介绍了GAN的各种难训练的地方,比如生成器和判别器的能力不平衡,梯度传导问题,模式坍缩 mode collapse等。 LSGAN采用了一种 Least Squares loss的计算方式来加大gradient的有效传递, 而WGAN也给出了一种新的 Wasserstein distance 来解决梯度和坍缩的问题。 这次要提到的 WGAN-gp 是WGAN的一种改良版本,全称是 WGAN with gradient penalty。
在前面的教程中我们已经证明了推土机距离的有效性,同时也提到了优化时的一个约束条件 - 1-Lipschitz, WGAN 为了实现这个约束,使用了 clip 截断了判别器 weights,不让判别器浪起来。 如果用一句话概括 WGAN-gp 的改良:用一种梯度惩罚的方法替换 clip weights,让判别器在该浪的地方浪,不该浪的地方不浪。
在WGAN-gp论文中,它提到了WGAN使用clip方式所引发的问题,我们重点看看下面的右边(b)这张图,很多颜色线条那个是随着判别器层数增加, Clip 方案中梯度传导是有问题的,要么爆炸要么消失了,而 Gradient penalty 方案可以让每一层的梯度都比较稳定。再来看看最右边的图, Clip 方案网络中 weights 参数都跑到的极端的地方,要么最大,要么最小,而 Gradient penalty 方案可以让 weights 比较均匀地分布。
在我这个mnist案例中,也有同样的规律,对比WGAN的结果,WGAN-gp的效果显然更好。
怎么训练¶
如果你还没有WGAN的概念,不知道WGAN在做什么,我强烈建议你先看完我写的WGAN教学。我们将基于WGAN的理解, 继续下面的探讨。在WGAN-gp原文有这么一句话。
We now propose an alternative way to enforce the Lipschitz constraint. A differentiable function is 1-Lipschtiz if and only if it has gradients with norm at most 1 everywhere, so we consider directly constraining the gradient norm of the critic’s output with respect to its input.
其实我也不是太懂为什么 Lipschitz 约束和 gradients 的 norm=1 有什么关系,能解释这点的朋友也欢迎留言讨论。 反正我就拿着这个原理来用了。从上面的这句表述,作者就提出了一个 gradient penalty 方法。公式在下面, 里面表达的是它在WGAN的loss上加了一个惩罚项,如果判别器的 gradient 的 norm,离 1 越远,那么 loss 的惩罚力度越高。
还要注意的是,一般的GAN,我们通常会在判别器上加一些 batchNorm,但对于WGAN-gp的判别器,是不能加 batchNorm 的,原因很简单, 是因为WGAN-gp的惩罚项计算中,惩罚的是单个数据的gradient norm,如果使用 batchNorm,就会扰乱这种惩罚,让这种特别的惩罚失效。 当然你可以绕过 batchNorm, 使用 layerNorm 或者 InstanceNorm。 这两种 Norm 的方式你可以在这篇文章 (中文也有类似介绍)详细了解其不同。我也放一个文章中的使用的非常直观的图。
下图就是论文中使用GP和其他的GAN方案在床照中的对比了,其中包括 DCGAN,LSGAN 和原始的WGAN。
最后我们看看这么牛逼的 WGAN-gp 到底怎么写的。还是会有一个循环套循环的步骤,在每一次训练生成器时,要多训练几次判别器, 判别器首先需要采样一次正式数据和生成数据,然后拿着生成数据和真实数据去计算 gradient penalty. 计算 gradient penalty 的时候有几个步骤。
- 拿到生成数据
- 将生成数据和真实数据按一个比例混合(在照片数据值上的混合),这是因为 1-Lipschitz 的假设条件推理出来的,解释写在下面。
- 用这个数据输入判别器,拿到输入判别器图片数据的梯度,注意这里并不是判别器网络weights的梯度
- 对梯度计算 norm,看看这个 norm 离单位距离 1 有多远(离1越近,惩罚越小)
对于上面第2点,为什么要用真假数据进行一个插值处理?我在这个知乎上看到了一个比较好的解释: 但问题是我们要求 ‖T‖L ≤ 1 是在每一处都成立,所以数据应该是全空间的均匀分布才行, 显然这很难做到。所以作者采用了一个非常机智(也有点流氓)的做法: 在真假样本之间随机插值来惩罚,这样保证真假样本之间的过渡区域满足 1-Lipschitz 约束。
秀代码¶
如果想直接看全部代码, 请点击这里去往我的github.
WGAN—gp 和以前WGAN的在代码上的差异不算太多,所以我下面会基于WGAN的代码来写WGAN-gp的代码。 所以如果你还不太清楚 WGAN 的工作机制,我强烈建议你先看我写的这个WGAN教学。正题来了, 我选择直接继承WGAN的 class,我们在哪些地方修改WGAN呢?
_get_discriminator(use_bn=False)
构建判别器的时候不能用 batchNorm, 前文也提了原因gp()
要多一个 gradient penalty 的计算过程train_d()
的时候要加入 gradient penalty
其他步骤就和训练一个 WGAN 一模一样啦。那我们先从 train_d()
这个方法开始。代码非常直白,基本上不用解释,就是在原本的 w_loss
后面加上了 self.lambda_ * gp
。
重点是在如何计算 gp 上,所以我们需要单独定义一个 gp()
功能。下面这种写法在tensorflow上还不是很流行。 步骤拆解成代码还算比较容易的,只是在计算 gradient penalty 的时候我们的要拿到的 gradient 并不是判别器参数的 gradient, 而是加工后的输入图片的 gradient
,也就是说,把图片当成网络 weight,原本的网络 weight 当成神经网络输入,去获取图片的 gradient。
最后一个epoch训练出来的效果还挺好的,比WGAN好多了。
总结¶
WGAN在为了保证 1-Lipschitz 这种约束条件时,使用了 clip weights 的方法,局限住判别器的能力,但是这种一刀切肯定是不好的。 WGAN-gp为了解决这种一刀切,提出了一种更柔和的方式,也就是使用 gradient penalty 来约束。 从训练结果来看,这种约束的确也发挥了非常好的效果。