深度卷积 DCGAN - 生成模型 GAN 网络 | 莫烦Python

深度卷积 DCGAN

作者: 莫烦 编辑: 莫烦 2021-03-23

学习资料:

怎么了

前文我们介绍了最基础的GAN网络,如果你对GAN生成网络还不了解, 直接从这个我简短介绍GAN的教学开始对你更友善。

简单来说,GAN是一种生成网络,它通过判别器(discriminator)和生成器(generator)打配合,最终训练出一个可以无限制生成数据的模型。 而在前文介绍的这个最经典款GAN,它使用到了卷积神经网络 CNN 来做生成。 从下图在论文中截取的图中,可以看到生成的能力还是十分有限的。发白发胡的地方很多。

gan res

接下来,很多研究者见识到GAN可以大有所为,纷纷都朝这个方向做研究,其中有一个打开了一种CNN生成图片的标配方案。它精心设计的CNN的 decoder,让生成的图片质量提升了不少。下图是论文中生成的床照图:

bed

可以发现,这些图明显比原始的GAN要好很多。就此,无数研究者开启了秃头模式,开始摸索GAN的无数改进方式。 为了让你保持发量,我在这系列教程中,全部使用mnist手写数据集的数据,让各种GAN的方法都能匹配mnist,做到横向对比。 今天你将看到的DCGAN学习效果如下。让我们开启愉快的学习之旅吧。

results

怎么训练

在DCGAN中,相比经典的GAN, 它的DC是什么呢?这个DC是 Deep Convolutional 的意思, 同时也表明了,这种网络,在生成器上选取了一个生成效果比较好的 Deconvolution (DCGAN论文觉得叫Deconvolution不太好,管他呢,反正大家都已经这么叫了) 方案。 什么是 Deconvolution 呢?也就是反卷积的意思。

  • 卷积:图片 -> 特征向量
  • 反卷积:特征向量(噪声)-> 图片

而DCGAN提出的是一种相对有效的反卷积方案。其他的配置上,和传统的GAN也没多大差别。在里面值得注意的是,经典CNN中,我们常用pooling来进行信息筛减。 但是对于生成中,需要图片信息扩充这种操作的时候,pooling并不能很有效地做到这点,因为pooling不是矩阵运算,而是简单的求平均或者是取最大值。 所以pooling适合信息筛减而不适合信息扩充生成。

deconv

效果

DCGAN 的训练过程,和经典GAN的过程没什么差别,我们就按照上次提到的方法训练就好。 不过我们将上次使用的一元二次数据换成了mnist。minst手写数据长成下面这样。我们的目的就是让DCGAN生成和下面很像的图片。

mnist

而生成器训练的过程和效果就体现在下面这张动图。

trained

秀代码

因为代码量还算比较多,我来展示最核心的一些部分,如果想直接看全部代码, 请点击这里去往我的github.

和以前经典GAN的差异在于:

  1. generator/discriminator的网络定义不同
  2. 数据不同

我们就根据这个来修改一下对应的代码吧。我有一个自己产生数据的功能 get_half_batch_ds(),这个 get_half_batch_ds 定义在了自己写的 mnist_ds.py 中。 所有依赖文件你都可以在 https://github.com/MorvanZhou/mnistGANs 中找到。

Training 的步骤没有太多变化,还是epoch的方式,只是我们在每一个step训练时,把 mnist 预加载到一个 tensorflow 的 dataset。每一步从dataset拿图片数据。

那么 DCGAN 怎么定义呢?相比经典GAN,就只是 generator 和 discriminator 不太一样。

  • 生成器:
    • 输入:随机噪点(想像力)
    • 输出:和mnist一样结构的图片
  • 判别器:
    • 输入:图片
    • 输入:yes/no 二分类

后面的教程中,为了统一 generator 的CNN结构,便于不同算法的横向对比,我规定了下面这种 deconvolution 方式。 简单来说,就是将噪点数据通过全连接,再 reshape 成一个三维的 tensor,把这个三维 tensor 不断 deconvolution,转化成一张图片。 下面也详细描述了这种 deconvolution 的 shape 变化。

deconv

Conv2DTranspose() 是用来扩张图片信息的做法。正常的卷积操作是将上面大的图片信息抽取到西面小的特征图中,而 Conv2DTranspose() 则是将下面的当作输入, 输出上面的大图。这张解释的gif动图是我从这里拔下来的, 他还做了GAN的很多其他的可视化,都挺直观的,建议也看看。

而discriminator就更好理解了,其实就是一个用于识别的CNN。定义如下,我就不详细说明了,学过卷积神经网络的同学都很清楚。 同样,为了兼容后续的教程,我这里还定义了一个是否使用 batchNormalization, 有些论文说 batchNorm 可以有效增强生成的效果,有些又说不好, 我在后续的教程中会继续讨论这个点。

mnist_uni_disc_cnn()mnist_uni_gen_cnn() 我都封装到了 https://github.com/MorvanZhou/mnistGANs 中的 gan_cnn.py 文件中,作为项目的依赖。

在DCGAN中,我们使用这两种定义好的网络作为它的 generator 和 discriminator。

你会发现 _get_discriminator() 中,我还嵌套了一个 keras.layers.Dense(1),为什么我不写在 mnist_uni_disc_cnn() 这个里面呢? 其实这里是别有用心,为了以后可以兼容到后面不同算法的 discriminator 的。后面不同的 GAN 算法,它的 discriminator 有可能不同。 哈哈,这其实是一个预告/剧透。

其他的训练方法和 经典GAN 就没有差别了,我这里就快速过一下, 学过我的经典GAN教学的同学,后面这一步可以直接跳过。

最终训练20个epoch的结果还不错,能够生成人模人样的手写数字了,虽然有些效果可能还不是很好,但是你还是可以辨别出它写的大概是啥。

res

哈哈,我想起来一个小插曲,和你分享一下,我在 github 上传一些 GAN 生成出来的 mnist 动图时。github 突然把我的账号封了。。说我涉嫌发布色情视频。 汗。截个图记录一下。从此,我就对账号安全提起了警戒。在国内立马开了一个代码仓库的备份账号 https://gitee.com/MorvanZhou

github suspend

res

从这个动图看起来,还的确有那么几番特殊体态的动作的意思,捂脸。

括展

对于用噪声生成图片,我们还可以玩点有趣的,对噪点(特征)进行加减法,我们就能得到图片含义上的编辑。 比如 笑脸女 - 正常女 + 正常男 = 笑脸男

ops

眼镜男 - 正常男 + 正常女 = 眼镜女

ops

这些都是在原始噪点上的筛选和编辑,后续我们还会将如何按照你的个人意图,做更加可控的生成方案。

总结

GAN来生成图片是GAN最基本,也是最主流的使用方法,这个DCGAN是一个开端,但是 DCGAN 仍然还有很多问题,比如生成图片质量不算高, 图片再大一点就很难train出好效果,图片大的话,训练稳定性也有问题。所以后续我们还会介绍更多GAN算法。


降低知识传递的门槛

莫烦很常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 通过 【赞助莫烦】 能让我感到认同,我也更有理由坚持下去。