From 5e8aaa66af43b6b66e61ca7d589002eac6b4fb69 Mon Sep 17 00:00:00 2001 From: Yifei Feng Date: Wed, 21 Feb 2018 15:07:05 -0800 Subject: [PATCH] Don't assign device for the keras part of _saved_first_checkpoint. Fix #14504. PiperOrigin-RevId: 186526175 --- tensorflow/python/keras/_impl/keras/estimator.py | 24 +++++++++---------- .../python/keras/_impl/keras/estimator_test.py | 27 +++++++++++++++++++++- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py index db0140c..0bf5bd4 100644 --- a/tensorflow/python/keras/_impl/keras/estimator.py +++ b/tensorflow/python/keras/_impl/keras/estimator.py @@ -222,18 +222,18 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects, Returns: The model_fn for a keras Estimator. """ - with ops.Graph().as_default() as g, g.device(estimator._device_fn): - random_seed.set_random_seed(estimator.config.tf_random_seed) - training_util.create_global_step() - model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model, - custom_objects) - - if isinstance(model, models.Sequential): - model = model.model - # Load weights and save to checkpoint if there is no checkpoint - latest_path = saver_lib.latest_checkpoint(estimator.model_dir) - if not latest_path: - with session.Session() as sess: + # Load weights and save to checkpoint if there is no checkpoint + latest_path = saver_lib.latest_checkpoint(estimator.model_dir) + if not latest_path: + with ops.Graph().as_default(): + random_seed.set_random_seed(estimator.config.tf_random_seed) + training_util.create_global_step() + model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model, + custom_objects) + if isinstance(model, models.Sequential): + model = model.model + # save to checkpoint + with session.Session(config=estimator._session_config) as sess: model.set_weights(keras_weights) # Make update ops and initialize all variables. if not model.train_function: diff --git a/tensorflow/python/keras/_impl/keras/estimator_test.py b/tensorflow/python/keras/_impl/keras/estimator_test.py index 9fc48b4..88dd14b 100644 --- a/tensorflow/python/keras/_impl/keras/estimator_test.py +++ b/tensorflow/python/keras/_impl/keras/estimator_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import json from math import log10 import os import tempfile @@ -62,7 +63,7 @@ def simple_functional_model(): return model -def get_resource_for_simple_model(is_sequential, is_evaluate): +def get_resource_for_simple_model(is_sequential=True, is_evaluate=False): model = simple_sequential_model( ) if is_sequential else simple_functional_model() if is_sequential: @@ -352,6 +353,30 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): model_dir=tempfile.mkdtemp(dir=self._base_dir), custom_objects=custom_objects) + def test_tf_config(self): + keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['mse', keras.metrics.categorical_accuracy]) + + tf_config = json.dumps({ + 'cluster': { + run_config_lib.TaskType.PS: ['localhost:1234'], + run_config_lib.TaskType.WORKER: ['localhost:1236'], + run_config_lib.TaskType.MASTER: ['localhost:1238'] + }, + 'task': { + 'type': run_config_lib.TaskType.MASTER, + 'index': 0 + } + }) + with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): + with self.test_session(): + keras.estimator.model_to_estimator( + keras_model=keras_model, + model_dir=tempfile.mkdtemp(dir=self._base_dir)) + if __name__ == '__main__': test.main() -- 2.7.4