Self-Attention GAN (SAGAN) 自注意力

作者: 莫烦 编辑: 莫烦 2021-03-31

学习资料:

怎么了

模型学着该看哪里,不该看哪里,这种能力对于模型针对某件事物的理解还是相当有好处的。因为模型开始有重点地观察,分析了。 不管是看一张图还是分析一句话,都可以让注意力在其中发挥作用。

paper res

现在到处都是注意力,有种注意力要一统江山的意思,不管是计算机视觉还是自然语言处理的模型里面,随处可见 attention 机制。 自从 Seq2Seq 这种 RNN 中开始学习计算机视觉(CV),开始搞注意力后, 自然语言处理 NLP 仿佛就爱上了注意力,比如 Transformer 完全就舍弃了 RNN 这一套, 靠着注意力大一统,并获得了当前最好的效果。后来 Transformer 这种注意力在又迁移会 CV 界,打开了 CV 的疯狂注意力时代。

在GAN里面同样也可以有注意力(SAGAN),这次要讲的就是用注意力让GAN生成得更好。用一句话来描述 SAGAN: 围绕着重点画画,画出来的画更有神了。 这次我们还是拿 mnist 来生成注意后的手写数字,下面这张动图就是最终结果了。

results

怎么训练

这种注意力,叫做自注意力 self-attention, 原则上就是不断地基于原始图片上的注意后再注意。我在NLP的教学当中有做个一个短视频, 原理是一样的,你可以直接看看。而这篇 SAGAN 也是这种堆砌注意力的网络。无论在 Generator 还是 Discriminator 上, 都嵌入了注意力,所以这种注意力更是一种可插拔的底层模块。

self attention

上面是一种计算注意力的方式,按照现在的分法,上面的 f(x), g(x), h(x) 有一个更通用的名字,现在在各种其他注意力模型中,我们都这样说:

  • Query: f(x)
  • Key: g(x)
  • Value: h(x)

那么 Query, Key, Value 又是什么意思呢?我之前做 NLP 教程Transformer 的时候, 是这么解释的,你看看有没有道理。

想象这是一个相亲画面,我有我心中有个喜欢女孩的样子,我会按照这个心目中的形象浏览各女孩的照片, 如果一个女生样貌很像我心中的样子,我就注意这个人, 并安排一段稍微长一点的时间阅读她的详细材料, 反之我就安排少一点时间看她的材料。这样我就能将注意力放在我认为满足条件的候选人身上了。 我心中女神的样子就是Query,我拿着它(Query)去和所有的候选人(Key)做对比,得到一个要注意的程度(attention), 根据这个程度判断我要花多久时间仔细阅读候选人的材料(Value)。 这就是Transformer的注意力方式。

key query value

所以 SAGAN 中的注意力模块计算方式也是上面这一种,但是如果你了解Transformer,并仔细看 SAGAN 的注意力计算公式,会发现他们也有些许不同。 下面第一个公式是用 softmax 来计算注意力矩阵,也就是哪里要稍微注意些,哪里不太注意,你可以理解成这是重点矩阵。 然后下面第二个公式,把注意力矩阵印在h(x)上,也就是在计算一个用注意力之后的影响结果。最终产生了一个注意后的结果。

softmax

attented

但是如果知道 Transformer 的朋友,你会发现在这里还有一个 v(),这是再干嘛呢?其实这里的 v 只是一个转换 channel 数的 1*1 conv 卷积, 目的是为了节省一点内存空间。如果你在 h(x) 那已经对齐了 channel 数,这个 v 就可以不要了。

另外,SAGAN还用了 hinge loss 代替原有的loss,它觉得 hinge loss 好一点,但其实应该也差不多。

hinge loss

秀代码

如果想直接看全部代码, 请点击这里去往我的github.

