Pix2Pix 图生图
学习资料:
- 我制作的GAN简介短片
- 我制作的看图
说画
的GAN - 论文 Image-to-Image Translation with Conditional Adversarial Networks
- 本节代码
- 代码有我自己定义的依赖utils.py, visual.py, mnist_ds.py
怎么了¶
pix2pix 这种技术非常直观,明摆着告诉你,这就是一种用图片生成图片的技术。具体来说,之前的 CCGAN 是拿着一张图片生成另一张图片。那么今天的 Pix2Pix 或者说 Img2Img Net 论文 有什么不一样呢?
其实并没有太多的不同,我觉得最大的不同就是他们两个论文中,实验场景和模型结构的不同。你说 CCGAN 可以拿来做 Img2Img/Pix2Pix 吧,也是可以的。 就好比 GAN 和 WGAN 的关系,同样都是 GAN, 都能从噪点生成图片,只是他们的网络结构不同而已。让我用一句话来概括 Pix2Pix: 图片之间的风格转换。
下面来几张 Pix2Pix 论文里比较经典的从草图生成实物的,包包鞋子生成图。
如果你对这些生成图片的 GAN 感兴趣,你可以从我这个短视频中快速了解他们的技术。 这次教学中,为了减少训练时间,统一横向对比 CCGAN 模型, 我还是用 mnist 数据做了这样一个遮盖生成的实验。
怎么训练¶
如果你之前看过我的 CCGAN 的教学,下面的模型架构你应该很容易就能理解。 所以我们对比一下,如果我们想用图片生成图片,下面这种 CCGAN 的模式就是最简单的一种模式了。
对比上面 CCGAN,Pix2Pix 的模型是下面这种结构。
- Generator:我用一张原图片(猫的草图),得到生成图
- Discriminator:同时拿着原图和生成图,看看这张生成图是不是通过原图生成的?
虽然这么做可以将原图的数据风格,转换成生成图的风格,但是要求你得有一一对应的原图数据。 后面我们再介绍 CycleGAN 的时候, 就没有这个限制,所以我还挺喜欢 CycleGAN 的思路的。
下图是 Pix2Pix paper 中,它定义的模型架构,你可以参考着看。
除了这种模型结构,Pix2Pix 中也提到,它改动了一些模型的底层结构,达到了更好的效果。比如其中的 Generator 的 U-Net。 U-Net 是另外一篇 paper 中提出来的。 目的是为了有一种 skip-connect 的感觉,更有效地传递梯度。这种 skip-connect 的想法在很多 CNN 模型中都有实践过,比如 ResNet。
还有一个细节是在 Discriminator 中,如果我们直接输出一个总的真假判断,可能对模型的训练负担会很大,那么能不能输出局部的真假判断呢? 只要我的 CNN filter 扫过的所有区域,我都预测对了,那么拼凑起来,预测对也是八九不离十的。简单说就是 我将一个大步骤拆解成很多小步骤,每个小步骤都好的话,也就意味着我的大步骤大概率也是好的。
最后一点,作者认为,Generator 除了直接用 Discriminator 告诉它的真假误差来更新,还需要自己努力一下,左手拿着真图,右手拿着生成图,一个个像素对比,看生成的效果如何。 所以 Generator 的 loss 是两个 loss 的叠加。
- 普通的,从 Discriminator 传过来的 loss
- 自己
内卷
一下,自己给自己的图片对比 L1 loss
为什么用 L1 loss 呢?论文里说,L1 不会 blurry,图片没有 L2 糊。这算是经验总结吧,没什么好说的。 所以总的 Generator loss 就变成了这样:
有了上面的理解,接下来我们来撸一撸代码,手动实现一下。
秀代码¶
如果想直接看全部代码, 请点击这里去往我的Github. 我在这里说一下代码当中的重点部分。
同样,我们先看看 train
这个入口是如何定义的,这次的训练方法比较简单,和普通 GAN 训练起来没啥区别。
在定义我们 Pix2Pix 模型的时候,我们关注的就是怎么拿到 G 和 D,我们先看看 D 是怎么定义的吧,因为上面提到 D 会用到 Patch Net,我们会看看 Patch Net 是怎么写的。
从 input 我们能看出另一件事,因为 Pix2Pix, Discriminator 要一手拿生成图,一手拿原图,所以实际上是拿了两张图一起输入,它需要判断的是, 我用原图生成的图片有没有生成好。那么我们就要想一种方式把原图和生成图拼接起来,一起打包输入到 Discriminator 里。 我的代码前几步就是在干这件事。
后面当我们要采用 Patch GAN 的方式来输出一个 2D 概率矩阵的时候,其实就是不用 flat 层,转而再做一次 conv,将 feature map 压缩到 1 张, 这张 feature map 里面的值就是汇集了所有 patch 的判断,对图片中每一个 patch 分别判断是否是 real。
下面我们再看 Generator,就简单一些。不过要注意的是,Pix2Pix 的 Generator 用了 U-Net。
在训练的时候呢,和普通 GAN 不同的是我的:
- Generator 要训练生成图片与真实图片的 L1 相似度
- Discriminator 要用 Patch GAN 的模式来训练
最后一个训练批次的结果如下图,看到的效果还是可以的。
问题¶
其实在上文我也提过,Pix2Pix 对数据集的要求还是挺高的,因为它需要你预先就准备好一一对应的图片数据,不然没发做风格转换。 拿草图生成实物的例子,你要同时拥有草图和实物对应的数据集,这样你才能训练这个模型,不然 Pix2Pix 对应不起来这些图片。
而后面要说的 CycleGAN 就不需要这种一一对应关系~
总结¶
Pix2Pix GAN 为我们打开了图图生成的新窗口,我们后续要做的风格迁移,转换等工作,都离不开这些思考。 虽然这个特定的 Pix2Pix 模型有很多局限性,但是里面的一些模型结果(U-Net),训练方案(Patch GAN),都被后辈们拿去借鉴, 基于他们做了很多的创新工作。