【Tensorflow1.0】训练结果的保存与加载

训练完成之后咱们就能够直接使用训练好的模板进行预测了python

可是每次在预测以前都要进行训练,不是一个常规操做,毕竟有些复杂的模型须要训练好几天甚至更久git

因此将训练好的模型进行保存,当有须要的时候从新加载这个模型进行预测或者继续训练,这才是一个常规操做github

咱们依然使用最简单的例子进行说明,这里沿用Tensorflow入门——实现最简单的线性回归模型的预测 这个例子进行session

====================================================dom

模型的保存优化

在tensorflow中保存模型使用的是tf.train.Saver对象,咱们须要在保存以前先实例化这个对象this

saver = tf.train.Saver()

对于模型的保存,其实就是保存整个session对象,再给定一个path就实现了模型的保存(对应的path须要存在,若是不存在会报错)spa

saver.save(sess, SAVE_PATH + 'model')

保存完成之后,能够看到对应的目录下面生成了4个文件.net

model.meta中保存的是模型,而这个模型仅仅是计算流和参数的定义,能够认为是一个未经训练的模型rest

model.index和model.data-00000-of-00001中保存的是参数值,也就是真正训练的结果

checkpoint中保存的是最后几回保存的信息,从文件名就能够看出它是一个检查点,记录了其余几个文件之间的关系,这是一个txt文件,咱们能够打开看一下(在这个例子中咱们只保存了一次,若是保存屡次的话这个文件中会记录屡次保存结果的信息)

下面是运行的log

epoch= 0 _loss= 6029.333 _w= [0.005] _n= [0.005]
epoch= 5000 _loss= 10.897877 _w= [4.2031364] _n= [-1.905781]
epoch= 10000 _loss= 112.455055 _w= [4.7837024] _n= [-11.81817]
epoch= 15000 _loss= 6.2376847 _w= [5.1548934] _n= [-19.740992]
epoch= 20000 _loss= 2.9357195 _w= [5.2787647] _n= [-22.662355]
epoch= 25000 _loss= 0.022824269 _w= [5.3112087] _n= [-23.141117]
epoch= 30000 _loss= 1.3711997 _w= [5.326612] _n= [-23.255548]
epoch= 35000 _loss= 0.005477888 _w= [5.3088646] _n= [-23.289743]
epoch= 40000 _loss= 2.8727396 _w= [5.315157] _n= [-23.191956]
epoch= 45000 _loss= 0.009563584 _w= [5.300157] _n= [-23.18857]
训练完成,开始预测。。。
x= 0.1610020536371326 y预测= [-22.44688] y实际= -22.401859054114084
x= 7.379937860774309 y预测= [16.030691] y实际= 16.075068797927063
x= 5.1744928042152685 y预测= [4.2754745] y实际= 4.320046646467379
x= 10.26990231423617 y预测= [31.434462] y实际= 31.478579334878784
x= 23.219346463697207 y预测= [100.45616] y实际= 100.49911665150611
x= 7.101197776563807 y预测= [14.544985] y实际= 14.589384149085088
x= 3.097841295090581 y预测= [-6.7932644] y实际= -6.7485058971672025
x= 6.474682013005717 y预测= [11.205599] y实际= 11.250055129320469
x= 13.811264369891983 y预测= [50.310234] y实际= 50.35403909152427
x= 29.260954830177415 y预测= [132.65846] y实际= 132.70088924484563

====================================================

模型的加载

由于保存时分红了模型和参数值两部分进行保存,因此在加载模型的时候也须要将模型和参数值(训练结果)两步分开进行加载

上面讲到了meta文件是模型,checkpoint是参数值,这里分别使用tf.train下的import_meta_graph和latest_checkpoint方法来加载

saver = tf.train.import_meta_graph(SAVE_PATH + 'model.meta')
saver.restore(sess, tf.train.latest_checkpoint(SAVE_PATH))

这样,以前保存起来的模型就被咱们从新加载成功了,可是在预测或者继续训练以前,咱们须要从新定义相关的变量

可是也不是凭空的从新定义,由于这些参数已经在以前保存的模型中定义过了,咱们只须要从已经加载的模型中将相关参数的定义给找出来就能够了

为了找回参数的定义,咱们须要稍微修改一下模型,将这些须要在从新加载阶段找回的参数定义给上命名(若是是用来预测,咱们须要找回X和OUT,若是是用来继续训练,咱们须要找回X、OUT、loss),因此这里咱们将模型中相关的参数都给上命名

X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32, name='Y')

