RNN Classifier 循环神经网络
作者: 莫烦 发布于: 2016-10-30
编辑: 学习资料:
方法介绍¶
这次我们用循环神经网络(RNN, Recurrent Neural Networks)进行分类(classification),采用MNIST数据集,主要用到SimpleRNN
层。
MNIST里面的图像分辨率是28×28,为了使用RNN,我们将图像理解为序列化数据。 每一行作为一个输入单元,所以输入数据大小INPUT_SIZE = 28
; 先是第1行输入,再是第2行,第3行,第4行,...,第28行输入, 这就是一张图片也就是一个序列,所以步长TIME_STEPS = 28
。
训练数据要进行归一化处理,因为原始数据是8bit灰度图像所以需要除以255。
搭建模型¶
首先添加RNN层,输入为训练数据,输出数据大小由CELL_SIZE
定义。
然后添加输出层,激励函数选择softmax
设置优化方法,loss
函数和metrics
方法之后就可以开始训练了。 每次训练的时候并不是取所有的数据,只是取BATCH_SIZE
个序列,或者称为BATCH_SIZE
张图片,这样可以大大降低运算时间,提高训练效率。
训练¶
输出test
上的loss
和accuracy
结果
有兴趣的话可以修改BATCH_SIZE
和CELL_SIZE
的值,试试这两个参数对训练时间和精度的影响。