Wasserstein Divergence for GANs (WGAN-div) 计算W散度

作者: 莫烦 编辑: 莫烦 2021-03-30

学习资料:

怎么了

WGAN 开创了一条新的GAN训练模式,但是简单粗暴的 Clip weights 方法让它无法发挥出全部实力。WGAN-gp 提出了一种替换 clip weights 的方式, 让WGAN焕发生机,我们从上回自己训练出来的 mnist 数据就能看出,WGAN-gp 比经典 WGAN 的结果好太多了。 但是 WGAN-gp 有没有问题呢?答案是有的,要不然 WGAN-div 就不会被提出来了。

wgan gp

在 WGAN-gp 中,为了满足 1-Lipschitz 约束,训练出好效果,作者采用了真假数据的插值方法,来模拟全空间的均匀分布 (至于为什么有这种结论,我没看太懂。。我选择相信作者,哈哈哈)。 WGAN-div 的作者说,这种做法是一种机械性的,很难靠有限的采样,模拟出这种全空间分布。

with a finite number of training iterations on limited input samples, it is very difficult to guarantee the k-Lipschitz constraint for the whole input domain.

objective function

WGAN-div的作者比较懒,没有具体解释 WGAN-div 要写成上面那样,在论文中的 公式10,只提到:

According to [19], by solving a family of minimization problems given p > 0

然后下面就摆上了那个公式,大哥,你好歹也解释一下啊。这个知乎真的去翻了一下引用的[19] 这篇论文,它没看懂,所以我也不抱希望去看懂证明过程了。还是直接拿着结论用吧,毕竟谁要我这么相信作者呢,哈哈。

所以一句话来描述 WGAN-div: 它不相信 WGAN-gp 可以在有效的训练中达到 1-Lipschitz 约束,所以换了一种叫做 W-div 的目标(公式截图的后半部分)。

我用 WGAN-div 训练 mnist 的结果如下,其实感觉训练出来的效果和 WGAN-gp 也差不多。但是WGAN-div论文中对比了其他的生成结果, 作者说效果更好,我们再下一节看看具体的论文效果吧。

results

训练效果

除了mnist,我没有拿WGAN-div做其他实验,所以我们展示一下论文中的训练对比吧,有一个更直观的了解。 有看过我 WGAN-gp介绍 的朋友知道,我对比过下面这张图。 经典WGAN在这张图上的效果更差。在对比 WGAN-gp 和 WGAN-div 的结果,也可以发现 WGAN-div 拟合效果更好。

comparison1

用GAN来生成人脸的测试中,WGAN-div 也拿到最最好的 FID 分数。FID 分数是一个拿来判断生成数据分布和真实数据分布的相像程度, 数值越小越像。

comparison2

怎么训练

WGAN-div 的伪代码如下,他的训练步骤并没有差 WGAN-gp 太多,所以在代码实现的时候, 我们可以直接继承之前写的 WGAN-gp 代码。

pseuducode

可以发现 WGAN-div 的 div div 和 WGAN-gp 的 gp gp 就只差了一丢丢。训练方法一模一样。有了 WGAN-gp 的基础下面一节写代码的工作量其实也不大。

秀代码

如果想直接看全部代码, 请点击这里去往我的github.

上面提到,WGAN-div 和 WGAN-gp 是十分相像了,我们这次会选择直接继承 WGAN-gp 来写,因为改动的代码少太多了。 所以如果你还不知道 WGAN-gp,我强烈建议你先看完我写的 WGAN-gp教程。 在 WGAN-div 中,我们只需要更改计算 gp 的方式即可。

from wgan_gp import WGANgp

class WGANdiv(WGANgp):
    def __init__(self, latent_dim, p, lambda_, img_shape):
        super().__init__(latent_dim, lambda_, img_shape)
        self.p = p

    # Wasserstein Divergence
    def gp(self, real_img, fake_img):
        e = tf.random.uniform((len(real_img), 1, 1, 1), 0, 1)
        noise_img = e * real_img + (1.-e)*fake_img      # extend distribution space
        with tf.GradientTape() as tape:
            tape.watch(noise_img)
            o = self.d(noise_img)
        g = tape.gradient(o, noise_img)         # image gradients
        # the following is different from WGANgp
        gp = tf.pow(tf.reduce_sum(tf.square(g), axis=[1, 2, 3]), self.p)
        return tf.reduce_mean(gp)

下面为了保持论文中计算 w_distance 的一致性,在原本继承过来的 w_distance() 中,我把它们的方向转换了一下,正号变成负号了。

class WGANdiv(WGANgp):
    ...
    @staticmethod
    def w_distance(fake, real=None):
        # sign is reversed in WGANdiv
        if real is None:
            return - tf.reduce_mean(fake)
        else:
            return -(tf.reduce_mean(fake) - tf.reduce_mean(real))

为了更清楚一点,我把 WGAN-gp 的计算loss方法再搬过来给你看看。下面的 self.lambda_ 其实就是论文中的 k, 为了继承方便,我偷懒没有把 self.lambda_ 改成 k

# 这段是从我的 WGAN-gp 中 copy 来的
class WGANgp(WGAN):
    ...
    def train_d(self, real_img):
        ...    
        gp = self.gp(real_img, fake_img)                        
        loss = w_loss + self.lambda_ * gp     

最后的训练结果和 WGAN-gp 差不了太多,但是我还是相信 WGAN-div 应该会稍微好一点点。 这也有待更多实验的验证。

res

总结

WGAN 为 GAN 的研究埋下了一个大坑,后续有很多人都投入到这个填坑的游戏中了。WGAN-gp 和 WGAN-div 都是填坑者。 我个人的话,会先在自己的项目中使用 WGAN-gp,如果效果达不到要求的话,再改为 WGAN-div 试试,把WGAN-div作为一个后补吧, 因为他们的效果真的比较相近。


降低知识传递的门槛

莫烦很常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 通过 【赞助莫烦】 能让我感到认同,我也更有理由坚持下去。