W = tf.Variable(tf.zeros([1]), name='W')
B = tf.Variable(tf.zeros([1]), name='B')
OUT = tf.add(tf.multiply(X, W), B, name='OUT')

loss = tf.reduce_mean(tf.square(Y - OUT), name='loss')
optimizer = tf.train.AdamOptimizer(0.005).minimize(loss)

在找回参数以前,须要获取计算图对象(关于计算图的概念,如今能够没必要先了解)

graph = tf.get_default_graph()

而后经过get_all_collection_keys,来查看这个模型中的内容

print(graph.get_all_collection_keys())

能够看到一共有三项,分别是train_op:优化器,trainable_variables:可训练的变量,variables:全部变量

['train_op', 'trainable_variables', 'variables']

咱们再经过get_collection方法把这些对象也打印出来看一下

print(graph.get_collection('train_op'))
print(graph.get_collection('trainable_variables'))
print(graph.get_collection('variables'))

可是从中发现,咱们须要找回的参数都不在这里

[<tf.Operation 'Adam' type=NoOp>]
[<tf.Variable 'W:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B:0' shape=(1,) dtype=float32_ref>]
[<tf.Variable 'W:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'beta1_power:0' shape=() dtype=float32_ref>, <tf.Variable 'beta2_power:0' shape=() dtype=float32_ref>, <tf.Variable 'W/Adam:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'W/Adam_1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B/Adam:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B/Adam_1:0' shape=(1,) dtype=float32_ref>]

继续经过get_operations方法来查看全部的操做数

print(graph.get_operations())

从如下内容中咱们发现了须要找回的参数X、Y、OUT等

