WGAN with Gradient Penalty (WGAN-gp) 惩罚该罚的 - 生成模型 GAN 网络 | 莫烦Python

WGAN with Gradient Penalty (WGAN-gp) 惩罚该罚的

作者: 莫烦 编辑: 莫烦 2021-03-26

学习资料:

怎么了

我们已经在这个介绍短片中LSGANWGAN 中, 都已经详细介绍了GAN的各种难训练的地方,比如生成器和判别器的能力不平衡,梯度传导问题,模型坍缩 model collapse等。 LSGAN采用了一种 Least Squares loss的计算方式来加大gradient的有效传递, 而WGAN也给出了一种新的 Wasserstein distance 来解决梯度和坍缩的问题。 这次要提到的 WGAN-gp 是WGAN的一种改良版本,全称是 WGAN with gradient penalty。

前面的教程中我们已经证明了推土机距离的有效性,同时也提到了优化时的一个约束条件 - 1-Lipschitz, WGAN 为了实现这个约束,使用了 clip 截断了判别器 weights,不让判别器浪起来。 如果用一句话概括 WGAN-gp 的改良:用一种梯度惩罚的方法替换 clip weights,让判别器在该浪的地方浪,不该浪的地方不浪。

在WGAN-gp论文中,它提到了WGAN使用clip方式所引发的问题,我们重点看看下面的右边(b)这张图,很多颜色线条那个是随着判别器层数增加, Clip 方案中梯度传导是有问题的,要么爆炸要么消失了,而 Gradient penalty 方案可以让每一层的梯度都比较稳定。再来看看最右边的图, Clip 方案网络中 weights 参数都跑到的极端的地方,要么最大,要么最小,而 Gradient penalty 方案可以让 weights 比较均匀地分布。

paper analysis

在我这个mnist案例中,也有同样的规律,对比WGAN的结果,WGAN-gp的效果显然更好。

results

怎么训练

如果你还没有WGAN的概念,不知道WGAN在做什么,我强烈建议你先看完我写的WGAN教学。我们将基于WGAN的理解, 继续下面的探讨。在WGAN-gp原文有这么一句话。

We now propose an alternative way to enforce the Lipschitz constraint. A differentiable function is 1-Lipschtiz if and only if it has gradients with norm at most 1 everywhere, so we consider directly constraining the gradient norm of the critic’s output with respect to its input.

其实我也不是太懂为什么 Lipschitz 约束和 gradients 的 norm=1 有什么关系,能解释这点的朋友也欢迎留言讨论。 反正我就拿着这个原理来用了。从上面的这句表述,作者就提出了一个 gradient penalty 方法。公式在下面, 里面表达的是它在WGAN的loss上加了一个惩罚项,如果判别器的 gradient 的 norm,离 1 越远,那么 loss 的惩罚力度越高。

gp

还要注意的是,一般的GAN,我们通常会在判别器上加一些 batchNorm,但对于WGAN-gp的判别器,是不能加 batchNorm 的,原因很简单, 是因为WGAN-gp的惩罚项计算中,惩罚的是单个数据的gradient norm,如果使用 batchNorm,就会扰乱这种惩罚,让这种特别的惩罚失效。 当然你可以绕过 batchNorm, 使用 layerNorm 或者 InstanceNorm。 这两种 Norm 的方式你可以在这篇文章 (中文也有类似介绍)详细了解其不同。我也放一个文章中的使用的非常直观的图。

different norm

下图就是论文中使用GP和其他的GAN方案在床照中的对比了,其中包括 DCGAN,LSGAN 和原始的WGAN。

gans comparison

最后我们看看这么牛逼的 WGAN-gp 到底怎么写的。还是会有一个循环套循环的步骤,在每一次训练生成器时,要多训练几次判别器, 判别器首先需要采样一次正式数据和生成数据,然后拿着生成数据和真实数据去计算 gradient penalty. 计算 gradient penalty 的时候有几个步骤。

  1. 拿到生成数据
  2. 将生成数据和真实数据按一个比例混合(在照片数据值上的混合),这是因为 1-Lipschitz 的假设条件推理出来的,解释写在下面。
  3. 用这个数据输入判别器,拿到输入判别器图片数据的梯度,注意这里并不是判别器网络weights的梯度
  4. 对梯度计算 norm,看看这个 norm 离单位距离 1 有多远(离1越近,惩罚越小)

对于上面第2点,为什么要用真假数据进行一个插值处理?我在这个知乎上看到了一个比较好的解释: 但问题是我们要求 ‖T‖L ≤ 1 是在每一处都成立,所以数据应该是全空间的均匀分布才行, 显然这很难做到。所以作者采用了一个非常机智(也有点流氓)的做法: 在真假样本之间随机插值来惩罚,这样保证真假样本之间的过渡区域满足 1-Lipschitz 约束。

pseuducode

秀代码

如果想直接看全部代码, 请点击这里去往我的github.

WGAN—gp 和以前WGAN的在代码上的差异不算太多,所以我下面会基于WGAN的代码来写WGAN-gp的代码。 所以如果你还不太清楚 WGAN 的工作机制,我强烈建议你先看我写的这个WGAN教学。正题来了, 我选择直接继承WGAN的 class,我们在哪些地方修改WGAN呢?

  1. _get_discriminator(use_bn=False) 构建判别器的时候不能用 batchNorm, 前文也提了原因
  2. gp() 要多一个 gradient penalty 的计算过程
  3. train_d() 的时候要加入 gradient penalty

其他步骤就和训练一个 WGAN 一模一样啦。那我们先从 train_d() 这个方法开始。代码非常直白,基本上不用解释,就是在原本的 w_loss 后面加上了 self.lambda_ * gp

重点是在如何计算 gp 上,所以我们需要单独定义一个 gp() 功能。下面这种写法在tensorflow上还不是很流行。 步骤拆解成代码还算比较容易的,只是在计算 gradient penalty 的时候我们的要拿到的 gradient 并不是判别器参数的 gradient, 而是加工后的输入图片的 gradient,也就是说,把图片当成网络 weight,原本的网络 weight 当成神经网络输入,去获取图片的 gradient。

最后一个epoch训练出来的效果还挺好的,比WGAN好多了。

res

总结

WGAN在为了保证 1-Lipschitz 这种约束条件时,使用了 clip weights 的方法,局限住判别器的能力,但是这种一刀切肯定是不好的。 WGAN-gp为了解决这种一刀切,提出了一种更柔和的方式,也就是使用 gradient penalty 来约束。 从训练结果来看,这种约束的确也发挥了非常好的效果。


降低知识传递的门槛

莫烦很常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 通过 【赞助莫烦】 能让我感到认同,我也更有理由坚持下去。