class RunConfig(run_config_lib.RunConfig):
"""RunConfig with TPU support."""
- def __init__(self, tpu_config=None, evaluation_master=None, master='',
+ def __init__(self,
+ tpu_config=None,
+ evaluation_master=None,
+ master=None,
**kwargs):
"""Constructs a RunConfig.
"""
super(RunConfig, self).__init__(**kwargs)
self._tpu_config = tpu_config or TPUConfig()
- if evaluation_master is None:
- self._evaluation_master = master
- else:
+
+ # If user sets master and/or evaluation_master explicilty, including empty
+ # string '', take it. Otherwise, take the values set by parent class.
+ if master is not None:
+ self._master = master
+
+ if evaluation_master is not None:
self._evaluation_master = evaluation_master
- self._master = master
+ elif (not self._evaluation_master and
+ self.task_type != run_config_lib.TaskType.EVALUATOR):
+ # If the task type is EVALUATOR, it means some cluster manager sets the
+ # TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
+ #
+ # Otherwise, it means user executes the code without external cluster
+ # manager. For that, we optimize the user experience by setting
+ # evaluation_master to master, unless user overwrites it.
+ self._evaluation_master = self._master
@property
def evaluation_master(self):
import json
from tensorflow.contrib.tpu.python.tpu import tpu_config as tpu_config_lib
+from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import test
tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0))
+class TPURunConfigMasterTest(test.TestCase):
+
+ def test_default_values(self):
+ run_config = tpu_config_lib.RunConfig()
+ self.assertEqual('', run_config.master)
+ self.assertEqual('', run_config.evaluation_master)
+
+ def test_user_provided_master_and_evaluation_master(self):
+ run_config = tpu_config_lib.RunConfig(
+ master='_master_123', evaluation_master='_eval_master_123')
+ self.assertEqual('_master_123', run_config.master)
+ self.assertEqual('_eval_master_123', run_config.evaluation_master)
+
+ def test_evaluation_master_defaults_to_master(self):
+ run_config = tpu_config_lib.RunConfig(master='_master_123')
+ self.assertEqual('_master_123', run_config.master)
+ self.assertEqual('_master_123', run_config.evaluation_master)
+
+ def test_tf_config(self):
+ tf_config = {
+ 'session_master': '_master_123',
+ 'eval_session_master': '_eval_master_123'
+ }
+ with _set_tf_config_env_variable(tf_config):
+ run_config = tpu_config_lib.RunConfig()
+ self.assertEqual('_master_123', run_config.master)
+ self.assertEqual('_eval_master_123', run_config.evaluation_master)
+
+ def test_evaluation_master_defaults_to_master_in_tf_config(self):
+ tf_config = {
+ 'session_master': '_master_123',
+ }
+ with _set_tf_config_env_variable(tf_config):
+ run_config = tpu_config_lib.RunConfig()
+ self.assertEqual('_master_123', run_config.master)
+ self.assertEqual('_master_123', run_config.evaluation_master)
+
+ def test_respect_evaluation_master_in_tf_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.EVALUATOR,
+ 'index': 0
+ },
+ }
+ with _set_tf_config_env_variable(tf_config):
+ run_config = tpu_config_lib.RunConfig(master='_something')
+ self.assertEqual('', run_config.evaluation_master)
+
+ def test_user_overwrites_tf_config(self):
+ tf_config = {
+ 'session_master': '_master_123',
+ 'eval_session_master': '_eval_master_123'
+ }
+ with _set_tf_config_env_variable(tf_config):
+ run_config = tpu_config_lib.RunConfig(
+ master='_new_master_123', evaluation_master='_new_eval_master_123')
+ self.assertEqual('_new_master_123', run_config.master)
+ self.assertEqual('_new_eval_master_123', run_config.evaluation_master)
+
+ def test_user_overwrites_master_in_tf_config(self):
+ tf_config = {
+ 'session_master': '_master_123',
+ 'eval_session_master': '_eval_master_123'
+ }
+ with _set_tf_config_env_variable(tf_config):
+ run_config = tpu_config_lib.RunConfig(master='_new_master_123')
+ self.assertEqual('_new_master_123', run_config.master)
+ self.assertEqual('_eval_master_123', run_config.evaluation_master)
+
+
class TPUJobNameTest(test.TestCase):
def test_default_name(self):