From e606e9133e96caf00d60e2ac0eb3f308fd0a4758 Mon Sep 17 00:00:00 2001 From: Yifei Feng Date: Mon, 9 Apr 2018 12:09:34 -0700 Subject: [PATCH] Only set session in model_to_estimator if _SESSION has not been set. Fix #18193. PiperOrigin-RevId: 192164669 --- tensorflow/python/keras/_impl/keras/estimator.py | 13 ++++++++++++- .../python/keras/_impl/keras/estimator_test.py | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py index 5d370eb..8043242 100644 --- a/tensorflow/python/keras/_impl/keras/estimator.py +++ b/tensorflow/python/keras/_impl/keras/estimator.py @@ -26,6 +26,7 @@ from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import export as export_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib @@ -465,11 +466,21 @@ def model_to_estimator(keras_model=None, estimator = estimator_lib.Estimator( keras_model_fn, model_dir=model_dir, config=config) + old_session = K._SESSION # Pass the config into keras backend's default session. sess = session.Session(config=estimator._session_config) K.set_session(sess) + try: + keras_weights = keras_model.get_weights() + except errors.FailedPreconditionError as e: + if old_session is None: + raise e + logging.warning( + 'The Keras backend session has already been ' + 'set. The _session_config passed to model_to_estimator is not used.') + K.set_session(old_session) + keras_weights = keras_model.get_weights() - keras_weights = keras_model.get_weights() if keras_model._is_graph_network: # TODO(yifeif): move checkpoint initialization to scaffold.init_fn _save_first_checkpoint(keras_model, diff --git a/tensorflow/python/keras/_impl/keras/estimator_test.py b/tensorflow/python/keras/_impl/keras/estimator_test.py index e076dc2..27b7ec7 100644 --- a/tensorflow/python/keras/_impl/keras/estimator_test.py +++ b/tensorflow/python/keras/_impl/keras/estimator_test.py @@ -512,6 +512,26 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): ._config.gpu_options.per_process_gpu_memory_fraction, gpu_options.per_process_gpu_memory_fraction) + def test_pretrained_weights(self): + keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=['mse', keras.metrics.categorical_accuracy]) + + keras_model.train_on_batch( + np.random.random((10,) + _INPUT_SIZE), np.random.random((10, + _NUM_CLASS))) + weights = keras_model.get_weights() + keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model() + keras_model.set_weights(weights) + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=['mse', keras.metrics.categorical_accuracy]) + keras.estimator.model_to_estimator( + keras_model=keras_model, config=self._config) + if __name__ == '__main__': test.main() -- 2.7.4