[<tf.Operation 'X' type=Placeholder>, <tf.Operation 'Y' type=Placeholder>, <tf.Operation 'zeros' type=Const>, <tf.Operation 'W' type=VariableV2>, <tf.Operation 'W/Assign' type=Assign>, <tf.Operation 'W/read' type=Identity>, <tf.Operation 'zeros_1' type=Const>, <tf.Operation 'B' type=VariableV2>, <tf.Operation 'B/Assign' type=Assign>, <tf.Operation 'B/read' type=Identity>, <tf.Operation 'Mul' type=Mul>, <tf.Operation 'OUT' type=Add>, <tf.Operation 'sub' type=Sub>, <tf.Operation 'Square' type=Square>, <tf.Operation 'Rank' type=Rank>, <tf.Operation 'range/start' type=Const>, <tf.Operation 'range/delta' type=Const>, <tf.Operation 'range' type=Range>, <tf.Operation 'loss' type=Mean>, <tf.Operation 'gradients/Shape' type=Const>, <tf.Operation 'gradients/grad_ys_0' type=Const>, <tf.Operation 'gradients/Fill' type=Fill>, <tf.Operation 'gradients/loss_grad/Shape' type=Shape>, <tf.Operation 'gradients/loss_grad/Size' type=Size>, <tf.Operation 'gradients/loss_grad/add' type=Add>, <tf.Operation 'gradients/loss_grad/mod' type=FloorMod>, <tf.Operation 'gradients/loss_grad/Shape_1' type=Shape>, <tf.Operation 'gradients/loss_grad/range/start' type=Const>, <tf.Operation 'gradients/loss_grad/range/delta' type=Const>, <tf.Operation 'gradients/loss_grad/range' type=Range>, <tf.Operation 'gradients/loss_grad/Fill/value' type=Const>, <tf.Operation 'gradients/loss_grad/Fill' type=Fill>, <tf.Operation 'gradients/loss_grad/DynamicStitch' type=DynamicStitch>, <tf.Operation 'gradients/loss_grad/Maximum/y' type=Const>, <tf.Operation 'gradients/loss_grad/Maximum' type=Maximum>, <tf.Operation 'gradients/loss_grad/floordiv' type=FloorDiv>, <tf.Operation 'gradients/loss_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/loss_grad/Tile' type=Tile>, <tf.Operation 'gradients/loss_grad/Shape_2' type=Shape>, <tf.Operation 'gradients/loss_grad/Shape_3' type=Const>, <tf.Operation 'gradients/loss_grad/Const' type=Const>, <tf.Operation 'gradients/loss_grad/Prod' type=Prod>, <tf.Operation 'gradients/loss_grad/Const_1' type=Const>, <tf.Operation 'gradients/loss_grad/Prod_1' type=Prod>, <tf.Operation 'gradients/loss_grad/Maximum_1/y' type=Const>, <tf.Operation 'gradients/loss_grad/Maximum_1' type=Maximum>, <tf.Operation 'gradients/loss_grad/floordiv_1' type=FloorDiv>, <tf.Operation 'gradients/loss_grad/Cast' type=Cast>, <tf.Operation 'gradients/loss_grad/truediv' type=RealDiv>, <tf.Operation 'gradients/Square_grad/Const' type=Const>, <tf.Operation 'gradients/Square_grad/Mul' type=Mul>, <tf.Operation 'gradients/Square_grad/Mul_1' type=Mul>, <tf.Operation 'gradients/sub_grad/Shape' type=Shape>, <tf.Operation 'gradients/sub_grad/Shape_1' type=Shape>, <tf.Operation 'gradients/sub_grad/BroadcastGradientArgs' type=BroadcastGradientArgs>, <tf.Operation 'gradients/sub_grad/Sum' type=Sum>, <tf.Operation 'gradients/sub_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/sub_grad/Sum_1' type=Sum>, <tf.Operation 'gradients/sub_grad/Neg' type=Neg>, <tf.Operation 'gradients/sub_grad/Reshape_1' type=Reshape>, <tf.Operation 'gradients/sub_grad/tuple/group_deps' type=NoOp>, <tf.Operation 'gradients/sub_grad/tuple/control_dependency' type=Identity>, <tf.Operation 'gradients/sub_grad/tuple/control_dependency_1' type=Identity>, <tf.Operation 'gradients/OUT_grad/Shape' type=Shape>, <tf.Operation 'gradients/OUT_grad/Shape_1' type=Const>, <tf.Operation 'gradients/OUT_grad/BroadcastGradientArgs' type=BroadcastGradientArgs>, <tf.Operation 'gradients/OUT_grad/Sum' type=Sum>, <tf.Operation 'gradients/OUT_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/OUT_grad/Sum_1' type=Sum>, <tf.Operation 'gradients/OUT_grad/Reshape_1' type=Reshape>, <tf.Operation 'gradients/OUT_grad/tuple/group_deps' type=NoOp>, <tf.Operation 'gradients/OUT_grad/tuple/control_dependency' type=Identity>, <tf.Operation 'gradients/OUT_grad/tuple/control_dependency_1' type=Identity>, <tf.Operation 'gradients/Mul_grad/Shape' type=Shape>, <tf.Operation 'gradients/Mul_grad/Shape_1' type=Const>, <tf.Operation 'gradients/Mul_grad/BroadcastGradientArgs' type=BroadcastGradientArgs>, <tf.Operation 'gradients/Mul_grad/Mul' type=Mul>, <tf.Operation 'gradients/Mul_grad/Sum' type=Sum>, <tf.Operation 'gradients/Mul_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/Mul_grad/Mul_1' type=Mul>, <tf.Operation 'gradients/Mul_grad/Sum_1' type=Sum>, <tf.Operation 'gradients/Mul_grad/Reshape_1' type=Reshape>, <tf.Operation 'gradients/Mul_grad/tuple/group_deps' type=NoOp>, <tf.Operation 'gradients/Mul_grad/tuple/control_dependency' type=Identity>, <tf.Operation 'gradients/Mul_grad/tuple/control_dependency_1' type=Identity>, <tf.Operation 'beta1_power/initial_value' type=Const>, <tf.Operation 'beta1_power' type=VariableV2>, <tf.Operation 'beta1_power/Assign' type=Assign>, <tf.Operation 'beta1_power/read' type=Identity>, <tf.Operation 'beta2_power/initial_value' type=Const>, <tf.Operation 'beta2_power' type=VariableV2>, <tf.Operation 'beta2_power/Assign' type=Assign>, <tf.Operation 'beta2_power/read' type=Identity>, <tf.Operation 'W/Adam/Initializer/zeros' type=Const>, <tf.Operation 'W/Adam' type=VariableV2>, <tf.Operation 'W/Adam/Assign' type=Assign>, <tf.Operation 'W/Adam/read' type=Identity>, <tf.Operation 'W/Adam_1/Initializer/zeros' type=Const>, <tf.Operation 'W/Adam_1' type=VariableV2>, <tf.Operation 'W/Adam_1/Assign' type=Assign>, <tf.Operation 'W/Adam_1/read' type=Identity>, <tf.Operation 'B/Adam/Initializer/zeros' type=Const>, <tf.Operation 'B/Adam' type=VariableV2>, <tf.Operation 'B/Adam/Assign' type=Assign>, <tf.Operation 'B/Adam/read' type=Identity>, <tf.Operation 'B/Adam_1/Initializer/zeros' type=Const>, <tf.Operation 'B/Adam_1' type=VariableV2>, <tf.Operation 'B/Adam_1/Assign' type=Assign>, <tf.Operation 'B/Adam_1/read' type=Identity>, <tf.Operation 'Adam/learning_rate' type=Const>, <tf.Operation 'Adam/beta1' type=Const>, <tf.Operation 'Adam/beta2' type=Const>, <tf.Operation 'Adam/epsilon' type=Const>, <tf.Operation 'Adam/update_W/ApplyAdam' type=ApplyAdam>, <tf.Operation 'Adam/update_B/ApplyAdam' type=ApplyAdam>, <tf.Operation 'Adam/mul' type=Mul>, <tf.Operation 'Adam/Assign' type=Assign>, <tf.Operation 'Adam/mul_1' type=Mul>, <tf.Operation 'Adam/Assign_1' type=Assign>, <tf.Operation 'Adam' type=NoOp>, <tf.Operation 'init' type=NoOp>, <tf.Operation 'save/filename/input' type=Const>, <tf.Operation 'save/filename' type=PlaceholderWithDefault>, <tf.Operation 'save/Const' type=PlaceholderWithDefault>, <tf.Operation 'save/SaveV2/tensor_names' type=Const>, <tf.Operation 'save/SaveV2/shape_and_slices' type=Const>, <tf.Operation 'save/SaveV2' type=SaveV2>, <tf.Operation 'save/control_dependency' type=Identity>, <tf.Operation 'save/RestoreV2/tensor_names' type=Const>, <tf.Operation 'save/RestoreV2/shape_and_slices' type=Const>, <tf.Operation 'save/RestoreV2' type=RestoreV2>, <tf.Operation 'save/Assign' type=Assign>, <tf.Operation 'save/Assign_1' type=Assign>, <tf.Operation 'save/Assign_2' type=Assign>, <tf.Operation 'save/Assign_3' type=Assign>, <tf.Operation 'save/Assign_4' type=Assign>, <tf.Operation 'save/Assign_5' type=Assign>, <tf.Operation 'save/Assign_6' type=Assign>, <tf.Operation 'save/Assign_7' type=Assign>, <tf.Operation 'save/restore_all' type=NoOp>]

