【导读】<br /> 随着TensorFlow的普及,愈来愈多的行业但愿将Github中大量已有的TensorFlow代码和模型集成到本身的业务系统中,如何在常见的编程语言(Java、NodeJS等)中使用TensorFlow成为了一个比较常见的问题。专知成员Hujun给你们详细介绍了在Java中使用TensorFlow的两种方法,并着重介绍如何用TensorFlow官方Java API调用已有TensorFlow模型的方法。java
1. 直接使用TensorFlow官方API调用训练好的pb模型node
2. (推荐) 使用KerasServer托管TensorFlow/Keras代码及模型:python
虽然使用TensorFlow官方Java API能够直接对接训练好的pb模型,但在实际使用中,依然存在着与跨语种对接相关的繁琐代码。例如虽然已有使用Python编写好的基于TensorFlow的文本分类代码,但TensorFlow Java API的输入须要是量化的文本,这样咱们又须要用Java从新实如今Python代码中已经实现的分词、从字符串到索引的转换等预处理操做(这些操做同时依赖于Python代码依赖的单词表等数据)。另外,因为Java没有numpy支持,在构建多维数组做为输入时,使用的依然是相似循环的操做,很是繁琐。c++
KerasServer支持restful交互,所以能够支持用任何程序语言调用TensorFlow/ Keras。因为KerasServer的服务端提供Python API, 所以能够直接将已有的TensorFlow/Keras Python代码和模型转换为KerasServer API,供Java/c/c++/C#/ Python/ NodeJS/Browser Javascript等调用,而不须要再其余语种中进行繁琐的数据预处理操做。git
例如,Java可直接将须要分类的文本数据提交给KerasServer,KerasServer可利用已有的Python代码对字符串进行分词、预处理等操做。github
本教程介绍如何用TensorFlow官方Java API调用TensorFlow(Python)训练好的模型。教程的代码可在专知的Github项目中找到: https://github.com/ZhuanZhiCode/TensorFlow-Java-Examplesapache
#coding=utf-8 import tensorflow as tf # 定义图 x = tf.placeholder(tf.float32, name="x") y = tf.get_variable("y", initializer=10.0) z = tf.log(x + y, name="z") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 进行一些训练代码,此处省略 # xxxxxxxxxxxx # 显示图中的节点 print([n.name for n in sess.graph.as_graph_def().node]) frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=["z"]) # 保存图为pb文件 with open('model.pb', 'wb') as f: f.write(frozen_graph_def.SerializeToString())
通过上面的代码,pb模型成功保存了。接下来将使用java将该模型加载并运行起来编程
此处使用的是maven项目,因此须要先将tensorflow依赖加载上去api
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.5.0</version> </dependency> ....
模型的执行与Python相似,依然是导入图,创建Session,指定输入(feed)和输出(fetch)。数组
import org.apache.commons.io.IOUtils; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import java.io.FileInputStream; import java.io.IOException; public class DemoImportGraph { public static void main(String[] args) throws IOException { try (Graph graph = new Graph()) { //导入图 byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("model.pb")); graph.importGraphDef(graphBytes); //根据图创建Session try(Session session = new Session(graph)){ //至关于TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0}) float z = session.runner() //此处的'x'是模型的输入;'z'是模型的输出 .feed("x", Tensor.create(10.0f)) .fetch("z").run().get(0).floatValue(); System.out.println("z = " + z); } } } }
#tensorflow模型: import tensorflow as tf import os from tensorflow.python.framework import graph_util path = './model/' with tf.Session(graph=tf.Graph()) as sess: x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) # 这里的输出须要加上name属性 op = tf.add(xy, b, name='op_to_store') sess.run(tf.global_variables_initializer()) # convert_variables_to_constants 须要指定output_node_names,list(),能够多个 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store']) # 测试 OP feed_dict = {x: 10, y: 3} print(sess.run(op, feed_dict)) # 写入序列化的 PB 文件 with tf.gfile.FastGFile(path+'model_3.pb', mode='wb') as f: f.write(constant_graph.SerializeToString())
//java代码以下: import org.apache.commons.io.IOUtils; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import java.io.FileInputStream; import java.io.IOException; /** * Created on 2019-07-03 * * @author :hao.li */ public class reload_3 { public static void main(String[] args) throws IOException { try (Graph graph = new Graph()) { //导入图 byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("/Users/lixuewei/workspace/private/tensorflow-java/src/main/resources/model_3.pb")); graph.importGraphDef(graphBytes); //根据图创建Session try(Session session = new Session(graph)){ //至关于TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0}) Tensor<?> tensor = session.runner() .feed("x", Tensor.create(10)) .feed("y", Tensor.create(3)) .fetch("op_to_store").run().get(0); System.out.println(tensor.intValue()); } } } }
import tensorflow as tf import numpy as np import os tf.app.flags.DEFINE_integer('training_iteration', 302, 'number of training iterations.') tf.app.flags.DEFINE_integer('model_version', 1, 'version number of the model.') tf.app.flags.DEFINE_string('work_dir', 'model/', 'Working directory.') FLAGS = tf.app.flags.FLAGS sess = tf.InteractiveSession() x = tf.placeholder('float', shape=[None, 3],name="x") y_ = tf.placeholder('float', shape=[None, 1]) w = tf.get_variable('w', shape=[3, 1], initializer=tf.truncated_normal_initializer) b = tf.get_variable('b', shape=[1], initializer=tf.zeros_initializer) sess.run(tf.global_variables_initializer()) y = tf.add(tf.matmul(x, w) , b,name="y") ms_loss = tf.reduce_mean((y - y_) ** 2) train_step = tf.train.GradientDescentOptimizer(0.005).minimize(ms_loss) train_x = np.random.randn(1000, 3) # let the model learn the equation of y = x1 * 1 + x2 * 2 + x3 * 3 train_y = np.sum(train_x * np.array([1, 2, 3]) + np.random.randn(1000, 3) / 100, axis=1).reshape(-1, 1) train_loss = [] for i in range(FLAGS.training_iteration): loss, _ = sess.run([ms_loss, train_step], feed_dict={x: train_x, y_: train_y}) train_loss.append(loss) export_path_base = FLAGS.work_dir export_path = os.path.join( tf.compat.as_bytes(export_path_base), tf.compat.as_bytes(str(FLAGS.model_version))) print('Exporting trained model to', export_path) # SavedModelBuilder里面放的是保存模型的路径,以下的export_path builder = tf.saved_model.builder.SavedModelBuilder(export_path) tensor_info_x = tf.saved_model.utils.build_tensor_info(x) tensor_info_y = tf.saved_model.utils.build_tensor_info(y) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'input': tensor_info_x}, outputs={'output': tensor_info_y}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') #第二步必须要有,它是给你的模型贴上一个标签,这样再次调用的时候就能够根据标签来找。我给它起的标签名是"serve" builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ 'prediction': prediction_signature, }, legacy_init_op=legacy_init_op) builder.save() print('Training error %g' % loss) print('Done exporting!') print('Done training!')
import org.tensorflow.SavedModelBundle; public class TensorflowUtils { public static SavedModelBundle loadmodel(String modelpath){ SavedModelBundle bundle=SavedModelBundle.load(modelpath,"serve"); return bundle; } }
Main:
import org.tensorflow.SavedModelBundle; import org.tensorflow.Tensor; import java.util.Arrays; public class Model { SavedModelBundle bundle = null; public void init(){ String classpath=this.getClass().getResource("/").getPath()+"1" ; bundle=TensorflowUtils.loadmodel(classpath); } public double getResult(float[][] arr){ Tensor tensor=Tensor.create(arr); Tensor<?> result= bundle.session().runner().feed("x",tensor).fetch("y").run().get(0); float[][] resultValues = (float[][])result.copyTo(new float[1][1]); result.close(); return resultValues[0][0]; } public static void main(String[] args){ Model model =new Model(); model.init(); float[][] arr=new float[1][3]; arr[0][0]=1f; arr[0][1]=0.5f; arr[0][2]=2.0f; System.out.println(model.getResult(arr)); System.out.println(Arrays.toString("他".getBytes())); } }
Cannot find TensorFlow native library for OS: darwin, architecture: x86_64.
<br /> result:<br />
具体使用见代码:github
该flink与tensorflow相结合的方式,有如下几个问题: