tensorflow模型保存与跨平台上线

1、tensorflow模型跨平台的方案

tensorflow模型的跨平台上线的备选方案一般有三种:PMML方式、tensorflow serving方式以及跨语言API方式。

  • PMML方式:与普通机器学习模型通过PMML上线方式一致,唯一的区别是转化生成PMML文件需要用一个Java库jpmml-tensorflow来完成,生成PMML文件后,跨语言加载模型和其他PMML模型文件基本类似。
  • tensorflow serving:该方式是tensorflow官方推荐的模型上线预测方式,它需要一个专门的tensorflow服务器,用来提供预测的API服务。如果你的模型和对应的应用是比较大规模的,那么使用tensorflow serving是比较好的使用方式。但是它也有一个缺点,就是比较笨重,如果你要使用tensorflow serving,那么需要自己搭建serving集群并维护这个集群。所以为了一个小的应用去做这个工作,有时候会觉得麻烦。
  • 跨语言API:该方式是本文要讨论的方式,它会用tensorflow自己的Python API生成模型文件,然后用tensorflow的客户端库比如Java或C++库来做模型的在线预测

2、tensorflow模型保存的方式和跨平台加载

tensorflow训练模型保存模型有两种方式:

  • 以checkpoint方式保存模型文件及其参数(多个文件);
  • 以pb固化图的方式保存模型文件和参数(只有一个文件)。

一般大部分训练模型的时候都是以ckpt的方式保存模型,本文不详细说明。接下来主要讲解如何将模型以pb的形式保存,且通过跨平台的方式来调用训练好的模型预测。可以通过以下两种方式将图保存为pb的形式:

1
2
3
4
5
6
7
8
9
10
output_graph_def = tf.graph_util.convert_variables_to_constants(sess=sess,
input_graph_def=input_graph_def,
output_node_names=["output"])

#方式1
tf.train.write_graph(output_graph_def,".",output_graph_path,as_text=False)

#方式2
with tf.gfile.GFile("./rf03.pb","wb") as f:
f.write(output_graph_def.SerializeToString())

input_graph_def是指的训练过程中的图,不带参数,以上两种方式固化图的结果是一致的。

2.1、模型训练保存完整的demo

以下给出一个完整的列子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from sklearn.datasets.samples_generator import make_classification
import tensorflow as tf
X1, y1 = make_classification(n_samples=4000, n_features=6, n_redundant=0,
n_clusters_per_class=1, n_classes=3)
def train():
#参数部分
learning_rate = 0.01
training_epochs = 600
batch_size = 100

#模型部分
x = tf.placeholder(tf.float32, [None, 6], name='input') # 6 features
y = tf.placeholder(tf.float32, [None, 3]) # 3 classes

W = tf.Variable(tf.zeros([6, 3]))
b = tf.Variable(tf.zeros([3]))

# softmax回归
pred = tf.nn.softmax(tf.matmul(x, W) + b, name="softmax")
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

prediction_labels = tf.argmax(pred, axis=1, name="output")


#训练部分
#如果模型文件存在继续训练就需要restore
restore_flag=False
saver = tf.train.Saver(tf.global_variables())
sess = tf.Session()
if restore_flag:
saver.restore(sess, tf.train.latest_checkpoint("./ckpt/"))
else:
init = tf.global_variables_initializer()
sess.run(init)



y2 = tf.one_hot(y1, 3)
y2 = sess.run(y2)

for epoch in range(training_epochs):

_, c = sess.run([optimizer, cost], feed_dict={x: X1, y: y2})
if (epoch + 1) % 10 == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c))

print("优化完毕!")

