RNN 循环神经网络 (分类)
作者: 莫烦 发布于: 2017-05-10
编辑: 学习资料:
要点¶
循环神经网络让神经网络有了记忆, 对于序列话的数据,循环神经网络能达到更好的效果. 如果你对循环神经网络还没有特别了解, 请观看几分钟的短动画, RNN 动画简介 和 LSTM 动画简介 能让你生动理解 RNN. 接着我们就一步一步做一个分析手写数字的 RNN 吧.
MNIST手写数据¶
黑色的地方的值都是0, 白色的地方值大于0.
同样, 我们除了训练数据, 还给一些测试数据, 测试看看它有没有训练好.
RNN模型¶
和以前一样, 我们用一个 class 来建立 RNN 模型. 这个 RNN 整体流程是
(input0, state0)
->LSTM
->(output0, state1)
;(input1, state1)
->LSTM
->(output1, state2)
;- ...
(inputN, stateN)
->LSTM
->(outputN, stateN+1)
;outputN
->Linear
->prediction
. 通过LSTM
分析每一时刻的值, 并且将这一时刻和前面时刻的理解合并在一起, 生成当前时刻对前面数据的理解或记忆. 传递这种理解给下一时刻分析.
训练¶
我们将图片数据看成一个时间上的连续数据, 每一行的像素点都是这个时刻的输入, 读完整张图片就是从上而下的读完了每行的像素点. 然后我们就可以拿出 RNN 在最后一步的分析值判断图片是哪一类了. 下面的代码省略了计算 accuracy
的部分, 你可以在我的 github 中看到全部代码.
最后我们再来取10个数据, 看看预测的值到底对不对:
所以这也就是在我 github 代码 中的每一步的意义啦.