Seq2Seq Attention 注意力机制
学习资料:
- 本节的全部代码
- 代码依赖的 utils.py 和 visual.py 在这里找到
- 我制作的 自然语言处理注意力机制 短片简介
- CNN attention 相关论文:Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
- NLP attention 相关论文1: Effective Approaches to Attention-based Neural Machine Translation
- NLP attention 相关论文2: Neural Machine Translation by Jointly Learning to Align and Translate
怎么了¶
不知道你身边有没有说话不经过大脑的人,反正我是有。每每听他们说话总是前言不搭后语。 所以俗话说三思而后行
,那么我们在说话之前的三思是在思考什么呢?在我看来就是对说话内容的反复斟酌,又或者是挑重点要点说。 其实这也是NLP中注意力的概念。计算机如果也是听到一句话后一股脑直接开始回答,那么很可能回答也前言不搭后语,特别是当听到的那句话太长, 他们回答起来也特别困难,因为很容易忘记听到的内容。注意力就是用来解决这种失忆
或者找不到重点
的问题。 我的这个短片简介会让你更加直观得观看到这样的问题。
我觉得下面这张图不仅有趣,而且成功地反映出了注意力的威力!只要你注意到不同的位置,整句话的意思也就变了~
怎么注意¶
我们先用视觉上的注意力方案就是,因为可以很直观地看到模型是怎么注意到重要内容的。下面的例子来自这篇论文, 它讲述的是让模型看一张海鸥的图片,然后生成一段对这张图片的文字描述。在生成描述中的每一个词的时候,这个模型都会注意到图片中不同部分。
比如在生成bird
这个词的时候,第一排图片对应的位置关注到的是图片中鸟的位置,意味着模型看着鸟
,生成了bird
。 另外在生成water
时这是观察到了除鸟以外的水面。
在这批图片中,当要生成画横线的词,模型也会在图片中留意对应的不同位置,这就是计算机视觉中的注意力。但是自然语言的注意力又是怎样的呢? 类比过来,视觉中,模型注意局部的图片信息,那么语言中,模型注意局部的文字信息就好啦。
上图是一个情感分析的简单例子,如果我要区分一句话的积极程度
,我可能会注意到这句话中某些词语的积极程度,因为这些词对于我判断积极句子起到了很大作用。 这样的attention注意力就很好理解了。箭头约粗的部分,模型的注意力就越集中。如果是翻译的情景,那注意力就像下面这样,我也尝试用更详细的图来解释一下。
Attention 的方案有很多种,上图描述的和我们这个教程代码中的Attention机制 tfa.seq2seq.LuongAttention()
都是源自这篇论文的方法。 这个方法在 decoder 用爱
预测时,他会用爱
这里的 state,结合上 encoder 中所有词的信息,生成一个注意力权重,比如模型会更加注意 encoder 中的爱
和莫烦
。 然后再结合这个权重,再次把权重施加到 encoder 中的每个 state,这样就有了对于不同 state step 的不一样关注度,注意力的结果(context)也在这里产生。 接着,把 context 和 decoder 那边的信息再次结合,最终输出注意后的答案。
补充一些数学信息,有兴趣的同学可以看一看。论文中提出了三种计算attention score的方式。 这里的 是 decoder 见到输入生成的 hidden state, 是 encoder 见到输入时生成的 hidden state。 意思是将decoder每预测一个词, 我都拿着这个decoder现在的信息去和encoder所有信息做注意力的计算。三个公式的不同点就是是否要引入更多的学习参数和变量。
翻译¶
在这节内容中,我还是以翻译为例。延续前几次用到日期翻译的例子, 我们知道在翻译的模型中,实际上是要构建一个Encoder,一个Decoder。这节内容我们就是让Decoder在生成语言的时候,也注意到Encoder的对应部分。
# 中文的 "年-月-日" -> "day/month/year"
"98-02-26" -> "26/Feb/1998"
今天的例子就是将中文的日期形式转换成英文的格式,同时我们也会输出类似这样的注意力图,让我们知道在模型生成某些字的时候,它究竟依据的是哪里的信息。 先剧透一下,x轴是中文的年月日,y轴是要生成的英文日月年。
秀代码¶
先来看训练过程(只想看全套代码的请点这里), 整个训练过程一如既往的简单,生成数据,建立模型,训练模型。数据的生成是我提前写好的utils.DateData()
功能,不需要掌握,你可以直接调用。
最后你能看到它的整个训练过程。最开始预测成渣渣,但是后面预测结果会好很多。你看刚训练几轮其实效果就已经很不错了,可见注意力的强大。
t: 0 | loss: 3.29403 | input: 89-05-25 | target: 25/May/1989 | inference: 00000000000
t: 70 | loss: 0.41608 | input: 03-09-13 | target: 13/Sep/2003 | inference: 13/Jan/2000<EOS>
t: 140 | loss: 0.01793 | input: 92-06-01 | target: 01/Jun/1992 | inference: 01/Jun/1992<EOS>
t: 210 | loss: 0.00309 | input: 23-01-28 | target: 28/Jan/2023 | inference: 28/Jan/2023<EOS>
...
t: 910 | loss: 0.00003 | input: 11-09-13 | target: 13/Sep/2011 | inference: 13/Sep/2011<EOS>
t: 980 | loss: 0.00003 | input: 06-08-10 | target: 10/Aug/2006 | inference: 10/Aug/2006<EOS>
模型构建时,在encoder部分其实和seq2seq的方法是没有什么变化的。 所以下面的代码和seq2seq也没有什么不同之处。
但是在decoder生成后文的时候,decoder要和encoder的结果有很多的联动,会和encoder的o, h, c
结果交叉在一起。 我们可以定义一个set_attention(x)
, 让模型关注到encoder后的信息。同时来让decoder记录从encoder来的信息和获取decoder的初始化state。
首先我们需要在__init__()
中先定义好attention的方法和decoder的处理单元。 现在我们已经为decoding做好的准备,来自encoder的信息在set_attention()
中被加工好,处理好了。 接下来就是我们训练encoder+decoder的过程。
train_logits()
训练的步骤很简单,分几步。
- 拿到encoder的attention信息和state;
- 筛选出标签;
- 把标签在decoder中embed;
- 拿着所有缓存的 encoded state (attention memory) 和 encoder最后一步产生 state,放入decoder预测,得到所有的由加了注意力的output。
最后我们加上传统的训练方案,根据logits加label进行误差计算,并反向更新梯度,整个训练过程就完成了。
但是在预测的时候,和训练就不一样了。因为预测时没有标签信息,我们只能基于预测出来的词,再接着做后续的计算。
所以在预测时,我们使用到了一个Python的循环,不断生成self.max_pred_len
这个多个数的预测结果, 然后用numpy array
收集起来,作为最后的预测结果。
最后我还写了一个可视化的功能,具体代码参考这里. 在翻译的时候attention的结果就如下。颜色深的地方就代表越注意。
思考¶
为什么我们在训练的时候,不把每次decoder出来的词当做下次预测的输入?这个问题我在seq2seq 提到过,不过我在这里做一下具体说明。
对比上面两张图,上图是训练时的样子,decoder的输入是真实标签信息。下图是预测时的样子,decoder的输入是上一步预测的值。 首先我们说预测的模式,因为缺少真实的label,我们只能拿着上一步预测的值,当成下一步预测的输入,这点非常明确清晰,没毛病。 但是在训练的时候我们为什么不也这么做呢?拿着标签来预测会不会让学习不连贯,效果不好呢?
其实判断一个训练的好坏,并不只是判断训练是否符合逻辑,当然如果用预测(inference)的方式做训练是符合逻辑的,而且的确是可以训练出来一个好结果的。 但是一般我们却不这么用,为什么呢?其实原因就是训练太慢了,如果现在我们训练一个小孩走路,你有两个方案。
- 他摔倒了,把它扶起来继续走路;
- 他摔倒了,我不管,让它从原地爬起来自己再走。
如果我们的目标是训练这个小孩的自理能力,那么我肯定选第二种,但是我们只想训练他走路的能力,不care他能不能自理,那我肯定选简单粗暴的第一种。 模型也一样,我只想训练他的生成话术的能力,但是我不care它的自动纠错能力,那么我还是直接用true label训练来的更快。
总结¶
看过了seq2seq,了解了encoder decoder。我们还在seq2seq上尝试了attention,让模型更加关注自身的结果,让模型更加鲁棒。 下一个我们就要进入到一个将注意力发挥到极致的算法,transformer。