经典款 GAN 网络

经典款 GAN 网络

作者: 莫烦 编辑: 莫烦 发布于: 2021-03-21

学习资料:

怎么了

机器学习,深度学习他们通常只在做一件事 - 分析输入,找到这个输入的一个标签。我们通常将这种模式称之为识别。 如果我想凭空让机器产生一幅画,一段曲子,用监督学习学习出来的模型,是不具备创作能力的。为什么是这样呢?

原来,监督学习的训练目标是给每一个数据,找到一个正确的标签/描述/数值。模型并不想锻炼想象力

我们在GAN的简介中提到了英语的考试, 监督学习就好比信息分析型考试,而则是信息输出型,这两种技能项是有差别的。

同理,监督学习学的是如何分析信息,而生成模型想要解决的是怎么输出信息。

另外一个重要的点是,生成模型不是监督学习,它算是一种非监督学习,导致它生成的东西可能在现实世界中压根不存在。 你也可以认为它这是在发挥想象力

这次教学,我们用一个最简单的例子来展示,生成模型在做的是什么事情。 下面这种动图就是生成模型学着画出一条一元二次曲线的过程。 你可以看到最开始的阶段,它画得并不好,但是随着学习的渐进,它能越画越像一条正宗的一元二次曲线。

results

很有幸,这一节内容,你是可以用CPU就能完成的,因为GAN实在太难训练了,后续我们要做生成图片的时候,训练时间会消耗更久, 我这系列都是用手写数字mnist来做实验,比较复杂的模型,在CPU上的训练时间可能要达到1个小时甚至更久。 后续的教学,你最好是有一张GPU卡,它能大大加速你的训练时间。

所以为了不吓唬你,至少让你入个门,我这一节,专门设置了一个CPU友善型的训练任务 - 生成一元二次曲线。

GAN的训练方法相比传统的监督学习不太一样,究竟不一样在哪,我们下面来介绍。

怎么训练

首先你需要搞懂,是谁在引导GAN的生成模型学习,它的老师是谁?在监督学习中,我们有一个监督信号,预测的Y和真实Y的差别就是模型的老师。

supervised

而在GAN中,我们是凭空生成一个Y。当然,这里的凭空是一个假凭空,这里的凭空指的是用想象力 X 来生成一个 Y。就像这里用想象力生成可爱猫咪的照片。 但是我怎么判断生成的猫咪好不好呢?

cats

如果没有接触过GAN的同学可能凭直觉就能想到,我就拿真实猫咪的target,让模型学习就好啦。当然这的确是一种可以接受的方法。 不过这种方法并不能达到我们的目的。因为每种想象力可能对应会生成不同的猫咪,但是我们怎么确定哪种想象力对应哪种target的猫咪呢? 显然用监督学习的方法在这里的确是有问题的。

cat pred target

还有什么办法解决这个对应错位的问题呢?如果考核target换成判断它是否生成的是一只猫?如果从想象力成功生成了一只猫,这不也能达到我们的生成要求吗? 通过不断地想象训练,它甚至是可以自己归纳那些想象力生成那些品种的猫。

is cat

这种方式貌似可让生成器慢慢学着用想象力画猫。可是判断这画出来的猫是不是猫这件事谁来做?叫一个人过来看着模型一个个生成成千上万的猫咪图像, 然后再手动实时给他一个标签反馈显然也不现实。还能怎么干?

is cat model

不如我们再定义一个模型,它的任务就是来判断这只想象出来的猫是不是长得像真猫。但是光让这个模型看想象猫,它没有参考物,也不会知道是否像真猫, 所以我们偷偷还要给这个模型看一些真猫的图片,让它知道哪些是真猫,哪些是想象猫

is cat model

这样,这个辨别模型就可以学着判断哪些是想象猫,哪些是真猫了。用这种方法来训练一个辨别模型当然是没问题的,可是我们的最终目的是为了让生成模型生成真假难辨的想象猫, 上面的步骤貌似并不能达成这个目标,充其量,训练出了一个牛逼的辨别模型。好,我们能不能将辨别模型的能力转接到生成模型呢?让辨别模型指导生成模型进化? GAN牛逼就是因为这。

yes cat back

上图如果转化成白话,意思就是,判决模型拿着对真猫的理解,和生成模型说:你看,你要照着我对真猫的想法来,修改一下这里这里,你就能画出更像真猫的想象猫了。 也就是告诉生成模型,怎么画可以画出我觉得想真猫的猫。

这就是GAN的前世今生啦。下面我们就来用代码实现最简单的一种GAN架构吧。

效果

这次任务是用GAN生成一元二次方程的风格,所以我们会给判别模型看很多真实的一元二次方程线条,然后让它指导生成器学习。 这个任务是我这个GAN系列中, 唯一一个用CPU训练也比较舒服的案例,其他用mnist手写数据来训练的GAN,最好还是有一张GPU比较好。

上面这段代码就是我用来生成一元二次方程的方法,通过改变不同 a 的值,我们就能获得不同的曲线样式。 刚开始训练的时候,生成器会生成乱七八糟的线条。

untrained

训练了20个epoch后,生成器的效果如下,可以看到生成的多条一元二次线条已经非常接近于真实的一元二次方程线条了。

trained

秀代码

因为代码量还算比较多,我来展示最核心的一些部分,如果想直接看全部代码, 请点击这里去往我的github.

下面就是训练的主循环了,其实和监督学习没多大差别,要说有差别的话,就是可以没有label。只需要传入真实一元二次方程数据(data)就好。

那么这里的GAN class怎么定义呢?首先,里面必然会有一个 generatordiscriminator, 这两个分别是生成器和判别器,后续就会使用生成器来生成一元二次方程,判别器来判断真假。

其实这个一元二次方程的生成器和判别器都很简单,我们用浅层神经网络就可以搞定,下面就是用keras定义的两个浅层神经网络。 唯一我想提的一点是他们的输入输出是什么?

  • 生成器:
    • 输入:随机噪点(想像力)
    • 输出:一元二次的线段
  • 判别器:
    • 输入:一元二次的线段
    • 输入:yes/no 二分类

训练GAN的生成器和判别器也是有讲究的。首先可以先让判别器(discriminator)来训练生成器(generator), 这时我们就要给判别器传入正确的label,让判别器告诉生成器如何朝正确的方向去发展。接着将生成器生成出来的假数据和真数据打包, 让判别器去学着判断哪些是真的哪些是假的,单独训练判别器。

那我们是如何具体训练判别器呢?将传入的真假数据给与真假标签,让判别器学着去做0/1分类就好了。

gradient

你在论文中可能会看到判别器的 gradient 定义为上面这样,乍看起来很高级,但是它本质上就是一个 0/1 的逻辑回归分类误差。 不过表现形式上,逻辑回归又有差别,下面是逻辑回归的 cost,我们对比一下。

logistic loss

逻辑回归中的x都是来自于同一批数据,而GAN中的前面的x是真实数据,后面还有一个用 G(z) 表示的 x,这个 x=G(z) 是生成出来的数据, 所以这两个 x 来自不同的数据,这点就和逻辑回归不同了。不过他们优化的目的都是一样的,最终都是让模型能区分 0 (假) 和 1(真)。 这里的 D(G(z)) 出来的结果越趋近 0 越好(梯度下降)。我在上面的代码中做了一个更简单的处理,设定好真假标签, 直接使用 keras.losses.BinaryCrossentropy() 做逻辑回归的分类。

然而训练生成器才是GAN的精华,这里发生了前所未有的训练模式,虽然生成器生成的数据我们都知道它是假的, 但是在这里我们故意用真标签去诱导判别器给出变成真数据的梯度下降方向,所以判别器才会生成器去生成更真的数据。 这就是GAN的奥秘了。

你有没有想过,想象力是什么?在GAN里面,想象力就是随机噪声,我们可以用 tf.random.normal((n, self.latent_dim) 这种方法把噪声做出来。

这部分的 gradient 计算方式也是逻辑回归的其中一部分。目的是让判别器 D 用它对的认知,帮助 G 提升生成 的能力。 所以这里的 D(G(z)) 出来的结果越趋近 1 越好(梯度上升)。

generator gradient

总结

在GAN的学术研究中,特别是生成图片的研究,研究人员发现,GAN的学习非常不稳定。稍微大点的数据生成,比如生成32*32像素的图片,就很容易学废。 在后续的发展中,众多研究者提出了非常多的改进方式,我们在后面一一展开。 GAN的发展还在持续的迭代中,也说不好未来的生成技术不一定是GAN的方式。 只是目前,GAN也能被训练得很好,生成特别优质的图片。比如让逝去家人的照片动起来~ 甚至还可以根据不同的要求条件来生成。 GAN理所应当地成为了当今最流行,也是最主要的生成模型方法。


降低知识传递的门槛

莫烦经常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 【支持莫烦】 能让我感到认同,我也更有理由坚持下去。

我组建了微信群,欢迎大家加入,交流经验,提出问题,互相帮持。 扫码后,请一定备注"莫烦",否则我不会同意你的入群申请。

wechat