切换视频源:

Save 保存 提取

作者: Alice 编辑: 莫烦 2016-11-03

学习资料:

要点

今天学习如何保存神经网络,以方便日后可以直接提取使用。

保存的方式是我们可以先把神经网络的参数,比如说 weights 还有 bias 保存起来,再重新定义神经网络的结构,使用模型的时候需要把参数 set 到结构中去。

保存和提取的方法是利用 shared 变量的 get 功能,拿出变量值保存到文件中去, 下一次再定义 weights 和 bias 的时候,可以直接把保存好的值放到 shared variable 中去。

本文以 Classification 分类学习 那节的代码为例。

导入模块

在引入相关包时,需要用到 pickle, 这是 python 中用来储存文件的一个模块。

import numpy as np
import theano
import theano.tensor as T
import pickle

创建数据-建立模型-激活-训练

接下来的 创建数据-建立模型-激活模型-训练模型 都和分类那节课的内容是一样的。

def compute_accuracy(y_target, y_predict):
    correct_prediction = np.equal(y_predict, y_target)
    accuracy = np.sum(correct_prediction)/len(correct_prediction)
    return accuracy

rng = np.random

# set random seed
np.random.seed(100)

N = 400
feats = 784

# generate a dataset: D = (input_values, target_class)
D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))

# Declare Theano symbolic variables
x = T.dmatrix("x")
y = T.dvector("y")

# initialize the weights and biases
w = theano.shared(rng.randn(feats), name="w")
b = theano.shared(0., name="b")

# Construct Theano expression graph
p_1 = 1 / (1 + T.exp(-T.dot(x, w) - b))
prediction = p_1 > 0.5
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1)
cost = xent.mean() + 0.01 * (w ** 2).sum()
gw, gb = T.grad(cost, [w, b])

# Compile
learning_rate = 0.1
train = theano.function(
          inputs=[x, y],
          updates=((w, w - learning_rate * gw), (b, b - learning_rate * gb)))
predict = theano.function(inputs=[x], outputs=prediction)

# Training
for i in range(500):
    train(D[0], D[1])

今天重点放在保存和提取模型的部分:

保存模型

把所有的参数放入 save 文件夹中,命名文件为 model.pickle,以 wb 的形式打开并把参数写入进去。

定义 model=[] 用来保存 weightsbias,这里用的是 list 结构保存,也可以用字典结构保存,提取值时用 get_value() 命令。

再用 pickle.dumpmodel 保存在 file 中。

可以通过 print(model[0][:10]) 打印出保存的 weights 的前 10 个数,方便后面提取模型时检查是否保存成功。还可以打印 accuracy 看准确率是否一样。

# save model
with open('save/model.pickle', 'wb') as file:
    model = [w.get_value(), b.get_value()]
    pickle.dump(model, file)
    print(model[0][:10])
    print("accuracy:", compute_accuracy(D[1], predict(D[0])))

"""
[-0.15707296  0.14590665 -0.08451091 -0.12594476 -0.13424085 -0.33887753
  0.12650858  0.20702686  0.0549835   0.29920542]
accuracy: 1.0
"""

执行上述代码后可以看到 save 文件夹中生成了一个 model.pickle 的文件。

提取模型

接下来提取模型时,提前把代码中 # Training# save model 两部分注释掉,即相当于只是通过 创建数据-建立模型-激活模型 构建好了新的模型结构,下面要通过调用存好的参数来进行预测。

rb 的形式读取 model.pickle 文件加载到 model 变量中去,

然后用 set_value 命令把 model 的第 0 位存进 w,第 1 位存进 b 中。

同样可以打印出 weights 的前 10 位和 accuracy,来对比之前的结果,可以发现结果完全一样。

# load model
with open('save/model.pickle', 'rb') as file:
    model = pickle.load(file)
    w.set_value(model[0])
    b.set_value(model[1])
    print(w.get_value()[:10])
    print("accuracy:", compute_accuracy(D[1], predict(D[0])))
    
"""
[-0.15707296  0.14590665 -0.08451091 -0.12594476 -0.13424085 -0.33887753
  0.12650858  0.20702686  0.0549835   0.29920542]
accuracy: 1.0
"""

以上就是保存和提取的过程。

降低知识传递的门槛

莫烦很常从互联网上学习知识,开源分享的人是我学习的榜样。 他们的行为也改变了我对教育的态度: 降低知识传递的门槛免费 奉献我的所学正是受这种态度的影响。 通过 【赞助莫烦】 能让我感到认同,我也更有理由坚持下去。

想当算法工程师拿高薪?转行AI无门道?莫烦也想祝你一臂之力,市面上机构繁杂, 经过莫烦的筛选,七月在线脱颖而出, 莫烦和他们合作,独家提供大额 【培训优惠券】, 让你更有机会接触丰富的教学资源、培训辅导体验, 祝你找/换工作/学习顺利~