假如想要在ARM板上用
根据前面所说,
tensorflow lite
,那么意味着必需要把PC上的模型生成
tflite
文件,而后在ARM上导入这个
tflite
文件,经过解析这个文件来进行计算。
根据前面所说,
tensorflow
的全部计算都会在内部生成一个图,包括变量的初始化,输入定义等,那么即使不是通过训练的神经网络模型,只是简单的三角函数计算,也能够生成一个
tflite
模型用于在
tensorflow lite
上导入。因此,这里我就只作了简单的
sin()
计算来跑一编这个流程。
生成tflite
模型
这部分主要是调用TFLiteConverter
函数,直接生成tflite
文件,再也不经过pb
文件转化。
先上代码:python
import numpy as np import time import math import tensorflow as tf SIZE = 1000 X = np.random.rand(SIZE, 1) X = X*(math.pi/2.0) start = time.time() x1 = tf.placeholder(tf.float32, [SIZE, 1], name='x1-input') x2 = tf.placeholder(tf.float32, [SIZE, 1], name='x2-input') y1 = tf.sin(x1) y2 = tf.sin(x2) y = y1*y2 with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) converter = tf.lite.TFLiteConverter.from_session(sess, [x1, x2], [y]) tflite_model = converter.convert() open("/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite", "wb").write(tflite_model) end = time.time() print("2nd ", str(end - start))
转化函数
主要遇到的问题是
主要遇到的问题是
tensorflow
的变化实在太快,这些个转化函数一直在变。位置也一直在变,如今参考
官方文档,是按上面代码中调用,不然就会报找不到
lite
之类的错误。我如今PC上的
tensorflow
Python
版本是1.13,因此
lite
已经在
contrib
外面了,若是是之前的版本,要按文档中下面这样调用。
TensorFlow Version | Python API |
1.12 | tf.contrib.lite.TFLiteConverter |
1.9-1.11 | tf.contrib.lite.TocoConverter |
1.7-1.8 | tf.contrib.lite.toco_convert |
输入参数shape
git
原本在本文件中为了给定的输入数据大小自由,x1
,x2
的shape
会写成[None, 1]
,可是若是这样写,转化成tflite
模型后会默认为[1,1]
,并不能自由接收数据大小,因此在这里要指定大小SIZE
:github
x1 = tf.placeholder(tf.float32, [SIZE, 1], name='x1-input')
导入tflite
模型
原本这部分应该是在ARM板子上作的,可是为了验证tflite
文件的可用性,我先在PC的Python
上试验。先上代码:api
import tensorflow as tf import numpy as np import math import time SIZE = 1000 X = np.random.rand(SIZE, 1, ).astype(np.float32) X = X*(math.pi/2.0) start = time.time() interpreter = tf.lite.Interpreter(model_path="/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite") interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() interpreter.set_tensor(input_details[0]['index'], X) interpreter.set_tensor(input_details[1]['index'], X) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) end = time.time() print("1st ", str(end - start))
首先根据
用
输入参数类型
tflite
文件生成解析器,而后用
allocate_tensors()
分配内存。将输入经过
set_tensor
传入,而后调用
invoke()
来真正运行。最后获得输出。
用
Python
跑的时候能够很清楚的看到
input_details
的数据结构。官方的例子是只传入一个数据,因此只须要取
input_details[0]
,而我传入了2个输入,因此须要设置2个。同时能够看到
input_details
的2个数据的名字都是我在以前设置的
x1-input
和
x2-input
,这样很是好理解。
这里有个坑是输入参数的类型必定要注意。我在生成模型的时候定义的输入参数类型是
tf.float32
,而在导入的时候若是直接是
X = np.random.rand(SIZE, 1, )
的话,会报错:
ValueError: Cannot set tensor: Got tensor of type 0 but expected type 1 for input 3
这里把经过astype(np.float32)
把输入参数指定为float32
就OK了。网络
- 操做不支持的坑
能够从前面的代码里看到我写了两个sin()
,其实一开始是一个sin()
一个cos()
的,可是好像默认的tflite
模型不支持cos()
操做,没法生成,因此我只好暂时先只写sin()
,后面再研究怎么把cos()
加上。