correct_prediction = tf.equal(prediction_labels, tf.argmax(y2, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = sess.run(accuracy, feed_dict={x: X1, y: y2})
print(acc)


saver.save(sess,"./ckpt/model")

#图固化
graph = tf.graph_util.convert_variables_to_constants(sess,input_graph_def= sess.graph_def,output_node_names=["output"])
tf.train.write_graph(graph, '.', 'rf.pb', as_text=False)

with tf.gfile.GFile("rf01.pb","wb") as f:
f.write(graph.SerializeToString())



if __name__=="__main__":
train()

2.2、python的api加载pb文件并预测

使用python加载.pb文件预测demo

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
import numpy as np

def load_graph(frozen_graph_file_name):
with tf.gfile.GFile(frozen_graph_file_name,"rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="prefix",
op_dict=None,producer_op_list=None
)
return graph

def show_nodes_graph(graph):
for op in graph.get_operations():
print(op.name,op.values)

def get_input():
inputs = [[0.0 for i in range(0,6)] for j in range(0,4)]
for i in range(0,4):
for j in range(0,6):
if i<2:
inputs[i][j]=2*i-5*j-6
else:
inputs[i][j] = 2 * i + 5 * j - 6
return np.array(inputs)

def predict(path="./ckpt/rf02.pb"):
graph = load_graph(path)
inputs = get_input()

with tf.Session(graph=graph) as sess:
outputs = sess.run("prefix/output:0",feed_dict={"prefix/input:0":inputs})
print(type(outputs))
print(list(outputs))


if __name__=="__main__":
# graph = load_graph("./ckpt/rf02.pb")
# show_nodes_graph(graph)
predict("./ckpt/rf02.pb")

2.3、Java的api加载pb文件并预测

Java的api加载.pb文件并预测
首先maven的依赖中加入以下内容:

1
2
3
4
5
6
<!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.10.0</version>
</dependency>

接下来Java加载模型和预测代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package com.tensorflowonline;
import org.tensorflow.*;
import org.tensorflow.Graph;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
public class OnlineDemo {

public static void main(String[] args){
String modelPath="/Users/xiachi/PycharmProjects/credit_card/WorkSpace/learn_dir/online_code/rf03.pb";
byte[] graphDef = loadTensoflowModel(modelPath);
float[][] inputs = new float[4][6];


for(int i=0;i<4;i++){
for(int j=0;j<6;j++){
if(i<2){
inputs[i][j]=2*i-5*j-6;
}else{
inputs[i][j]=2*i+5*j-6;
}
}
}

Tensor<Float> input = covertArrayToTensor(inputs);
Graph graph = new Graph();
graph.importGraphDef(graphDef);

Session session = new Session(graph);
Tensor result = session.runner().feed("input",input).fetch("output").run().get(0);

long[] rshape = result.shape();
int rs = (int) rshape[0];
long realResult[] = new long[rs];

result.copyTo(realResult);
for(long a :realResult){
System.out.println(a);
}
}




private static Tensor<Float> covertArrayToTensor(float[][] inputs){
return Tensors.create(inputs);
}

private static byte[] loadTensoflowModel(String path) {
try {
return Files.readAllBytes(Paths.get(path));
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
}

3、将checkpoint文件转为pb文件

以上给出了在训练的时候保存模型为pb文件,且给出了python和java两种调用pb预测的方式。如果我们已经训练完了,其实不用在运行源代码那么麻烦,只需要通过以下方式直接将checkpoint形式的模型文件转为pb形式的文件。
checkpoint模型文件中一般有如图以下几个文件:
ckpt模型文件

  • .meta文件保存的是图结构,通俗地讲就是神经网络的网络结构。一般而言网络结构是不会发生改变,所以可以只保存一个就行了;
  • .data文件是数据文件,保存的是网络的权值,偏置,操作等等;
  • .index是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等;
  • checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;

拿到了这些文件,然后通过以下步骤完成pb文件的保存:
1、通过传入CKPT模型的路径得到模型的图和变量数据
2、通过 import_meta_graph 导入模型中的图
3、通过 saver.restore 从模型中恢复图中各个变量的数据
4、通过graph_util.convert_variables_to_constants将模型持久化

完整的demo如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import tensorflow as tf
"""
1、通过传入CKPT模型的路径得到模型的图和变量数据
2、通过 import_meta_graph 导入模型中的图
3、通过 saver.restore 从模型中恢复图中各个变量的数据
4、通过graph_util.convert_variables_to_constants将模型持久化
"""


output_graph_path = "./ckpt/rf02.pb"
checkpoint_prefix=tf.train.latest_checkpoint("./ckpt/")

def freeze_graph():
output_node_name="ouput"
meta_file_path = checkpoint_prefix+".meta"
saver = tf.train.import_meta_graph(meta_file_path,clear_devices=True)
graph = tf.get_default_graph() #获得默认图
input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前图

with tf.Session() as sess:
saver.restore(sess,checkpoint_prefix)

output_graph_def = tf.graph_util.convert_variables_to_constants(sess=sess,
input_graph_def=input_graph_def,
output_node_names=["output"])

#使用tf.train.write_graph固化
# tf.train.write_graph(output_graph_def,".","./rf02.pb",as_text=False)

#使用tf.gfile来固化模型文件
with tf.gfile.GFile(output_graph_path,"wb") as f:
f.write(output_graph_def.SerializeToString())


if __name__=="__main__":
freeze_graph()

4、总结

对于tensorflow来说,模型上线一般选择tensorflow serving或者client API库来上线,前者适合于较大的模型和应用场景,后者则适合中小型的模型和应用场景。因此算法工程师使用在产品之前需要做好选择和评估。

5、参考