【TensorFlow源码系列】【零】使用TensorFlow C++ 接口进行模型推理

#include <string>
#include <vector>
#include <iostream>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"

//using namespace std;
//using namespace tensorflow;

int main(int argc,char **argv)
{
	// 1. 建立session
	Session * session;
	Status status = NewSession(SessionOptions(),&session);
    
	// 2. 模型路径
	string model_path = "mnist.pb";
	
	// 3. 将pb原始模型导入到GraphDef中
	GraphDef graphdef;
    status = ReadBinaryProto(Env::Default(),model_path,&graphdef);
	
	if(!status.ok()){
		
		return 0;
	}
	
	// 4. 将原始模型加载到session中
	status = session->Create(graphdef);
	
	if(!status.ok()){
		
		return 0;
	}
	
	// 5. 建立输入输出tensor
	std::vector<std::pair<std::string,tensorflow::Tensor>> inputs;
	std::vector<tensorflow::Tensor> outputs;
	
	tensorflow::Tensor input_tensor(DT_FLOAT,tensorflow::TesorShape({1,28,28,1}));
	
	// 6. 获取输入tensor指针,向里面填写数据
	auto plane_tensor = input_tensor.tensor<float,4>();
	
	for(int n = 0; n < 1; ++n)
		for(int h = 0 ; h < 28; ++h)
			for(int w = 0; w < 28; ++w)
				for(int c = 0; c < 1; ++c){
					plane_tensor(n,h,w,c) = 1.0f;
				}
	inputs.push_back({"inputs",input_tensor});
	
	// 7. 运行模型,须要传递输入tensor,输出tensor,输出tensor的name ---softmax
	status = session->Run(inputs,{"softmax"},{},&outputs);
	if(!status.ok()){
		
		return 0;
	}
	
	// 8. 计算完成后,将计算结果从outputtensor中取出来
	auto out_tensor = out_tensor[0].tensor<float,2>();
	for(int n = 0; n < 1; ++n)
		for(int h = 0 ; h < 10; ++h){
					std::cout<<out_tensor(n,h)<<std::endl;
				}
	
	return 0;
}

后续源码分析,会基于这个主体流程做分析。ios

相关文章
相关标签/搜索