RNN 循环神经网络 (回归)
学习资料:
要点¶
循环神经网络让神经网络有了记忆, 对于序列话的数据,循环神经网络能达到更好的效果. 如果你对循环神经网络还没有特别了解, 请观看几分钟的短动画, RNN 动画简介 和 LSTM 动画简介 能让你生动理解 RNN. 上次我们提到了用 RNN 的最后一个时间点输出来判断之前看到的图片属于哪一类, 这次我们来真的了, 用 RNN 来及时预测时间序列.
训练数据¶
我们要用到的数据就是这样的一些数据, 我们想要用 sin
的曲线预测出 cos
的曲线.
RNN模型¶
这一次的 RNN, 我们对每一个 r_out
都得放到 Linear
中去计算出预测的 output
, 所以我们能用一个 for loop 来循环计算. 这点是 Tensorflow 望尘莫及的! 除了这点, 还有一些动态的过程都可以在这个教程中查看, 看看我们的 PyTorch 和 Tensorflow 到底哪家强.
其实熟悉 RNN 的朋友应该知道, forward
过程中的对每个时间点求输出还有一招使得计算量比较小的. 不过上面的内容主要是为了呈现 PyTorch 在动态构图上的优势, 所以我用了一个 for loop
来搭建那套输出系统. 下面介绍一个替换方式. 使用 reshape 的方式整批计算.
训练¶
下面的代码就能实现动图的效果啦~开心, 可以看出, 我们使用 x
作为输入的 sin
值, 然后 y
作为想要拟合的输出, cos
值. 因为他们两条曲线是存在某种关系的, 所以我们就能用 sin
来预测 cos
. rnn
会理解他们的关系, 并用里面的参数分析出来这个时刻 sin
曲线上的点如何对应上 cos
曲线上的点.
所以这也就是在我 github 代码 中的每一步的意义啦.