From 355fb5e14b325a1d106c4046f478da4bda350205 Mon Sep 17 00:00:00 2001 From: Yifei Feng Date: Mon, 5 Mar 2018 13:47:30 -0800 Subject: [PATCH] Fix the issue where gpu_option is not respected for keras estimator. Set keras backend session with the given config before any get_session call creates a new session. Fix #14776. PiperOrigin-RevId: 187916300 --- tensorflow/python/keras/_impl/keras/estimator.py | 6 +++++- tensorflow/python/keras/_impl/keras/estimator_test.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py index 0bf5bd4..5697771 100644 --- a/tensorflow/python/keras/_impl/keras/estimator.py +++ b/tensorflow/python/keras/_impl/keras/estimator.py @@ -296,10 +296,14 @@ def model_to_estimator(keras_model=None, 'Given keras model has not been compiled yet. Please compile first ' 'before creating the estimator.') - keras_weights = keras_model.get_weights() keras_model_fn = _create_keras_model_fn(keras_model, custom_objects) est = estimator_lib.Estimator( keras_model_fn, model_dir=model_dir, config=config) + # Pass the config into keras backend's default session. + with session.Session(config=est._session_config) as sess: + K.set_session(sess) + + keras_weights = keras_model.get_weights() # TODO(yifeif): move checkpoint initialization to scaffold.init_fn _save_first_checkpoint(keras_model, est, custom_objects, keras_weights) return est diff --git a/tensorflow/python/keras/_impl/keras/estimator_test.py b/tensorflow/python/keras/_impl/keras/estimator_test.py index 88dd14b..a9de5dd 100644 --- a/tensorflow/python/keras/_impl/keras/estimator_test.py +++ b/tensorflow/python/keras/_impl/keras/estimator_test.py @@ -24,6 +24,7 @@ import tempfile import numpy as np +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import test_util @@ -377,6 +378,22 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir)) + def test_gpu_config(self): + keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['mse', keras.metrics.categorical_accuracy]) + + gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3) + sess_config = config_pb2.ConfigProto(gpu_options=gpu_options) + self._config._session_config = sess_config + keras.estimator.model_to_estimator( + keras_model=keras_model, config=self._config) + self.assertEqual(keras.backend.get_session() + ._config.gpu_options.per_process_gpu_memory_fraction, + gpu_options.per_process_gpu_memory_fraction) + if __name__ == '__main__': test.main() -- 2.7.4