SRGAN 超清生成
学习资料:
- 我制作的GAN简介短片
- 我制作的看图
说画
的GAN - 论文 Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
- 本节代码
- 代码有我自己定义的依赖utils.py, visual.py, mnist_ds.py
怎么了¶
美女的照片太糊了?做 PPT 的时候,能找到的图片都是低清的。你是不是也遇到过这样的事情,想要更高清的图片。 GAN 是不是也能在这方面来拯救你?
所以这一节的内容十分简单粗暴,目的就是生成高清图。不管你是不是想要做一些不可描述的事情。 这节的内容,做法,都值得你参考。如果你对使用 GAN 来做图片生成图片的技术还不太了解,我强烈建议你先看看我这个短片介绍。 如果用一句话来概括今天这个 SRGAN 模型在做的事情:低清图片转高清图片
就这么简单。
我们先来看看论文中的实验结果。
同样,在我的这节练习中,我们还是使用 mnist 数据来做,首先我会将 28*28 的手写数字图片变成更模糊的版本,然后让模型来学着重构,高清化它。
怎么训练¶
为了让模型可以对低清图像高清化,我们传统的思维就是把低清图像当成 GAN 的输入数据,然后让 GAN 帮你生成一种尺寸更大的高清图。 思路正确,不过效果可能并不太好,所以 SRGAN 在生成效果上再多思考了一下。
上面这张图是论文中的图,他想表达的是,从低清到高清的过程,可能比你想象的要复杂一点,因为低清意味着信息量少,高清是信息量高。从一个信息量少的数据,到信息量高的数据, 其实说白了,就是要靠猜,我猜你最有可能缺少了哪些信息,然后补全这些信息。
而且上图还表达了,这张低清的图,有可能可以对应上很多的可行答案
。如果我们像下面这种方案来规范化 GAN 的高清生成,计算生成图片和真实图片的像素误差,很可能让 GAN 放弃学习, 因为这会让 GAN 太专注在单个像素的对比上了,所以 GAN 自己也会学得很纠结,久而久之就学废了。
作者不想让 GAN 学废了,但又想让 GAN 对生成的图片和高清原图有些对比,用这个对比误差修正网络。那么他们使用了一个很聪明的方式。 既然直接在像素上比较,会吃力不讨好,那我能不能在高层特征上对比呢?比如用一个 CNN 卷积的特征图上。 这真是一个聪明的方式。 因为我们在乎的是生成的图像,它肉眼可见的范围内,是不是像高清图像,而不是单个像素上像不像。所以就有了下面的图例解释。
在论文中,使用了下面图中的 Generator 和 Discriminator 的网络架构,没有太多特别的,就是在 Generator 中,它使用了 ResNet, 做了一些 Skip Connection, 这样有助于梯度传导,学习效果会好一点。
在 loss 定义上,因为我们在做超清化生成,所以这个生成效果也会有特定的 loss。
这里面的 content loss 就是我在 CNN 特征图上对比生成图片和原图的相似度。不是直接在生成图像上的像素对比。 下面的公式中,W
和 H
就是这张特征图的长宽。论文中,是选择了一个 VGG 来生成这些特征图,而且用这个 VGG 得到的超清(生成图)/高清(原图)上的特征图做 L2 的相似度计算。
用这个 SR loss 来评估生成的超清图和原始的高清图是不是一样的。在总 loss 那还有一个一般 GAN 都会有的 Adversarial loss, 评估 Discriminator 看到的是不是真实图片。
有了上面的理解,接下来我们来撸一撸代码,手动实现一下。
秀代码¶
如果想直接看全部代码, 请点击这里去往我的Github. 我在这里说一下代码当中的重点部分。
因为要处理低清图片,所以我给原始的 mnist(28*28)的图片做了 downsampling
低清化处理,变成了 (7*7) 的尺寸。 然后把高清图片 hr_img
当做标签,低清 lr_img
当做 GAN 的 input,给输入进训练。
先看看我对 Discriminator 的定义,我为了偷懒,没有像论文中用了 VGG 去拿特征图,因为我想减轻训练压力, 就直接复用了我的 Discriminator 来输出特征图,做特征图上的相似度计算。 所以你不会在我的代码中发现 VGG 的影子。 反正最终的生成效果还是挺不错的,说明这种方法是可以在你的项目中尝试使用的。
再来看看 Generator 的模型定义,虽然有点长,但是简单来看,其实和论文中的模型有一些类似,但是会小很多。因为我们的 mnist 数据也不大嘛,搞那么大模型干嘛~ 在 Generator 里,我定义了一个 pro_process 网络,是用来对齐 Feature Map 的数量的。后面还接了一个 Encoder + Decoder, 用来生成超清图片的。没有什么额外的设定了。
Discriminator 的 loss 计算就是正常的 GAN 中 Discriminator 的 loss 计算,算一个真假区分就好了。
但是在 Generator 的 loss 计算的时候,稍微复杂一点,因为有两种 loss,一种是特征图的对比 loss,一种是 Discriminator 传过来的真假图片 loss。
最后一个训练批次的结果如下图,看到的效果还是可以的。
总结¶
图片的高清化的应用场景还是很多的,比如在网络传输视频的时候,如果都是高清画质传输,这样很占带宽的。如果你可以只传输低清画质的视频,然后再浏览器中高清化处理, 这样就对下载没有压力了,而且很快就能加载完,而且又能看到超清化的画质。基于 SRGAN 的研究还有很多,感兴趣的话,你可以再在网上搜搜这一个流派。