Least Squares GAN (LSGAN) 要不换一种loss吧 - 生成模型 GAN 网络 | 莫烦Python

Least Squares GAN (LSGAN) 要不换一种loss吧

作者: 莫烦 编辑: 莫烦 2021-03-23

学习资料:

怎么了

前面经典GAN教学DCGAN中, 我们的生成网络已经能有一个相当不错的效果了,但是当数据形态多样化后,数据尺寸变大后,一些问题还是慢慢浮现出来了。包括模型坍塌(model collapse), 训练不稳定,梯度消失,难训练等等问题都可以在我上一个简介视频中了解到。

而这次介绍的 LSGAN,主要就是为了解决梯度消失和模型坍塌这一系列问题所提出的一种方法。后续我们还会介绍更多方法,很多方法都是尝试用不同的loss定义来让模型更有效的传递梯度, 以至于更有效地更新网络。用一句话来描述 LSGAN:它是这样一种对训练 loss 的改进方法。它使用了 Least Squares 作为 loss,想让梯度能更有效的传递回去。

results

怎么训练

具体来说,前面我们介绍的经典GANDCGAN 使用的是 sigmoid cross-entropy 来计算 loss。这个函数的特性是什么呢?它的前向并没有什么问题,非线性变化后,将取值空间从原本的正负无穷压缩至了0-1的区间, 问题出在 sigmoid 的梯度上。众所周知,梯度反向传播是神经网络模型更新的命门,如果梯度不能有效传递会网络,那么1)网络更新十分缓慢;2)网络趋向于停止学习。

sigmoid

你在看看上图的红虚线部分,这是 sigmoid 的梯度范围,可见只有非常有限的区间才有梯度,而且梯度的值也比较小。相对比下面其他的梯度,对比起来就比较明显了。

activation derivative

只有 sigmoid 的梯度值是最小的。那么 LSGAN 就想通过修改计算 loss 的公式,将 sigmoid 替换成 least squares. 下面第一张有 log 的图是 sigmoid 方式, 而下面第二张没有 log 计算,取而代之的是利用方差的模式,也就是 MeanSquaredError 的计算公式了。这样,能传递回去的梯度就大很多了,理想情况下, 更新GAN的效果也会好很多。

sigmoid loss

ls loss

LSGAN相对后面会介绍的 WGAN,理解起来还是相对容易一点的,也是在训练GAN的时候,一种值得一试的 loss 替换方法。 在mnist手写数据集上的训练结果看起来也相当不错。

秀代码

如果想直接看全部代码, 请点击这里去往我的github.

和以前DCGAN的差异在于:

  1. step() 中的 loss function 不同

所以你会看到,我代码库中的 LSGAN 直接集成了 DCGAN。如果你还对DCGAN不太了解, 不知道GAN是怎么生成图片的,强烈建议你先过一遍我的 DCGAN 教学, 这里面的步骤非常详细。

下面我只用写 LSGAN 和 DCGAN 不同的 step() 更新功能,所以在此强调,你一定要有 DCGAN 的基础,

没啦,就这么简单,仅仅只是替换了一个loss而已。 最终训练20个epoch的结果还不错,能够生成人模人样的手写数字了,虽然有些效果可能还不是很好,但是你还是可以辨别出它写的大概是啥。

res

括展思考

在 LSGAN 的论文中有这样一张图。 他想表明 LSGAN 在生成多样性上的优势,也就是 LSGAN 能够学习到大多数数据的分布,而不会像经典GAN那样坍缩到一种数据分布上。 简单来说,就是GAN在学 mnist 的时候,只学会了生成某一种数字(比如1),其他数字GAN就不管了, 反正如果我只生成一种足够真实的数字,判别器也不能拿我怎样。那我还不如安全起见,只学会生成一种就好了。 比如下图中,上面部分是我们希望的多样性,但是模型可能只生成下半部分的6,这就是 collapse 了。

collapse

这就是困扰 GAN 的头号问题之一 model collapse. 可以说,gradient 的有效传递,的确可以减缓 model collapse, 但是也不能让 collapse 消失。 如果你做大型项目,LSGAN 还是可能坍缩到固定的某些生成方案上。

lsgan learned distribution

另外一个我想提的点是,在我的多次试验中,LSGAN 的确会比经典GAN有更好的收敛效果, 但是还是存在一个偶发性问题。有时候,在一开始训练不久后,就学废了,崩了。具体原因不明,但我猜测可能是初始的时候,有可能梯度过于大, 让网络更新力度太大,冲昏了头脑,所以学废了。

总结

经典GAN的loss定义是存在问题的,sigmoid函数导致梯度没办法有效传到,引发了 model collapse 和一些更新效率上的问题。LSGAN 尝试从 loss 定义上修改训练方案, 用 Least Squares 代替 sigmoid,让梯度传递更顺畅。后面我们沿着这样的角度,拓展一下,看看还有哪些方法能够解决梯度的这个问题。


降低知识传递的门槛

莫烦很常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 通过 【赞助莫烦】 能让我感到认同,我也更有理由坚持下去。