RNN LSTM 循环神经网络 (分类例子)
学习资料:
- 相关代码
- 为 TF 2017 打造的新版可视化教学代码
- 机器学习-简介系列 什么是RNN
- 机器学习-简介系列 什么是LSTM RNN
- 本代码基于网上这一份代码 code
设置 RNN 的参数¶
这次我们会使用 RNN 来进行分类的训练 (Classification). 会继续使用到手写数字 MNIST 数据集. 让 RNN 从每张图片的第一行像素读到最后一行, 然后再进行分类判断. 接下来我们导入 MNIST 数据并确定 RNN 的各种参数(hyper-parameters):
接着定义 x
, y
的 placeholder
和 weights
, biases
的初始状况.
定义 RNN 的主体结构¶
接着开始定义 RNN 主体结构, 这个 RNN 总共有 3 个组成部分 ( input_layer
, cell
, output_layer
). 首先我们先定义 input_layer
:
接着是 cell
中的计算, 有两种途径:
- 使用
tf.nn.rnn(cell, inputs)
(不推荐原因). 但是如果使用这种方法, 可以参考原因; - 使用
tf.nn.dynamic_rnn(cell, inputs)
(推荐). 这次的练习将使用这种方式.
因 Tensorflow 版本升级原因, state_is_tuple=True
将在之后的版本中变为默认. 对于 lstm
来说, state
可被分为(c_state, h_state)
.
如果使用tf.nn.dynamic_rnn(cell, inputs)
, 我们要确定 inputs
的格式. tf.nn.dynamic_rnn
中的 time_major
参数会针对不同 inputs
格式有不同的值.
- 如果
inputs
为 (batches, steps, inputs) ==>time_major=False
; - 如果
inputs
为 (steps, batches, inputs) ==>time_major=True
;
最后是 output_layer
和 return
的值. 因为这个例子的特殊性, 有两种方法可以求得 results
.
方式一: 直接调用final_state
中的 h_state
(final_state[1]
) 来进行运算:
方式二: 调用最后一个 outputs
(在这个例子中,和上面的final_state[1]
是一样的):
在 def RNN()
的最后输出 result
定义好了 RNN 主体结构后, 我们就可以来计算 cost
和 train_op
:
训练 RNN¶
训练时, 不断输出 accuracy
, 观看结果:
最终 accuracy
的结果如下:
0.1875
0.65625
0.726562
0.757812
0.820312
0.796875
0.859375
0.921875
0.921875
0.898438
0.828125
0.890625
0.9375
0.921875
0.9375
0.929688
0.953125
....