Adjusts the TPU RunConfig to respect parent class master/evaluation master.
authorJianwei Xie <xiejw@google.com>
Wed, 3 Jan 2018 00:19:27 +0000 (16:19 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 3 Jan 2018 00:23:33 +0000 (16:23 -0800)
PiperOrigin-RevId: 180607766

tensorflow/contrib/tpu/python/tpu/tpu_config.py
tensorflow/contrib/tpu/python/tpu/tpu_config_test.py

index 1b6ce2dfdf09a046038b293810115b3e6e2f05ab..0c2580211ab7674d841ca1953c9327df9488bb8e 100644 (file)
@@ -99,7 +99,10 @@ class TPUConfig(
 class RunConfig(run_config_lib.RunConfig):
   """RunConfig with TPU support."""
 
-  def __init__(self, tpu_config=None, evaluation_master=None, master='',
+  def __init__(self,
+               tpu_config=None,
+               evaluation_master=None,
+               master=None,
                **kwargs):
     """Constructs a RunConfig.
 
@@ -113,11 +116,23 @@ class RunConfig(run_config_lib.RunConfig):
     """
     super(RunConfig, self).__init__(**kwargs)
     self._tpu_config = tpu_config or TPUConfig()
-    if evaluation_master is None:
-      self._evaluation_master = master
-    else:
+
+    # If user sets master and/or evaluation_master explicilty, including empty
+    # string '', take it. Otherwise, take the values set by parent class.
+    if master is not None:
+      self._master = master
+
+    if evaluation_master is not None:
       self._evaluation_master = evaluation_master
-    self._master = master
+    elif (not self._evaluation_master and
+          self.task_type != run_config_lib.TaskType.EVALUATOR):
+      # If the task type is EVALUATOR, it means some cluster manager sets the
+      # TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
+      #
+      # Otherwise, it means user executes the code without external cluster
+      # manager. For that, we optimize the user experience by setting
+      # evaluation_master to master, unless user overwrites it.
+      self._evaluation_master = self._master
 
   @property
   def evaluation_master(self):
index 618f2636184ac05e69376ddc581dde086aacfbe1..60884aa32f932413b49ea2193a145828489ea04c 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import print_function
 import json
 
 from tensorflow.contrib.tpu.python.tpu import tpu_config as tpu_config_lib
+from tensorflow.python.estimator import run_config as run_config_lib
 from tensorflow.python.platform import test
 
 
@@ -43,6 +44,79 @@ class TPURunConfigTest(test.TestCase):
           tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0))
 
 
+class TPURunConfigMasterTest(test.TestCase):
+
+  def test_default_values(self):
+    run_config = tpu_config_lib.RunConfig()
+    self.assertEqual('', run_config.master)
+    self.assertEqual('', run_config.evaluation_master)
+
+  def test_user_provided_master_and_evaluation_master(self):
+    run_config = tpu_config_lib.RunConfig(
+        master='_master_123', evaluation_master='_eval_master_123')
+    self.assertEqual('_master_123', run_config.master)
+    self.assertEqual('_eval_master_123', run_config.evaluation_master)
+
+  def test_evaluation_master_defaults_to_master(self):
+    run_config = tpu_config_lib.RunConfig(master='_master_123')
+    self.assertEqual('_master_123', run_config.master)
+    self.assertEqual('_master_123', run_config.evaluation_master)
+
+  def test_tf_config(self):
+    tf_config = {
+        'session_master': '_master_123',
+        'eval_session_master': '_eval_master_123'
+    }
+    with _set_tf_config_env_variable(tf_config):
+      run_config = tpu_config_lib.RunConfig()
+      self.assertEqual('_master_123', run_config.master)
+      self.assertEqual('_eval_master_123', run_config.evaluation_master)
+
+  def test_evaluation_master_defaults_to_master_in_tf_config(self):
+    tf_config = {
+        'session_master': '_master_123',
+    }
+    with _set_tf_config_env_variable(tf_config):
+      run_config = tpu_config_lib.RunConfig()
+      self.assertEqual('_master_123', run_config.master)
+      self.assertEqual('_master_123', run_config.evaluation_master)
+
+  def test_respect_evaluation_master_in_tf_config(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host0:0'],
+        },
+        'task': {
+            'type': run_config_lib.TaskType.EVALUATOR,
+            'index': 0
+        },
+    }
+    with _set_tf_config_env_variable(tf_config):
+      run_config = tpu_config_lib.RunConfig(master='_something')
+      self.assertEqual('', run_config.evaluation_master)
+
+  def test_user_overwrites_tf_config(self):
+    tf_config = {
+        'session_master': '_master_123',
+        'eval_session_master': '_eval_master_123'
+    }
+    with _set_tf_config_env_variable(tf_config):
+      run_config = tpu_config_lib.RunConfig(
+          master='_new_master_123', evaluation_master='_new_eval_master_123')
+      self.assertEqual('_new_master_123', run_config.master)
+      self.assertEqual('_new_eval_master_123', run_config.evaluation_master)
+
+  def test_user_overwrites_master_in_tf_config(self):
+    tf_config = {
+        'session_master': '_master_123',
+        'eval_session_master': '_eval_master_123'
+    }
+    with _set_tf_config_env_variable(tf_config):
+      run_config = tpu_config_lib.RunConfig(master='_new_master_123')
+      self.assertEqual('_new_master_123', run_config.master)
+      self.assertEqual('_eval_master_123', run_config.evaluation_master)
+
+
 class TPUJobNameTest(test.TestCase):
 
   def test_default_name(self):