部署tf_keras模型遇到的坑

Author Avatar
patrickcty 12月 14, 2019

一次 load 多次 predict

由于 TensorFlow 使用的是静态图,因此默认并不能做到一次 load 然后就任意在其他的地方调用,因此需要以下方法来做到:

# load
model = load_model('path_to_model')
graph = tf.get_default_graph()

# predict, maybe in another thread
global graph
with graph.as_default():
    model.predict()

限制不要占满 GPU 显存

TensorFlow 默认的机制是会占用满 GPU,但是可以通过设置 allow_growth 来只使用必须的显存:

import tensorflow as tf
from tensorflow.python.keras.backend import set_session

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
set_session(sess)  # 必须要先设置 session 才能 load model

进行这样设置之后如果要在多个不同情况下调用模型则也必须每次都设置 session:

# predict, maybe in another thread
global graph
global sess
with graph.as_default():
    set_session(sess)
    model.predict()

忽视掉 TensorFlow 的一大堆初始化 log

log level 为 3 的话就只会输出错误信息了

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel('ERROR')