恢复参数:

这里须要注意的是在后面须要加上“:0”,表明第0个参数(这个涉及到另外一个概念,之后再细讲)

X = graph.get_tensor_by_name('X:0')
Y = graph.get_tensor_by_name('Y:0')
W = graph.get_tensor_by_name('W:0')
B = graph.get_tensor_by_name('B:0')
OUT = graph.get_tensor_by_name('OUT:0')
loss = graph.get_tensor_by_name('loss:0')

恢复优化器:

optimizer = graph.get_collection('train_op')

仍然将以前代码中的预测和训练相关的逻辑拷过来执行一下

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from D:/test/tf1/xw_b/model
从新加载,开始预测。。。
x= 26.764991404677083 y预测= [[119.67893]] y实际= 119.39740418692885
x= 25.85141169466281 y预测= [[114.797356]] y实际= 114.52802433255279
x= 17.046457082367727 y预测= [[67.749466]] y实际= 67.59761624901998
x= 5.918111849660451 y预测= [[8.286896]] y实际= 8.283536158690204
x= 7.409698341670607 y预测= [[16.256956]] y实际= 16.233692161104333
x= 15.469762867798304 y预测= [[59.324646]] y实际= 59.19383608536495
x= 11.519144276233455 y预测= [[38.215134]] y实际= 38.13703899232431
x= 27.85137286496477 y预测= [[125.48383]] y实际= 125.18781737026221
x= 26.50150532742774 y预测= [[118.271034]] y实际= 117.99302339518984
x= 15.664275922154658 y预测= [[60.364]] y实际= 60.23059066508432
继续训练
epoch= 0 _loss= 16.00476 _w= [5.3422985] _n= [-23.3365]
epoch= 5000 _loss= 19.420956 _w= [5.3203373] _n= [-23.186474]
epoch= 10000 _loss= 0.30325127 _w= [5.3471537] _n= [-23.290209]
epoch= 15000 _loss= 3.018042 _w= [5.32293] _n= [-23.245607]
epoch= 20000 _loss= 12.473472 _w= [5.309146] _n= [-23.24814]
epoch= 25000 _loss= 17.09799 _w= [5.3170156] _n= [-23.342768]
epoch= 30000 _loss= 18.25596 _w= [5.3193855] _n= [-23.225794]
epoch= 35000 _loss= 0.32235628 _w= [5.339825] _n= [-23.196495]
epoch= 40000 _loss= 2.6598516 _w= [5.304051] _n= [-23.248428]
epoch= 45000 _loss= 6.564373 _w= [5.328891] _n= [-23.212101]
继续训练完成,开始预测。。。
x= 24.14983880390778 y预测= [[105.329315]] y实际= 105.45864082482846
x= 8.654129156050717 y预测= [[22.795414]] y实际= 22.86650840175032
x= 17.410606725772045 y预测= [[69.434525]] y实际= 69.53853384836499
x= 17.55599000188004 y预测= [[70.20888]] y实际= 70.31342671002061
x= 24.43148021367975 y预测= [[106.82939]] y实际= 106.95978953891309
x= 20.286380740475614 y预测= [[84.751595]] y实际= 84.86640934673503
x= 2.8131286438423353 y预测= [[-8.3151655]] y实际= -8.266024328320354
x= 11.781139561484927 y预测= [[39.450626]] y实际= 39.53347386271466
x= 4.611147529065006 y预测= [[1.2615166]] y实际= 1.3174163299164796
x= 6.625783852577516 y预测= [[11.991955]] y实际= 12.055427934238164

