这是当微信小程序赶上TensorFlow系列文章的第四篇文章,阅读本文,你将了解到:node
若是你想要了解更多关于本项目,能够参考这个系列的前三篇文章:python
关于Tensorflow SavedModel格式模型的处理,能够参考前面的文章:git
截至到目前为止,咱们实现了一个简单的微信小程序,使用开源的Simple TensorFlow Serving部署了服务端。但这种实现方案还存在一个重大问题:小程序和服务端通讯传递的图像数据是(299, 299, 3)二进制数组的JSON化表示,这种二进制数据JSON化的最大缺点是数据量太大,一个简单的299 x 299的图像,这样表示大约有3 ~ 4 M。其实HTTP传输二进制数据经常使用的方案是对二进制数据进行base64编码,通过base64编码,虽然数据量比二进制也会大一些,但相比JSON化的表示,仍是小不少。github
因此如今的问题是,如何让服务器端接收base64编码的图像数据?web
为了解决这一问题,咱们仍是先看看模型的输入输出,看看其签名是怎样的?这里的签名,并不是是为了保证模型不被修改的那种电子签名。个人理解是相似于编程语言中模块的输入输出信息,好比函数名,输入参数类型,输出参数类型等等。借助于Tensorflow提供的saved_model_cli.py工具,咱们能够清楚的查看模型的签名:编程
python ./tensorflow/python/tools/saved_model_cli.py show --dir /data/ai/workspace/aiexamples/AIDog/serving/models/inception_v3/ --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['image'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 299, 299, 3)
name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['prediction'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 120)
name: final_result:0
Method name is: tensorflow/serving/predict
复制代码
从中咱们能够看出模型的输入参数名为image,其shape为(-1, 299, 299, 3),这里-1表明能够批量输入,一般咱们只输入一张图像,因此这个维度一般是1。输出参数名为prediction,其shape为(-1, 120),-1和输入是对应的,120表明120组狗类别的几率。json
如今的问题是,咱们可否在模型的输入前面增长一层,进行base64及解码处理呢?小程序
也许你认为能够在服务器端编写一段代码,进行base64字符串解码,而后再转交给Simple Tensorflow Serving进行处理,或者修改Simple TensorFlow Serving的处理逻辑,但这种修改方案增长了服务器端的工做量,使得服务器部署方案再也不通用,放弃!微信小程序
其实在上一篇文章《如何合并两个TensorFlow模型》中咱们已经讲到了如何链接两个模型,这里再稍微重复一下,首先是编写一个base64解码、png解码、图像缩放的模型:api
base64_str = tf.placeholder(tf.string, name='input_string')
input_str = tf.decode_base64(base64_str)
decoded_image = tf.image.decode_png(input_str, channels=input_depth)
# Convert from full range of uint8 to range [0,1] of float32.
decoded_image_as_float = tf.image.convert_image_dtype(decoded_image,
tf.float32)
decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
resize_shape = tf.stack([input_height, input_width])
resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
resized_image = tf.image.resize_bilinear(decoded_image_4d,
resize_shape_as_int)
tf.identity(resized_image, name="DecodePNGOutput")
复制代码
接下来加载retrain模型:
with tf.Graph().as_default() as g2:
with tf.Session(graph=g2) as sess:
input_graph_def = saved_model_utils.get_meta_graph_def(
FLAGS.origin_model_dir, tag_constants.SERVING).graph_def
tf.saved_model.loader.load(sess, [tag_constants.SERVING], FLAGS.origin_model_dir)
g2def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
["final_result"],
variable_names_whitelist=None,
variable_names_blacklist=None)
复制代码
这里调用了graph_util.convert_variables_to_constants将模型中的变量转化为常量,也就是所谓的冻结图(freeze graph)操做。
利用tf.import_graph_def方法,咱们能够导入图到现有图中,注意第二个import_graph_def,其input是第一个graph_def的输出,经过这样的操做,就将两个计算图链接起来,最后保存起来。代码以下:
with tf.Graph().as_default() as g_combined:
with tf.Session(graph=g_combined) as sess:
x = tf.placeholder(tf.string, name="base64_string")
y, = tf.import_graph_def(g1def, input_map={"input_string:0": x}, return_elements=["DecodePNGOutput:0"])
z, = tf.import_graph_def(g2def, input_map={"Placeholder:0": y}, return_elements=["final_result:0"])
tf.identity(z, "myOutput")
tf.saved_model.simple_save(sess,
FLAGS.model_dir,
inputs={"image": x},
outputs={"prediction": z})
复制代码
若是你不知道retrain出来的模型的input节点是啥(注意不能使用模型部署的signature信息)?可使用以下代码遍历graph的节点名称:
for n in g2def.node:
print(n.name)
复制代码
注意,咱们能够将链接以后的模型保存在./models/inception_v3/2/目录下,原来的./models/inception_v3/1/也不用删除,这样两个版本的模型能够同时提供服务,方便从V1模型平滑过渡到V2版本模型。
咱们修改一下原来的test_client.py代码,增长一个model_version参数,这样就能够决定与哪一个版本的模型进行通讯:
with open(file_name, "rb") as image_file:
encoded_string = str(base64.urlsafe_b64encode(image_file.read()), "utf-8")
if enable_ssl :
endpoint = "https://127.0.0.1:8500"
else:
endpoint = "http://127.0.0.1:8500"
json_data = {"model_name": model_name,
"model_version": model_version,
"data": {"image": encoded_string}
}
result = requests.post(endpoint, json=json_data)
复制代码
通过一个多星期的研究和反复尝试,终于解决了图像数据的base64编码通讯问题。难点在于虽然模型是编写retrain脚本从新训练的,但这段代码不是那么好懂,想要在retrain时增长输入层也是尝试失败。最后从Tensorflow模型转Tensorflow Lite模型时的freezing graph获得灵感,将图中的变量固化为常量,才解决了合并模型变量加载的问题。虽然网上提供了一些恢复变量的方法,但实际用起来并无论用,多是Tensorflow发展太快,之前的一些方法已通过时了。
本文的完整代码请参阅:github.com/mogoweb/aie…
点击阅读原文能够直达在github上的项目。
到目前为止,关键的问题已经都解决,接下来就须要继续完善微信小程序的展示,以及如何提供识别率,敬请关注个人微信公众号:云水木石,获取最新动态。