1、tensorflow模型跨平台的方案
tensorflow模型的跨平台上线的备选方案一般有三种:PMML方式、tensorflow serving方式以及跨语言API方式。
- PMML方式:与普通机器学习模型通过PMML上线方式一致,唯一的区别是转化生成PMML文件需要用一个Java库jpmml-tensorflow来完成,生成PMML文件后,跨语言加载模型和其他PMML模型文件基本类似。
- tensorflow serving:该方式是tensorflow官方推荐的模型上线预测方式,它需要一个专门的tensorflow服务器,用来提供预测的API服务。如果你的模型和对应的应用是比较大规模的,那么使用tensorflow serving是比较好的使用方式。但是它也有一个缺点,就是比较笨重,如果你要使用tensorflow serving,那么需要自己搭建serving集群并维护这个集群。所以为了一个小的应用去做这个工作,有时候会觉得麻烦。
- 跨语言API:该方式是本文要讨论的方式,它会用tensorflow自己的Python API生成模型文件,然后用tensorflow的客户端库比如Java或C++库来做模型的在线预测
2、tensorflow模型保存的方式和跨平台加载
tensorflow训练模型保存模型有两种方式:
- 以checkpoint方式保存模型文件及其参数(多个文件);
- 以pb固化图的方式保存模型文件和参数(只有一个文件)。
一般大部分训练模型的时候都是以ckpt的方式保存模型,本文不详细说明。接下来主要讲解如何将模型以pb的形式保存,且通过跨平台的方式来调用训练好的模型预测。可以通过以下两种方式将图保存为pb的形式:1
2
3
4
5
6
7
8
9
10output_graph_def = tf.graph_util.convert_variables_to_constants(sess=sess,
input_graph_def=input_graph_def,
output_node_names=["output"])
#方式1
tf.train.write_graph(output_graph_def,".",output_graph_path,as_text=False)
#方式2
with tf.gfile.GFile("./rf03.pb","wb") as f:
f.write(output_graph_def.SerializeToString())
input_graph_def
是指的训练过程中的图,不带参数,以上两种方式固化图的结果是一致的。
2.1、模型训练保存完整的demo
以下给出一个完整的列子:
1 | from sklearn.datasets.samples_generator import make_classification |
2.2、python的api加载pb文件并预测
使用python加载.pb
文件预测demo1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46import tensorflow as tf
import numpy as np
def load_graph(frozen_graph_file_name):
with tf.gfile.GFile(frozen_graph_file_name,"rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="prefix",
op_dict=None,producer_op_list=None
)
return graph
def show_nodes_graph(graph):
for op in graph.get_operations():
print(op.name,op.values)
def get_input():
inputs = [[0.0 for i in range(0,6)] for j in range(0,4)]
for i in range(0,4):
for j in range(0,6):
if i<2:
inputs[i][j]=2*i-5*j-6
else:
inputs[i][j] = 2 * i + 5 * j - 6
return np.array(inputs)
def predict(path="./ckpt/rf02.pb"):
graph = load_graph(path)
inputs = get_input()
with tf.Session(graph=graph) as sess:
outputs = sess.run("prefix/output:0",feed_dict={"prefix/input:0":inputs})
print(type(outputs))
print(list(outputs))
if __name__=="__main__":
# graph = load_graph("./ckpt/rf02.pb")
# show_nodes_graph(graph)
predict("./ckpt/rf02.pb")
2.3、Java的api加载pb文件并预测
Java的api加载.pb
文件并预测
首先maven的依赖中加入以下内容:1
2
3
4
5
6<!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.10.0</version>
</dependency>
接下来Java加载模型和预测代码如下1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58package com.tensorflowonline;
import org.tensorflow.*;
import org.tensorflow.Graph;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
public class OnlineDemo {
public static void main(String[] args){
String modelPath="/Users/xiachi/PycharmProjects/credit_card/WorkSpace/learn_dir/online_code/rf03.pb";
byte[] graphDef = loadTensoflowModel(modelPath);
float[][] inputs = new float[4][6];
for(int i=0;i<4;i++){
for(int j=0;j<6;j++){
if(i<2){
inputs[i][j]=2*i-5*j-6;
}else{
inputs[i][j]=2*i+5*j-6;
}
}
}
Tensor<Float> input = covertArrayToTensor(inputs);
Graph graph = new Graph();
graph.importGraphDef(graphDef);
Session session = new Session(graph);
Tensor result = session.runner().feed("input",input).fetch("output").run().get(0);
long[] rshape = result.shape();
int rs = (int) rshape[0];
long realResult[] = new long[rs];
result.copyTo(realResult);
for(long a :realResult){
System.out.println(a);
}
}
private static Tensor<Float> covertArrayToTensor(float[][] inputs){
return Tensors.create(inputs);
}
private static byte[] loadTensoflowModel(String path) {
try {
return Files.readAllBytes(Paths.get(path));
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
}
3、将checkpoint文件转为pb文件
以上给出了在训练的时候保存模型为pb文件,且给出了python和java两种调用pb预测的方式。如果我们已经训练完了,其实不用在运行源代码那么麻烦,只需要通过以下方式直接将checkpoint
形式的模型文件转为pb
形式的文件。
checkpoint
模型文件中一般有如图以下几个文件:
.meta
文件保存的是图结构,通俗地讲就是神经网络的网络结构。一般而言网络结构是不会发生改变,所以可以只保存一个就行了;.data
文件是数据文件,保存的是网络的权值,偏置,操作等等;.index
是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等;- checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;
拿到了这些文件,然后通过以下步骤完成pb文件的保存:
1、通过传入CKPT模型的路径得到模型的图和变量数据
2、通过 import_meta_graph 导入模型中的图
3、通过 saver.restore 从模型中恢复图中各个变量的数据
4、通过graph_util.convert_variables_to_constants将模型持久化
完整的demo如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36import tensorflow as tf
"""
1、通过传入CKPT模型的路径得到模型的图和变量数据
2、通过 import_meta_graph 导入模型中的图
3、通过 saver.restore 从模型中恢复图中各个变量的数据
4、通过graph_util.convert_variables_to_constants将模型持久化
"""
output_graph_path = "./ckpt/rf02.pb"
checkpoint_prefix=tf.train.latest_checkpoint("./ckpt/")
def freeze_graph():
output_node_name="ouput"
meta_file_path = checkpoint_prefix+".meta"
saver = tf.train.import_meta_graph(meta_file_path,clear_devices=True)
graph = tf.get_default_graph() #获得默认图
input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前图
with tf.Session() as sess:
saver.restore(sess,checkpoint_prefix)
output_graph_def = tf.graph_util.convert_variables_to_constants(sess=sess,
input_graph_def=input_graph_def,
output_node_names=["output"])
#使用tf.train.write_graph固化
# tf.train.write_graph(output_graph_def,".","./rf02.pb",as_text=False)
#使用tf.gfile来固化模型文件
with tf.gfile.GFile(output_graph_path,"wb") as f:
f.write(output_graph_def.SerializeToString())
if __name__=="__main__":
freeze_graph()
4、总结
对于tensorflow来说,模型上线一般选择tensorflow serving或者client API库来上线,前者适合于较大的模型和应用场景,后者则适合中小型的模型和应用场景。因此算法工程师使用在产品之前需要做好选择和评估。