当一个TensorFlow模型训练出来的时候,为了投入到实际应用,因此就须要部署到服务器上。因为我本次所作的项目是一个javaweb的图像识别项目。全部我就想去寻找一下java调用TensorFlow训练模型的办法。前端
因为TensorFlow好久没更新的缘故,网上的博客大都是18/19年的,而且是基于TensorFlow1.0的,对于如今使用的TensorFlow2.0不太友好。java
下面我简述一下TensorFlow1.0时期的方法:python
须要将训练的.h5模型转换成.pb模型,而且须要本身定义.pb模型的输入输出参数。(pb模型是一种基于动态图的模型)web
pb的生成代码冗长、并且对初学者真滴不太友好json
相比之下.h5模型的生成代码就一行flask
此外,这个生成pb模型的代码是否能照搬使用,仍是一个问题,而且还可能报一些奇奇怪怪的错误。api
查阅资料发现java上的TensorFlow的jar包都是TensorFlow1.0的服务器
现状:app
而且maven官网上的TensorFlow2.0的api已经更名成了tensorflow-core-api,而且网上相关方面的教程十分难找。因为网上都是导入的1.0的包,本身导入2.0的包以后,详细的调用教程能够说是没有。从上面也能够看出来TensorFlow对java的调用也不怎么重视了。因此这又给学习的途中徒增了不少困难。框架
用java直接调用训练好的模型很困难,那么咱们想办法让java调用python脚本,让python脚本去调用.h5模型会不会更简单呢?
代码以下
package com.guard.service; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; public class api_service { public String recognize(String path){ //此处的path是图片路径 Process proc; String res = null; try { System.out.println("接受到的参数"+path); String[] cmd = new String[] { "python", "E:\\machine_learning\\predict.py", path}; proc = Runtime.getRuntime().exec(cmd); BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream())); String line = null; while ((line = in.readLine()) != null) { System.out.println(line); res = line; } in.close(); proc.waitFor(); } catch (IOException e) { e.printStackTrace(); } catch (InterruptedException e) { e.printStackTrace(); } System.out.println(res+">>>>>>>>>>>"); return res; } }
可是咱们能够看出,这个实际上是用java在win上跑了这样一个指令
虽然这个确实是一个好办法,可是这个路径参数须要事先知道服务器上的路径,而且在协做开发的时候,每一个人的路径和环境就不一样,虽然该方法能用,可是我认为还不够好。
咱们能够直接用python的flask框架,直接生成一个api接口,就能够远程直接调用TensorFlow训练好的模型进行结果预测。
我的认为,这种方法相较于用java调用命令行,这种方法仍是更加直观的
而且flask仅仅须要加个@app.route的注解就能实现,可谓是十分方便
下面是模型调用代码
model.py
import glob import sys import os import cv2 import numpy as np import tensorflow as tf import image_processing def model_ues(path): # 缩放图片大小为100*100 w = 100 h = 100 # 测试图像的地址 (改成本身的) # path_test = "resource/test24.jpg" api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda" path_test = image_processing.download_img(path,api_token) # 建立保存图像的空列表 imgs = [] img = cv2.imread(path_test) img = cv2.resize(img, (w, h)) # 将每张通过处理的图像数据保存在以前建立的imgs空列表当中 imgs.append(img) imgs = np.asarray(imgs, np.float32) # print("shape of data:",imgs.shape) # 导入模型 model = tf.keras.models.load_model(r"resource/rice_0.93.h5") # 建立图像标签列表 rice_dict = {0: 'Rice blast', 1: 'Rice fleck', 2: 'Rice koji disease', 3: 'Sheath blight'} # 将图像导入模型进行预测 prediction = model.predict_classes(imgs) # prediction = np.argmax(model.predict(imgs), axis=-1) # 绘制预测图像 for i in range(np.size(prediction)): # 打印每张图像的预测结果 print(rice_dict[prediction[i]]) return rice_dict[prediction[0]]
为了实现图片外连接受,下面是图片下载脚本
image_processing.py
# coding: utf8 import requests import random def download_img(img_url, api_token): print (img_url) header = {"Authorization": "Bearer " + api_token} # 设置http header,视状况加须要的条目,这里的token是用来鉴权的一种方式 r = requests.get(img_url, headers=header, stream=True) print(r.status_code) # 返回状态码 file_img = 'resource/img.png' # file_img = 'resource/' print(file_img) if r.status_code == 200: open(file_img, 'wb').write(r.content) # 将内容写入图片 print("done") del r return file_img # if __name__ == '__main__': # # 下载要的图片 # img_url = "https://z3.ax1x.com/2021/07/27/W5l6Qe.png" # api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda" # download_img(img_url, api_token)
主程序脚本
app.py
from flask import Flask,render_template, url_for, request, json,jsonify import model app = Flask(__name__) #设置编码 app.config['JSON_AS_ASCII'] = False @app.route('/test') def hello_world(): return "hello world" @app.route('/predict', methods=['GET', 'POST']) def form_data(): my_path = request.form['path'] print(my_path) str = model.model_ues(my_path) print("http://127.0.0.1:5000/predict") return jsonify({'result':str,'msg':'200'}) if __name__ == '__main__': app.run()
虽然咱们可以经过postman进行测试接受到回传的结果,可是咱们要怎么用java实现呢??
1.使用postman生成大体代码框架(postman生成的代码可能不能直接运行)
这里我选用的是java-okhttp的方法,但其实使用Unirest写出来的代码更加简洁易懂。
public class Get_result { public String getResult(String path) throws IOException { // String path = "https://i.loli.net/2021/07/29/badDNR2OCironUf.jpg"; OkHttpClient client = new OkHttpClient().newBuilder() .build(); MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded"); RequestBody body = RequestBody.create(mediaType, "path="+path); Request request = new Request.Builder() .url("http://127.0.0.1:8000/predict") .method("POST", body) .addHeader("Content-Type", "application/x-www-form-urlencoded") .build(); Response response = client.newCall(request).execute(); String result = response.body().string(); System.out.println(result); } }
{ "msg": "200", "result": "Rice fleck" }
获取到json数据以后,就须要对json数据进行解析
java上的解析原理是,先按照json编写一个类,以后用Gson对接受到的数据按照这个类进行规范化
(这里能够用GsonFormatPlus插件来自动生成这个实体类)
//Rice_result.java---为该json的实体类 package com.guard.tool; import lombok.Data; import lombok.NoArgsConstructor; @NoArgsConstructor @Data public class Rice_result { private String msg; private String result; }
下面是数据解析代码(和上面的okhttp获取json数据的代码连起来看)
//json数据解析 Gson gson = new Gson(); java.lang.reflect.Type type = new TypeToken<Rice_result>(){}.getType(); Rice_result rice_result = gson.fromJson(result, type); System.out.println(rice_result); if("200".equals(rice_result.getMsg())){ // System.out.println(rice_result.getResult()); return Rice_result.convertdata(rice_result.getResult()); }else { // System.out.println("获取结果出错!!"); return "获取结果出错!!"; }
这样的话就能够进行json数据的解析了。
因为须要使用java发送post请求给flask的预测端口,那么就须要把本地上传的数据作成图链,把图链做为数据传给flask的预测端口,从而来接收结果。
因为前端js的知识大多遗忘,这里就选用了用java来发送一个post请求,得到回传的信息。
这里我使用的是sm.ms的图床(该图床无需登陆,且速度快,算得上是一个好的选择)
//sm.ms的使用方法,建议看官方文档 package com.guard.tool; import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; import okhttp3.*; import java.io.File; import java.io.IOException; public class CloudUpload { public String toUrl(String path) throws IOException { // String file_path = "E:/machine_learning/test8.jpg"; String file_path = path; OkHttpClient client = new OkHttpClient().newBuilder() .build(); MediaType mediaType = MediaType.parse("multipart/form-data"); RequestBody body = new MultipartBody.Builder().setType(MultipartBody.FORM) .addFormDataPart("smfile",file_path, RequestBody.create(MediaType.parse("application/octet-stream"), new File(file_path))) .addFormDataPart("format","json") .build(); Request request = new Request.Builder() .url("https://sm.ms/api/v2/upload") .method("POST", body) .addHeader("Content-Type", "multipart/form-data") .addHeader("Authorization", "TlxzRSaVJj0o7HFZOd9sgdf4Jl60RA00") //这里的user-agent和Cookie须要本身打开网站,到网站的页面去拿取 .addHeader("user-agent","Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36") .addHeader("Cookie", "SMMSrememberme=42417%3A10e8e9cb5281082b493fdee73381aeb2dca0bd3d; PHPSESSID=1gjog2em3ogof23vrqi79vd41m; SM_FC=runWNk3mPIiL8mzl%2FrlEfzM940LRKjLm182cm2qDrm4%3D") .build(); Response response = client.newCall(request).execute(); String result = response.body().string(); System.out.println(result); // String result = response.body().string(); Gson gson = new Gson(); java.lang.reflect.Type type = new TypeToken<Image_data>(){}.getType(); Image_data imge_data = gson.fromJson(result, type); System.out.println(imge_data); if (imge_data.getSuccess()){ System.out.println(imge_data.getData().getUrl()); return imge_data.getData().getUrl(); } else{ System.out.println("图片已经上传过一次!!"); System.out.println(imge_data.getImages()); return imge_data.getImages(); } } }
回传的json结果--这个就须要使用上面的插件来进行处理
{ "success": true, "code": "success", "message": "Upload success.", "data": { "file_id": 0, "width": 192, "height": 454, "filename": "test25.jpg", "storename": "xICPNzFsfth5uJk.png", "size": 124993, "path": "/2021/08/01/xICPNzFsfth5uJk.png", "hash": "2exIdQGvBru46RKMyNjg3DhCTO", "url": "https://i.loli.net/2021/08/01/xICPNzFsfth5uJk.png", "delete": "https://sm.ms/delete/2exIdQGvBru46RKMyNjg3DhCTO", "page": "https://sm.ms/image/xICPNzFsfth5uJk" }, "RequestId": "9BFE9DEB-8370-44C8-A8AF-AAB2DB753A18" }
以上就是我此次在小组编写<基于CNN图像分类的水稻病虫害识别>这个项目中的收获。在此记录下学习路上踩过的一些坑和一些解决方法。