Progressive Growing GAN (PGGAN) 分阶段学习
学习资料:
- 我制作的GAN简介短片
- 我制作的GAN的问题漏洞短片
- 论文 Progressive Growing of GAN for Improved Quality, Stability, and Variation
- 本节代码
- 代码有我自己定义的依赖utils.py, visual.py, mnist_ds.py
怎么了¶
GAN 在小尺寸上图片(<60px)的生成任务已经可以达到相对比较好的效果了,但是继续扩大尺寸,比如到到1024px,生成的难度是指数型增加的, 因为可以尝试的排列组合数是指数型增高。训练难度的提高,将会出现几个问题:
- 训练时间的指数型加长
- 训练效果的降低
- 有可能压根训练不出来,或只能训练出模糊的图片
所以 Progressive Growing GAN (PGGAN) 打算从训练方法上革新。用一句话描述 PGGAN 的做法: 分阶段,从易到难,从简单的小任务到大任务过度型学习。
这次我在 mnist 上做实验后,明显可以看出有三个训练阶段。从低像素(7*7)开始训练,到 14*14, 最后到 28*28 的最终 mnist 尺寸。
如果从动态训练的角度来看,mnist 训练的整个过程也可以在下面的动图中看到,我训练了8个epoch,效果已经不错了。
怎么训练¶
思路其实很简单,就是分阶段训练。具体呢,也就是论文中这张截图显示的:
- Generator 生成 4*4 的图片,Discriminator 识别 4*4 的图片
- 生成 8*8, 识别 8*8
- 生成 16*16, 识别 16*16
- ...
- 生成 1024*1024, 识别 1024*1024
让生成器和判别器的能力是逐层提升的,充分证明中国的依据俗话 一口吃不成一个胖子
,所以我们小口小口来,一个阶梯一个阶梯上。这样拆解了任务, 单个的小任务更容易实现。但是在分阶段学习的时候还有一些细节要考虑。
比如在上图中,就是针对每次增加难度,要进入下一阶段训练的时候,我们有些小动作可以帮我们无缝顺滑进入下一阶段。 这个小动作就是图上 (b) 这个阶段做的事情。本质上 (b) 阶段做的是一件润滑剂的事情。事情是这样的,进入下一阶段, 是一种硬切换好呢,还是一种软切换好,有点像是在对比阶梯还是残疾人坡的区别。
在 (b) 这部分,生成器给出的图片,综合了
- 原本网络的图片
- 新添加层生成的图片
因为在一开始新增一层的阶段,这层新增的层因为随机初始化,效果肯定没有原始被训练的层好,所以我们给一个相对低的权重,不让训练突然话锋一转,转得过于厉害。 另外,我们再接一条管道直接连通原本的训练过的层和生成图片,并给一个稍微大点的权重。在训练过程中在缓慢调整这个权重,让权重分配越来越重视新增的层。 对于判别器,同理也是一样的处理过程。
这就是 PGGAN 最核心的工作了。
秀代码¶
如果想直接看全部代码, 请点击这里去往我的github.
PGGAN看似简单,但是在工程实践的时候还是挺纠结的,因为要涉及到网络结构的变化,而 Tensorflow 又是一种静态图计算,所以还有一些些兼容性问题。 代码量是偏多的。我在下面只会体现出 PGGAN 的最核心代码。def train(gan, ds, epoch):
的代码和 class PGGAN(keras.Model):
的代码并没有和其他GAN有太大的差别。 但是因为 Generator 和 Discriminator 变动比较大,我们就分开创建这两个类吧。
在 Generator 和 Discriminator 中,我们分别都要创建每个阶段要使用到的层,而有些层经历了训练阶段后,还要被丢弃,比如在上一节中提到的 (b) 部分的润滑剂。
所以每个阶段都有它的配套服务员
,中间有一些服务员
我介绍一下:
self.b[x]_rgb
每个阶段,作为生成器的最后一层 rgb 颜色的转换self.b[x]
卷积或者是反卷积self.p[x]
润滑剂,原图的直接映射
我觉得最让我头疼的就是写 call()
这个功能了,因为每个阶段要运算的东西都不一样,写起来挺麻烦的,而且我也感觉这种写法一点都不优雅,应该还有更好的写法, 如果你想到更好的写法,请及时在后面留言讨论~
简单来说,我这里 call()
做前向后,会先判断目前进行到第几阶段了,然后根据不同阶段,有不同的输出策略。 比如要使用哪个阶段的配套服务员
就由 current_layer
这个参数控制。有比如 inputs
中的 p
就来控制我的润滑剂
力度。 相应的,Discriminator 的设计也是这样的,因为代码实在有点多,如果你想深入研究 PGGAN, 欢迎直接看我Github全部代码。
最后再来回顾一下分阶段的训练结果,有没有感觉这是一种简单,但是写起来很麻烦的方法呀~
总结¶
PGGAN 用一种我们人类常用的分阶段学习方法,来让GAN生成相对比较高清的图片。这是一种非常有效的方法,在后续有很多生成高清图片的论文中都沿用了这套理论。