Batch Normalization 批标准化 - Tensorflow | 莫烦Python
切换视频源:

Batch Normalization 批标准化

作者: 莫烦 编辑: 莫烦 2016-12-05

学习资料:

什么是 Batch Normalization

请参考我制作的 Batch normalization 简介视频 Batch normalization 是一种解决深度神经网络层数太多, 而没办法有效前向传递(forward propagate)的问题. 因为每一层的输出值都会有不同的 均值(mean) 和 方差(deviation), 所以输出数据的分布也不一样, 如下图, 从左到右是每一层的输入数据分布, 上排的没有 Batch normalization, 下排的有 Batch normalization.

5_13_01.png

我们以前说过, 为了更有效的学习数据, 我们会对数据预处理, 进行 normalization (请参考我制作的 为什么要特征标准化). 而现在请想象, 我们可以把 每层输出的值 都看成 后面一层所接收的数据. 对每层都进行一次 normalization 会不会更好呢? 这就是 Batch normalization 方法的由来.

搭建网络

输入需要的模块和定义网络的结构

使用 build_net() 功能搭建神经网络:

创建数据

创造数据并可视化数据:

5_13_02.png

Batch Normalization 代码

为了实现 Batch Normalization, 我们要对每一层的代码进行修改, 给 built_netadd_layer 都加上 norm 参数, 表示是否是 Batch Normalization 层:

然后每层的 Wx_plus_b 需要进行一次 batch normalize 的步骤, 这样输出到 activationWx_plus_b 就已经被 normalize 过了:

如果你是使用 batch 进行每次的更新, 那每个 batch 的 mean/var 都会不同, 所以我们可以使用 moving average 的方法记录并慢慢改进 mean/var 的值. 然后将修改提升后的 mean/var 放入 tf.nn.batch_normalization(). 而且在 test 阶段, 我们就可以直接调用最后一次修改的 mean/var 值进行测试, 而不是采用 test 时的 fcmean/fcvar.

那如何确定我们是在 train 阶段还是在 test 阶段呢, 我们可以修改上面的算法, 想办法传入 on_train 参数, 你也可以把 on_train 定义成全局变量. (注意: github 的代码中没有这一段, 想做 test 的同学们需要自己修改)

同样, 我们也可以在输入数据 xs 时, 给它做一个 normalization, 同样, 如果是最 batch data 来训练的话, 要重复上述的记录修改 mean/var 的步骤:

然后我们把在建立网络的循环中的这一步加入 norm 这个参数:

对比有无 BN

搭建两个神经网络, 一个没有 BN, 一个有 BN:

训练神经网络:

代码中的 plot_his() 不会在这里讲解, 请自己在全套代码中查看.

5_13_03.gif

可以看出, 没有用 BN 的时候, 每层的值迅速全部都变为 0, 也可以说, 所有的神经元都已经死了. 而有 BN, relu 过后, 每层的值都能有一个比较好的分布效果, 大部分神经元都还活着. (看不懂了? 没问题, 再去看一遍我制作的 Batch normalization 简介视频).

Relu 激励函数的图在这里:

5_13_04.png

我们也看看使用 relu cost 的对比:

5_13_05.png

因为没有使用 NB 的网络, 大部分神经元都死了, 所以连误差曲线都没了.

如果使用不同的 ACTIVATION 会怎么样呢? 不如把 relu 换成 tanh:

5_13_06.gif

可以看出, 没有 NB, 每层的值迅速全部都饱和, 都跑去了 -1/1 这个饱和区间, 有 NB, 即使前一层因变得相对饱和, 但是后面几层的值都被 normalize 到有效的不饱和区间内计算. 确保了一个活的神经网络.

tanh 激励函数的图在这里:

5_13_07.gif

最后我们看一下使用 tanh 的误差对比:

5_13_08.png


降低知识传递的门槛

莫烦很常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 通过 【赞助莫烦】 能让我感到认同,我也更有理由坚持下去。

    Tensorflow