使用恢复之后的模型直接进行预测,匹配程度也很是高,而进行继续训练也没问题

====================================================

完整代码以下,在python3.6.八、tensorflow1.13环境下成功运行

https://github.com/yukiti2007/sample/blob/master/python/tensorflow/wx_b_save.py

import random

import tensorflow as tf

SAVE_PATH = "D:/test/tf1/xw_b/"


def create_data(for_train=False):
    w = 5.33
    b = -23.26
    x = random.random() * 30
    y = w * x + b

    if for_train:
        noise = (random.random() - 0.5) * 10
        y += noise

    return x, y


def train():
    X = tf.placeholder(tf.float32, name='X')
    Y = tf.placeholder(tf.float32, name='Y')

    W = tf.Variable(tf.zeros([1]), name='W')
    B = tf.Variable(tf.zeros([1]), name='B')
    OUT = tf.add(tf.multiply(X, W), B, name='OUT')

    loss = tf.reduce_mean(tf.square(Y - OUT), name='loss')
    optimizer = tf.train.AdamOptimizer(0.005).minimize(loss)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(50000):
            x_data, y_data = create_data(True)
            _, _loss, _w, _b = sess.run([optimizer, loss, W, B], feed_dict={X: x_data, Y: y_data})
            if 0 == epoch % 5000:
                print("epoch=", epoch, "_loss=", _loss, "_w=", _w, "_n=", _b)

        print("训练完成,开始预测。。。")
        for step in range(10):
            x_data, y_data = create_data(False)
            prediction_value = sess.run(OUT, feed_dict={X: x_data})
            print("x=", x_data, "y预测=", prediction_value, "y实际=", y_data)

        saver = tf.train.Saver()
        saver.save(sess, SAVE_PATH + 'model')


def predict():
    sess = tf.Session()
    saver = tf.train.import_meta_graph(SAVE_PATH + 'model.meta')
    saver.restore(sess, tf.train.latest_checkpoint(SAVE_PATH))

    graph = tf.get_default_graph()
    X = graph.get_tensor_by_name('X:0')
    Y = graph.get_tensor_by_name('Y:0')
    W = graph.get_tensor_by_name('W:0')
    B = graph.get_tensor_by_name('B:0')
    OUT = graph.get_tensor_by_name('OUT:0')
    loss = graph.get_tensor_by_name('loss:0')
    optimizer = graph.get_collection('train_op')
    # print(graph.get_all_collection_keys())
    # print(graph.get_collection('train_op'))
    # print(graph.get_collection('trainable_variables'))
    # print(graph.get_collection('variables'))
    # print(graph.get_operations())

    print("从新加载,开始预测。。。")
    for step in range(10):
        x_data, y_data = create_data(False)
        prediction_value = sess.run(OUT, feed_dict={X: [[x_data]]})
        print("x=", x_data, "y预测=", prediction_value, "y实际=", y_data)

    print("继续训练")
    for epoch in range(50000):
        x_data, y_data = create_data(True)
        _, _loss, _w, _b = sess.run([optimizer, loss, W, B], feed_dict={X: x_data, Y: y_data})
        if 0 == epoch % 5000:
            print("epoch=", epoch, "_loss=", _loss, "_w=", _w, "_n=", _b)

    print("继续训练完成,开始预测。。。")
    for step in range(10):
        x_data, y_data = create_data(False)
        prediction_value = sess.run(OUT, feed_dict={X: [[x_data]]})
        print("x=", x_data, "y预测=", prediction_value, "y实际=", y_data)


if __name__ == "__main__":
    train()
    predict()
相关文章
相关标签/搜索