Only set session in model_to_estimator if _SESSION has not been set.
authorYifei Feng <yifeif@google.com>
Mon, 9 Apr 2018 19:09:34 +0000 (12:09 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 9 Apr 2018 19:12:18 +0000 (12:12 -0700)
Fix #18193.

PiperOrigin-RevId: 192164669

tensorflow/python/keras/_impl/keras/estimator.py
tensorflow/python/keras/_impl/keras/estimator_test.py

index 5d370eb..8043242 100644 (file)
@@ -26,6 +26,7 @@ from tensorflow.python.estimator import estimator as estimator_lib
 from tensorflow.python.estimator import export as export_lib
 from tensorflow.python.estimator import model_fn as model_fn_lib
 from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -465,11 +466,21 @@ def model_to_estimator(keras_model=None,
   estimator = estimator_lib.Estimator(
       keras_model_fn, model_dir=model_dir, config=config)
 
+  old_session = K._SESSION
   # Pass the config into keras backend's default session.
   sess = session.Session(config=estimator._session_config)
   K.set_session(sess)
+  try:
+    keras_weights = keras_model.get_weights()
+  except errors.FailedPreconditionError as e:
+    if old_session is None:
+      raise e
+    logging.warning(
+        'The Keras backend session has already been '
+        'set. The _session_config passed to model_to_estimator is not used.')
+    K.set_session(old_session)
+    keras_weights = keras_model.get_weights()
 
-  keras_weights = keras_model.get_weights()
   if keras_model._is_graph_network:
     # TODO(yifeif): move checkpoint initialization to scaffold.init_fn
     _save_first_checkpoint(keras_model,
index e076dc2..27b7ec7 100644 (file)
@@ -512,6 +512,26 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
                      ._config.gpu_options.per_process_gpu_memory_fraction,
                      gpu_options.per_process_gpu_memory_fraction)
 
+  def test_pretrained_weights(self):
+    keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
+    keras_model.compile(
+        loss='categorical_crossentropy',
+        optimizer=rmsprop.RMSPropOptimizer(1e-3),
+        metrics=['mse', keras.metrics.categorical_accuracy])
+
+    keras_model.train_on_batch(
+        np.random.random((10,) + _INPUT_SIZE), np.random.random((10,
+                                                                 _NUM_CLASS)))
+    weights = keras_model.get_weights()
+    keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
+    keras_model.set_weights(weights)
+    keras_model.compile(
+        loss='categorical_crossentropy',
+        optimizer=rmsprop.RMSPropOptimizer(1e-3),
+        metrics=['mse', keras.metrics.categorical_accuracy])
+    keras.estimator.model_to_estimator(
+        keras_model=keras_model, config=self._config)
+
 
 if __name__ == '__main__':
   test.main()