因为加入了 self-attention 模块,所以我们干脆就先定义一个 Attention 的 layer class 吧,到时候好重复调用它。 下面的代码比我原始的代码有省略,我只突出最重要的地方,所以如果你想看更细节的代码,请移步我的 Github.

class Attention(keras.layers.Layer):
    def __init__(self, gamma=0.01, trainable=True):
        super().__init__(trainable=trainable)
        self._gamma = gamma

    def call(self, inputs, **kwargs):
        f = self.f(inputs)    # [n, w, h, c] -> [n, w*h, c//8]
        g = self.g(inputs)    # [n, w, h, c] -> [n, w*h, c//8]
        h = self.h(inputs)    # [n, w, h, c] -> [n, w*h, c//8]
        s = tf.matmul(f, g, transpose_b=True)   # [n, w*h, c//8] @ [n, c//8, w*h] = [n, w*h, w*h]
        self.attention = tf.nn.softmax(s, axis=-1)
        context_wh = tf.matmul(self.attention, h)  # [n, w*h, w*h] @ [n, w*h, c//8] = [n, w*h, c//8]
        s = inputs.shape        # [n, w, h, c]
        cs = context_wh.shape   # [n, w*h, c//8]
        context = tf.reshape(context_wh, [-1, s[1], s[2], cs[-1]])    # [n, w, h, c//8]
        o = self.v(self.gamma * context) + inputs   # residual
        return o

在这里,我省略了构建网络的代码,但我想强调计算 attention 的过程。通过 f, g, h, v 的交互计算,先得到一个 fg 产生的注意力矩阵, 然后 softmax 一下它,给它分配100%的注意力空间,然后再将这个注意力施加到 h 上。这就是注意力的套路了。接着我们再将这个 Attention 塞入 SAGAN。

我们先看看要把 Attention 塞到哪吧。我在 Discriminator 里放了一个 Attention,在 Generator 里放了两个。 位置都在激活函数之后。其实我觉得放在激活函数前后都可以尝试一下,现在想想,我可能还更偏向放在激活之前。 因为激活前,batchNorm 之后的数据分布更加适合学习。

class SAGAN(keras.Model):
    def _get_discriminator(self):
        model = keras.Sequential([
            keras.layers.GaussianNoise(0.01, input_shape=self.img_shape),
            keras.layers.Conv2D(16, 4, strides=2, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            Attention(self.gamma),    # <---
            keras.layers.Dropout(0.3),

            keras.layers.Conv2D(32, 4, strides=2, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            keras.layers.Dropout(0.3),

            keras.layers.Flatten(),
            keras.layers.Dense(1),
        ], name="discriminator")
        return model

    def _get_generator(self):
        model = keras.Sequential([
            # [n, latent] -> [n, 7 * 7 * 128] -> [n, 7, 7, 128]
            keras.layers.Dense(7 * 7 * 128, input_shape=(self.latent_dim,)),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            keras.layers.Reshape((7, 7, 128)),

            # -> [n, 14, 14, 64]
            keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            Attention(self.gamma),      # <---

            # -> [n, 28, 28, 32]
            keras.layers.Conv2DTranspose(32, (4, 4), strides=(2, 2), padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            Attention(self.gamma),      # <---
            # -> [n, 28, 28, 1]
            keras.layers.Conv2D(1, (4, 4), padding='same', activation=keras.activations.tanh)
        ], name="generator")
        return model

整体训练方法和步骤和其他的GAN差不多。我就不复述了,你可以到我Github看全部代码。 最后一个 epoch 的结果是这样,感觉也没特别好,哈哈。

sagan res

总结

SAGAN 虽然是以前的paper了,但是用到了最近比较火的 Attention 技术,Attention 本质是好的,效果也应该会不错,所以把 Attention 用在 GAN 上,的确是一种不二选择。 但是可能还有一些点需要完善才能达到一个比较好的效果。


降低知识传递的门槛

莫烦很常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 通过 【赞助莫烦】 能让我感到认同,我也更有理由坚持下去。