CNN 卷积神经网络
学习资料:
要点¶
卷积神经网络目前被广泛地用在图片识别上, 已经有层出不穷的应用, 如果你对卷积神经网络还没有特别了解, 我制作的 卷积神经网络 动画简介 能让你花几分钟就了解什么是卷积神经网络. 接着我们就一步一步做一个分析手写数字的 CNN 吧.
下面是一个 CNN 最后一层的学习过程, 我们先可视化看看:
MNIST手写数据¶
黑色的地方的值都是0, 白色的地方值大于0.
同样, 我们除了训练数据, 还给一些测试数据, 测试看看它有没有训练好.
CNN模型¶
和以前一样, 我们用一个 class 来建立 CNN 模型. 这个 CNN 整体流程是 卷积(Conv2d
) -> 激励函数(ReLU
) -> 池化, 向下采样 (MaxPooling
) -> 再来一遍 -> 展平多维的卷积成的特征图 -> 接入全连接层 (Linear
) -> 输出
训练¶
下面我们开始训练, 将 x
y
都用 Variable
包起来, 然后放入 cnn
中计算 output
, 最后再计算误差. 下面代码省略了计算精确度 accuracy
的部分, 如果想细看 accuracy
代码的同学, 请去往我的 github 看全部代码.
最后我们再来取10个数据, 看看预测的值到底对不对:
可视化训练(视频中没有)¶
这是做完视频后突然想要补充的内容, 因为可视化可以帮助理解, 所以还是有必要提一下. 可视化的代码主要是用 matplotlib
和 sklearn
来完成的, 因为其中我们用到了 T-SNE
的降维手段, 将高维的 CNN 最后一层输出结果可视化, 也就是 CNN forward 代码中的 x = x.view(x.size(0), -1)
这一个结果.
可视化的代码不是重点, 我们就直接展示可视化的结果吧.
所以这也就是在我 github 代码 中的每一步的意义啦.