Pix2Pix 图生图 - 生成模型 GAN 网络 | 莫烦Python

Pix2Pix 图生图

作者: 莫烦 编辑: 莫烦 2021-06-04

学习资料:

怎么了

pix2pix 这种技术非常直观,明摆着告诉你,这就是一种用图片生成图片的技术。具体来说,之前的 CCGAN 是拿着一张图片生成另一张图片。那么今天的 Pix2Pix 或者说 Img2Img Net 论文 有什么不一样呢?

cat

其实并没有太多的不同,我觉得最大的不同就是他们两个论文中,实验场景和模型结构的不同。你说 CCGAN 可以拿来做 Img2Img/Pix2Pix 吧,也是可以的。 就好比 GANWGAN 的关系,同样都是 GAN, 都能从噪点生成图片,只是他们的网络结构不同而已。让我用一句话来概括 Pix2Pix: 图片之间的风格转换。

下面来几张 Pix2Pix 论文里比较经典的从草图生成实物的,包包鞋子生成图。

bag shoes map

如果你对这些生成图片的 GAN 感兴趣,你可以从我这个短视频中快速了解他们的技术。 这次教学中,为了减少训练时间,统一横向对比 CCGAN 模型, 我还是用 mnist 数据做了这样一个遮盖生成的实验。

results

怎么训练

如果你之前看过我的 CCGAN 的教学,下面的模型架构你应该很容易就能理解。 所以我们对比一下,如果我们想用图片生成图片,下面这种 CCGAN 的模式就是最简单的一种模式了。

my struct trail

对比上面 CCGAN,Pix2Pix 的模型是下面这种结构。

  1. Generator:我用一张原图片(猫的草图),得到生成图
  2. Discriminator:同时拿着原图和生成图,看看这张生成图是不是通过原图生成的?

my struct

虽然这么做可以将原图的数据风格,转换成生成图的风格,但是要求你得有一一对应的原图数据。 后面我们再介绍 CycleGAN 的时候, 就没有这个限制,所以我还挺喜欢 CycleGAN 的思路的。

下图是 Pix2Pix paper 中,它定义的模型架构,你可以参考着看。

paper struct

除了这种模型结构,Pix2Pix 中也提到,它改动了一些模型的底层结构,达到了更好的效果。比如其中的 Generator 的 U-Net。 U-Net 是另外一篇 paper 中提出来的。 目的是为了有一种 skip-connect 的感觉,更有效地传递梯度。这种 skip-connect 的想法在很多 CNN 模型中都有实践过,比如 ResNet。

unet

还有一个细节是在 Discriminator 中,如果我们直接输出一个总的真假判断,可能对模型的训练负担会很大,那么能不能输出局部的真假判断呢? 只要我的 CNN filter 扫过的所有区域,我都预测对了,那么拼凑起来,预测对也是八九不离十的。简单说就是 我将一个大步骤拆解成很多小步骤,每个小步骤都好的话,也就意味着我的大步骤大概率也是好的。

patch net

最后一点,作者认为,Generator 除了直接用 Discriminator 告诉它的真假误差来更新,还需要自己努力一下,左手拿着真图,右手拿着生成图,一个个像素对比,看生成的效果如何。 所以 Generator 的 loss 是两个 loss 的叠加。

  1. 普通的,从 Discriminator 传过来的 loss
  2. 自己内卷一下,自己给自己的图片对比 L1 loss

l1 loss

为什么用 L1 loss 呢?论文里说,L1 不会 blurry,图片没有 L2 糊。这算是经验总结吧,没什么好说的。 所以总的 Generator loss 就变成了这样:

G loss

有了上面的理解,接下来我们来撸一撸代码,手动实现一下。

秀代码

如果想直接看全部代码, 请点击这里去往我的Github. 我在这里说一下代码当中的重点部分。

同样,我们先看看 train 这个入口是如何定义的,这次的训练方法比较简单,和普通 GAN 训练起来没啥区别。

在定义我们 Pix2Pix 模型的时候,我们关注的就是怎么拿到 G 和 D,我们先看看 D 是怎么定义的吧,因为上面提到 D 会用到 Patch Net,我们会看看 Patch Net 是怎么写的。

从 input 我们能看出另一件事,因为 Pix2Pix, Discriminator 要一手拿生成图,一手拿原图,所以实际上是拿了两张图一起输入,它需要判断的是, 我用原图生成的图片有没有生成好。那么我们就要想一种方式把原图和生成图拼接起来,一起打包输入到 Discriminator 里。 我的代码前几步就是在干这件事。

后面当我们要采用 Patch GAN 的方式来输出一个 2D 概率矩阵的时候,其实就是不用 flat 层,转而再做一次 conv,将 feature map 压缩到 1 张, 这张 feature map 里面的值就是汇集了所有 patch 的判断,对图片中每一个 patch 分别判断是否是 real。

下面我们再看 Generator,就简单一些。不过要注意的是,Pix2Pix 的 Generator 用了 U-Net

在训练的时候呢,和普通 GAN 不同的是我的:

  1. Generator 要训练生成图片与真实图片的 L1 相似度
  2. Discriminator 要用 Patch GAN 的模式来训练

最后一个训练批次的结果如下图,看到的效果还是可以的。

res

问题

其实在上文我也提过,Pix2Pix 对数据集的要求还是挺高的,因为它需要你预先就准备好一一对应的图片数据,不然没发做风格转换。 拿草图生成实物的例子,你要同时拥有草图和实物对应的数据集,这样你才能训练这个模型,不然 Pix2Pix 对应不起来这些图片。

而后面要说的 CycleGAN 就不需要这种一一对应关系~

总结

Pix2Pix GAN 为我们打开了图图生成的新窗口,我们后续要做的风格迁移,转换等工作,都离不开这些思考。 虽然这个特定的 Pix2Pix 模型有很多局限性,但是里面的一些模型结果(U-Net),训练方案(Patch GAN),都被后辈们拿去借鉴, 基于他们做了很多的创新工作。


降低知识传递的门槛

莫烦很常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 通过 【赞助莫烦】 能让我感到认同,我也更有理由坚持下去。