Self-Attention GAN (SAGAN) 自注意力
学习资料:
- 我制作的GAN简介短片
- 我制作的GAN的问题漏洞短片
- 论文 Self-Attention Generative Adversarial Networks
- 本节代码
- 代码有我自己定义的依赖utils.py, visual.py, mnist_ds.py
怎么了¶
模型学着该看哪里,不该看哪里,这种能力对于模型针对某件事物的理解还是相当有好处的。因为模型开始有重点地观察,分析了。 不管是看一张图还是分析一句话,都可以让注意力在其中发挥作用。
现在到处都是注意力,有种注意力要一统江山的意思,不管是计算机视觉还是自然语言处理的模型里面,随处可见 attention 机制。 自从 Seq2Seq 这种 RNN 中开始学习计算机视觉(CV),开始搞注意力后, 自然语言处理 NLP 仿佛就爱上了注意力,比如 Transformer 完全就舍弃了 RNN 这一套, 靠着注意力大一统,并获得了当前最好的效果。后来 Transformer 这种注意力在又迁移会 CV 界,打开了 CV 的疯狂注意力时代。
在GAN里面同样也可以有注意力(SAGAN),这次要讲的就是用注意力让GAN生成得更好。用一句话来描述 SAGAN: 围绕着重点画画,画出来的画更有神了。 这次我们还是拿 mnist 来生成注意后的手写数字,下面这张动图就是最终结果了。
怎么训练¶
这种注意力,叫做自注意力 self-attention, 原则上就是不断地基于原始图片上的注意后再注意。我在NLP的教学当中有做个一个短视频, 原理是一样的,你可以直接看看。而这篇 SAGAN 也是这种堆砌注意力的网络。无论在 Generator 还是 Discriminator 上, 都嵌入了注意力,所以这种注意力更是一种可插拔的底层模块。
上面是一种计算注意力的方式,按照现在的分法,上面的 f(x)
, g(x)
, h(x)
有一个更通用的名字,现在在各种其他注意力模型中,我们都这样说:
- Query:
f(x)
- Key:
g(x)
- Value:
h(x)
那么 Query, Key, Value 又是什么意思呢?我之前做 NLP 教程的 Transformer 的时候, 是这么解释的,你看看有没有道理。
想象这是一个相亲画面,我有我心中有个喜欢女孩的样子,我会按照这个心目中的形象浏览各女孩的照片, 如果一个女生样貌很像我心中的样子,我就注意这个人, 并安排一段稍微长一点的时间阅读她的详细材料, 反之我就安排少一点时间看她的材料。这样我就能将注意力放在我认为满足条件的候选人身上了。 我心中女神的样子就是Query,我拿着它(Query)去和所有的候选人(Key)做对比,得到一个要注意的程度(attention), 根据这个程度判断我要花多久时间仔细阅读候选人的材料(Value)。 这就是Transformer的注意力方式。
所以 SAGAN 中的注意力模块计算方式也是上面这一种,但是如果你了解Transformer,并仔细看 SAGAN 的注意力计算公式,会发现他们也有些许不同。 下面第一个公式是用 softmax 来计算注意力矩阵,也就是哪里要稍微注意些,哪里不太注意,你可以理解成这是重点矩阵
。 然后下面第二个公式,把注意力矩阵印在h(x)
上,也就是在计算一个用注意力之后的影响结果。最终产生了一个注意后的结果。
但是如果知道 Transformer 的朋友,你会发现在这里还有一个 v()
,这是再干嘛呢?其实这里的 v 只是一个转换 channel 数的 1*1 conv 卷积, 目的是为了节省一点内存空间。如果你在 h(x)
那已经对齐了 channel 数,这个 v 就可以不要了。
另外,SAGAN还用了 hinge loss 代替原有的loss,它觉得 hinge loss 好一点,但其实应该也差不多。
秀代码¶
如果想直接看全部代码, 请点击这里去往我的github.
因为加入了 self-attention 模块,所以我们干脆就先定义一个 Attention
的 layer class 吧,到时候好重复调用它。 下面的代码比我原始的代码有省略,我只突出最重要的地方,所以如果你想看更细节的代码,请移步我的 Github.
在这里,我省略了构建网络的代码,但我想强调计算 attention 的过程。通过 f
, g
, h
, v
的交互计算,先得到一个 f
和 g
产生的注意力矩阵, 然后 softmax 一下它,给它分配100%的注意力空间,然后再将这个注意力施加到 h
上。这就是注意力的套路了。接着我们再将这个 Attention
塞入 SAGAN。
我们先看看要把 Attention 塞到哪吧。我在 Discriminator 里放了一个 Attention,在 Generator 里放了两个。 位置都在激活函数之后。其实我觉得放在激活函数前后都可以尝试一下,现在想想,我可能还更偏向放在激活之前。 因为激活前,batchNorm 之后的数据分布更加适合学习。
整体训练方法和步骤和其他的GAN差不多。我就不复述了,你可以到我Github看全部代码。 最后一个 epoch 的结果是这样,感觉也没特别好,哈哈。
总结¶
SAGAN 虽然是以前的paper了,但是用到了最近比较火的 Attention 技术,Attention 本质是好的,效果也应该会不错,所以把 Attention 用在 GAN 上,的确是一种不二选择。 但是可能还有一些点需要完善才能达到一个比较好的效果。