在《Tensorflow SavedModel模型的保存与加载》一文中,咱们谈到SavedModel格式的优势是与语言无关、容易部署和加载。那问题来了,若是别人发布了一个SavedModel模型,咱们该如何去了解这个模型,如何去加载和使用这个模型呢?python
理想的状态是模型发布者编写出完备的文档,给出示例代码。但在不少状况下,咱们只是获得了训练好的模型,而没有齐全的文档,这个时候咱们可否从模型自己上得到一些信息呢?好比模型的输入输出、模型的结构等等。git
答案是能够的。github
这里的签名,并不是是为了保证模型不被修改的那种电子签名。个人理解是相似于编程语言中模块的输入输出信息,好比函数名,输入参数类型,输出参数类型等等。咱们以《Tensorflow SavedModel模型的保存与加载》里的代码为例,从语句:web
signature = predict_signature_def(inputs={'myInput': x},
outputs={'myOutput': y})
复制代码
咱们能够看到模型的输入名为myInput,输出名为myOutput。若是咱们没有源码呢?编程
Tensorflow提供了一个工具,若是你下载了Tensorflow的源码,能够找到这样一个文件,./tensorflow/python/tools/saved_model_cli.py,你能够加上-h参数查看该脚本的帮助信息:浏览器
usage: saved_model_cli.py [-h] [-v] {show,run,scan} ...
saved_model_cli: Command-line interface for SavedModel
optional arguments:
-h, --help show this help message and exit
-v, --version show program's version number and exit commands: valid commands {show,run,scan} additional help 复制代码
指定SavedModel模所在的位置,咱们就能够显示SavedModel的模型信息:bash
python $TENSORFLOW_DIR/tensorflow/python/tools/saved_model_cli.py show --dir ./model/ --all
复制代码
结果为:编程语言
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['predict']:
The given SavedModel SignatureDef contains the following input(s):
inputs['myInput'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 784)
name: myInput:0
The given SavedModel SignatureDef contains the following output(s):
outputs['myOutput'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 10)
name: Softmax:0
Method name is: tensorflow/serving/predict
复制代码
从这里咱们能够清楚的看到模型的输入/输出的名称、数据类型、shape以及方法名称。有了这些信息,咱们就能够很容易写出推断方法。函数
了解tensflow的人可能知道TensorBoard是一个很是强大的工具,可以显示不少模型信息,其中包括计算图。问题是,TensorBoard须要模型训练时的log,若是这个SavedModel模型是别人训练好的呢?办法也不是没有,咱们能够写一段代码,加载这个模型,而后输出summary info,代码以下:工具
import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
with tf.Session() as sess:
model_filename ='./model/saved_model.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
if 1 != len(sm.meta_graphs):
print('More than one graph found. Not sure which to write')
sys.exit(1)
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
LOGDIR='./logdir'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
train_writer.flush()
train_writer.close()
复制代码
代码中,将汇总信息输出到logdir,接着启动TensorBoard,加上上面的logdir:
tensorboard --logdir ./logdir
复制代码
在浏览器中输入地址: http://127.0.0.1:6006/ ,就能够看到以下的计算图:
按照前面两种方法,咱们能够对Tensorflow SavedModel格式的模型有比较全面的了解,即便模型训练者并无给出文档。有了这些模型信息,相信你写出使用模型进行推断更加容易。
本文的完整代码请参考:github.com/mogoweb/aie…
但愿这篇文章对您有帮助,感谢阅读!