Fix the issue where gpu_option is not respected for keras estimator.
authorYifei Feng <yifeif@google.com>
Mon, 5 Mar 2018 21:47:30 +0000 (13:47 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Mar 2018 21:52:51 +0000 (13:52 -0800)
Set keras backend session with the given config before any get_session call creates a new session.
Fix #14776.

PiperOrigin-RevId: 187916300

tensorflow/python/keras/_impl/keras/estimator.py
tensorflow/python/keras/_impl/keras/estimator_test.py

index 0bf5bd4..5697771 100644 (file)
@@ -296,10 +296,14 @@ def model_to_estimator(keras_model=None,
         'Given keras model has not been compiled yet. Please compile first '
         'before creating the estimator.')
 
-  keras_weights = keras_model.get_weights()
   keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
   est = estimator_lib.Estimator(
       keras_model_fn, model_dir=model_dir, config=config)
+  # Pass the config into keras backend's default session.
+  with session.Session(config=est._session_config) as sess:
+    K.set_session(sess)
+
+  keras_weights = keras_model.get_weights()
   # TODO(yifeif): move checkpoint initialization to scaffold.init_fn
   _save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
   return est
index 88dd14b..a9de5dd 100644 (file)
@@ -24,6 +24,7 @@ import tempfile
 
 import numpy as np
 
+from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.estimator import run_config as run_config_lib
 from tensorflow.python.estimator.inputs import numpy_io
 from tensorflow.python.framework import test_util
@@ -377,6 +378,22 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
             keras_model=keras_model,
             model_dir=tempfile.mkdtemp(dir=self._base_dir))
 
+  def test_gpu_config(self):
+    keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
+    keras_model.compile(
+        loss='categorical_crossentropy',
+        optimizer='rmsprop',
+        metrics=['mse', keras.metrics.categorical_accuracy])
+
+    gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
+    sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
+    self._config._session_config = sess_config
+    keras.estimator.model_to_estimator(
+        keras_model=keras_model, config=self._config)
+    self.assertEqual(keras.backend.get_session()
+                     ._config.gpu_options.per_process_gpu_memory_fraction,
+                     gpu_options.per_process_gpu_memory_fraction)
+
 
 if __name__ == '__main__':
   test.main()