adding ps_strategy to run_config to enable different placement strateā€¦ (#15640)
authorSiu Kei, Muk <muksiukei@gmail.com>
Mon, 16 Apr 2018 02:23:20 +0000 (10:23 +0800)
committerJonathan Hseu <vomjom@vomjom.net>
Mon, 16 Apr 2018 02:23:20 +0000 (19:23 -0700)
* adding ps_strategy to run_config to enable different placement strategy in estimator

* 1. Moved estimator._device_fn to RunConfig as @property
2. Made RunConfig.device_fn to return custom device function if one is specified, otherwise the result from `tf.train.replica_device_setter` call is used
3. Added some basic unit tests, may need further tests.

* 1. Removing ps_strategy.
2. Modified estimator to take overriden device_fn from  if set.
3. Removed ps_strategy related unit tests.

* Adding manual initialization of _device_fn in legacy RunConfig class

* Updated estimator golden API through
1. bazel build //tensorflow/tools/api/tests:api_compatibility_test
2. bazel-bin/tensorflow/tools/api/tests/api_compatibility_test --update_goldens True

* fixing code styles

tensorflow/contrib/learn/python/learn/estimators/run_config.py
tensorflow/python/estimator/estimator.py
tensorflow/python/estimator/run_config.py
tensorflow/python/estimator/run_config_test.py
tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt

index 8c85c43..14ee2ba 100644 (file)
@@ -299,6 +299,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
     # so instead of breaking compatibility with that assumption, we
     # just manually initialize this field:
     self._train_distribute = None
+    self._device_fn = None
 
     gpu_options = config_pb2.GPUOptions(
         per_process_gpu_memory_fraction=gpu_memory_fraction)
index 8890f74..901f047 100644 (file)
@@ -216,7 +216,8 @@ class Estimator(object):
     else:
       self._session_config = self._config.session_config
 
-    self._device_fn = _get_replica_device_setter(self._config)
+    self._device_fn = self._config.device_fn or \
+                      _get_replica_device_setter(self._config)
 
     if model_fn is None:
       raise ValueError('model_fn must be provided to Estimator.')
index dab442a..8162b24 100644 (file)
@@ -27,11 +27,13 @@ import six
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import server_lib
+from tensorflow.python.estimator import util
 from tensorflow.python.util import compat_internal
 from tensorflow.python.util.tf_export import tf_export
 
 
 _USE_DEFAULT = object()
+_VALID_DEVICE_FN_ARGS = set(['op'])
 
 # A list of the property names in RunConfig that the user is allowed to change.
 _DEFAULT_REPLACEABLE_LIST = [
@@ -44,7 +46,8 @@ _DEFAULT_REPLACEABLE_LIST = [
     'keep_checkpoint_max',
     'keep_checkpoint_every_n_hours',
     'log_step_count_steps',
-    'train_distribute'
+    'train_distribute',
+    'device_fn'
 ]
 
 _SAVE_CKPT_ERR = (
@@ -279,6 +282,11 @@ def _validate_properties(run_config):
   _validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types),
             message='tf_random_seed must be integer.')
 
+  _validate('device_fn', lambda device_fn: six.callable(device_fn) and
+            set(util.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
+            message='device_fn must be callable with exactly'
+                    ' one argument "op".')
+
 
 class TaskType(object):
   MASTER = 'master'
@@ -302,7 +310,8 @@ class RunConfig(object):
                keep_checkpoint_max=5,
                keep_checkpoint_every_n_hours=10000,
                log_step_count_steps=100,
-               train_distribute=None):
+               train_distribute=None,
+               device_fn=None):
     """Constructs a RunConfig.
 
     All distributed training related properties `cluster_spec`, `is_chief`,
@@ -430,6 +439,10 @@ class RunConfig(object):
         `tf.contrib.distribute.DistributionStrategy`. If specified,
         then Estimator will distribute the user's model during training,
         according to the policy specified by that strategy.
+      device_fn: A callable invoked for every `Operation` that takes the
+        `Operation` and returns the device string. If `None`, defaults to
+        the device function returned by `tf.train.replica_device_setter`
+        with round-robin strategy.
 
     Raises:
       ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
@@ -466,7 +479,8 @@ class RunConfig(object):
         keep_checkpoint_max=keep_checkpoint_max,
         keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
         log_step_count_steps=log_step_count_steps,
-        train_distribute=train_distribute)
+        train_distribute=train_distribute,
+        device_fn=device_fn)
 
     self._init_distributed_setting_from_environment_var(tf_config)
 
@@ -569,6 +583,16 @@ class RunConfig(object):
     return self._cluster_spec
 
   @property
+  def device_fn(self):
+    """Returns the device_fn.
+
+    If device_fn is not `None`, it overrides the default
+    device function used in `Estimator`.
+    Otherwise the default one is used.
+    """
+    return self._device_fn
+
+  @property
   def evaluation_master(self):
     return self._evaluation_master
 
@@ -697,7 +721,8 @@ class RunConfig(object):
       - `keep_checkpoint_max`,
       - `keep_checkpoint_every_n_hours`,
       - `log_step_count_steps`,
-      - `train_distribute`.
+      - `train_distribute`,
+      - `device_fn`.
 
     In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
     can be set (should not be both).
index a3eef4c..c8b1260 100644 (file)
@@ -42,6 +42,7 @@ _SESSION_CONFIG_ERR = 'session_config must be instance of ConfigProto'
 _KEEP_CKPT_MAX_ERR = 'keep_checkpoint_max should be >= 0'
 _KEEP_CKPT_HOURS_ERR = 'keep_checkpoint_every_n_hours should be > 0'
 _TF_RANDOM_SEED_ERR = 'tf_random_seed must be integer'
+_DEVICE_FN_ERR = 'device_fn must be callable with exactly one argument "op".'
 _ONE_CHIEF_ERR = 'The "cluster" in TF_CONFIG must have only one "chief" node.'
 _ONE_MASTER_ERR = 'The "cluster" in TF_CONFIG must have only one "master" node.'
 _INVALID_TASK_TYPE_FOR_EVAL_MASTER = (
@@ -83,6 +84,7 @@ class RunConfigTest(test.TestCase):
     self.assertEqual(5, config.keep_checkpoint_max)
     self.assertEqual(10000, config.keep_checkpoint_every_n_hours)
     self.assertIsNone(config.service)
+    self.assertIsNone(config.device_fn)
 
   def test_model_dir(self):
     empty_config = run_config_lib.RunConfig()
@@ -93,6 +95,7 @@ class RunConfigTest(test.TestCase):
 
   def test_replace_with_allowed_properties(self):
     session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+    device_fn = lambda op: "/cpu:0"
 
     config = run_config_lib.RunConfig().replace(
         tf_random_seed=11,
@@ -100,13 +103,15 @@ class RunConfigTest(test.TestCase):
         save_checkpoints_secs=14,
         session_config=session_config,
         keep_checkpoint_max=16,
-        keep_checkpoint_every_n_hours=17)
+        keep_checkpoint_every_n_hours=17,
+        device_fn=device_fn)
     self.assertEqual(11, config.tf_random_seed)
     self.assertEqual(12, config.save_summary_steps)
     self.assertEqual(14, config.save_checkpoints_secs)
     self.assertEqual(session_config, config.session_config)
     self.assertEqual(16, config.keep_checkpoint_max)
     self.assertEqual(17, config.keep_checkpoint_every_n_hours)
+    self.assertEqual(device_fn, config.device_fn)
 
   def test_replace_none_value(self):
     config = run_config_lib.RunConfig().replace(
@@ -117,7 +122,8 @@ class RunConfigTest(test.TestCase):
         save_checkpoints_steps=None,
         session_config=None,
         keep_checkpoint_max=None,
-        keep_checkpoint_every_n_hours=None)
+        keep_checkpoint_every_n_hours=None,
+        device_fn=None)
     self.assertIsNone(config.tf_random_seed)
     self.assertIsNone(config.model_dir)
     self.assertIsNone(config.save_summary_steps)
@@ -126,6 +132,7 @@ class RunConfigTest(test.TestCase):
     self.assertIsNone(config.session_config)
     self.assertIsNone(config.keep_checkpoint_max)
     self.assertIsNone(config.keep_checkpoint_every_n_hours)
+    self.assertIsNone(config.device_fn)
 
   def test_replace_with_disallowallowed_properties(self):
     config = run_config_lib.RunConfig()
@@ -166,9 +173,12 @@ class RunConfigTest(test.TestCase):
       config.replace(keep_checkpoint_every_n_hours=0)
     with self.assertRaisesRegexp(ValueError, _TF_RANDOM_SEED_ERR):
       config.replace(tf_random_seed=1.0)
+    with self.assertRaisesRegexp(ValueError, _DEVICE_FN_ERR):
+      config.replace(device_fn=lambda x, y: 0)
 
   def test_init_with_allowed_properties(self):
     session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+    device_fn = lambda op: "/cpu:0"
 
     config = run_config_lib.RunConfig(
         tf_random_seed=11,
@@ -176,13 +186,15 @@ class RunConfigTest(test.TestCase):
         save_checkpoints_secs=14,
         session_config=session_config,
         keep_checkpoint_max=16,
-        keep_checkpoint_every_n_hours=17)
+        keep_checkpoint_every_n_hours=17,
+        device_fn=device_fn)
     self.assertEqual(11, config.tf_random_seed)
     self.assertEqual(12, config.save_summary_steps)
     self.assertEqual(14, config.save_checkpoints_secs)
     self.assertEqual(session_config, config.session_config)
     self.assertEqual(16, config.keep_checkpoint_max)
     self.assertEqual(17, config.keep_checkpoint_every_n_hours)
+    self.assertEqual(device_fn, config.device_fn)
 
   def test_init_none_value(self):
     config = run_config_lib.RunConfig(
@@ -193,7 +205,8 @@ class RunConfigTest(test.TestCase):
         save_checkpoints_steps=None,
         session_config=None,
         keep_checkpoint_max=None,
-        keep_checkpoint_every_n_hours=None)
+        keep_checkpoint_every_n_hours=None,
+        device_fn=None)
     self.assertIsNone(config.tf_random_seed)
     self.assertIsNone(config.model_dir)
     self.assertIsNone(config.save_summary_steps)
@@ -202,6 +215,7 @@ class RunConfigTest(test.TestCase):
     self.assertIsNone(config.session_config)
     self.assertIsNone(config.keep_checkpoint_max)
     self.assertIsNone(config.keep_checkpoint_every_n_hours)
+    self.assertIsNone(config.device_fn)
 
   def test_init_invalid_values(self):
     with self.assertRaisesRegexp(ValueError, _MODEL_DIR_ERR):
@@ -220,6 +234,8 @@ class RunConfigTest(test.TestCase):
       run_config_lib.RunConfig(keep_checkpoint_every_n_hours=0)
     with self.assertRaisesRegexp(ValueError, _TF_RANDOM_SEED_ERR):
       run_config_lib.RunConfig(tf_random_seed=1.0)
+    with self.assertRaisesRegexp(ValueError, _DEVICE_FN_ERR):
+      run_config_lib.RunConfig(device_fn=lambda x: "/cpu:0")
 
 
 class RunConfigDistributedSettingTest(test.TestCase):
index 05e603e..c8da55d 100644 (file)
@@ -7,6 +7,10 @@ tf_class {
     mtype: "<type \'property\'>"
   }
   member {
+    name: "device_fn"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "evaluation_master"
     mtype: "<type \'property\'>"
   }
@@ -84,7 +88,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\'], "
+    argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\'], "
   }
   member_method {
     name: "replace"