Sync replicas distributed training example with two strategies:
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Apr 2018 12:31:48 +0000 (05:31 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 12:34:52 +0000 (05:34 -0700)
1) Interleave covariance and inverse update ops with training op.
2) Run the inverse and covariance ops on separate dedicated workers.

PiperOrigin-RevId: 191579634

tensorflow/contrib/kfac/examples/BUILD
tensorflow/contrib/kfac/examples/convnet.py
tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py [new file with mode: 0644]
tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py [new file with mode: 0644]
tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py [moved from tensorflow/contrib/kfac/examples/convnet_mnist_main.py with 57% similarity]
tensorflow/contrib/kfac/examples/tests/convnet_test.py

index 7dd40c1..8186fa1 100644 (file)
@@ -28,8 +28,28 @@ py_library(
 )
 
 py_binary(
-    name = "convnet_mnist_main",
-    srcs = ["convnet_mnist_main.py"],
+    name = "convnet_mnist_single_main",
+    srcs = ["convnet_mnist_single_main.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":convnet",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_binary(
+    name = "convnet_mnist_multi_tower_main",
+    srcs = ["convnet_mnist_multi_tower_main.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":convnet",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_binary(
+    name = "convnet_mnist_distributed_main",
+    srcs = ["convnet_mnist_distributed_main.py"],
     srcs_version = "PY2AND3",
     deps = [
         ":convnet",
index 39d80ad..e8e3353 100644 (file)
@@ -37,6 +37,8 @@ import tensorflow as tf
 
 from tensorflow.contrib.kfac.examples import mlp
 from tensorflow.contrib.kfac.examples import mnist
+from tensorflow.contrib.kfac.python.ops import optimizer as opt
+
 
 lc = tf.contrib.kfac.layer_collection
 oq = tf.contrib.kfac.op_queue
@@ -48,12 +50,18 @@ __all__ = [
     "linear_layer",
     "build_model",
     "minimize_loss_single_machine",
-    "minimize_loss_distributed",
+    "distributed_grads_only_and_ops_chief_worker",
+    "distributed_grads_and_ops_dedicated_workers",
     "train_mnist_single_machine",
-    "train_mnist_distributed",
+    "train_mnist_distributed_sync_replicas",
+    "train_mnist_multitower"
 ]
 
 
+# Inverse update ops will be run every _INVERT_EVRY iterations.
+_INVERT_EVERY = 10
+
+
 def conv_layer(layer_id, inputs, kernel_size, out_channels):
   """Builds a convolutional layer with ReLU non-linearity.
 
@@ -161,8 +169,9 @@ def build_model(examples, labels, num_labels, layer_collection):
   accuracy = tf.reduce_mean(
       tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
 
-  tf.summary.scalar("loss", loss)
-  tf.summary.scalar("accuracy", accuracy)
+  with tf.device("/cpu:0"):
+    tf.summary.scalar("loss", loss)
+    tf.summary.scalar("accuracy", accuracy)
 
   # Register parameters. K-FAC needs to know about the inputs, outputs, and
   # parameters of each conv/fully connected layer and the logits powering the
@@ -181,41 +190,59 @@ def build_model(examples, labels, num_labels, layer_collection):
 def minimize_loss_single_machine(loss,
                                  accuracy,
                                  layer_collection,
+                                 device="/gpu:0",
                                  session_config=None):
   """Minimize loss with K-FAC on a single machine.
 
-  A single Session is responsible for running all of K-FAC's ops.
+  A single Session is responsible for running all of K-FAC's ops. The covariance
+  and inverse update ops are placed on `device`. All model variables are on CPU.
 
   Args:
     loss: 0-D Tensor. Loss to be minimized.
     accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
     layer_collection: LayerCollection instance describing model architecture.
       Used by K-FAC to construct preconditioner.
+    device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and invserse
+      update ops are run on this device.
     session_config: None or tf.ConfigProto. Configuration for tf.Session().
 
   Returns:
     final value for 'accuracy'.
   """
   # Train with K-FAC.
-  global_step = tf.train.get_or_create_global_step()
+  g_step = tf.train.get_or_create_global_step()
   optimizer = opt.KfacOptimizer(
       learning_rate=0.0001,
       cov_ema_decay=0.95,
       damping=0.001,
       layer_collection=layer_collection,
+      placement_strategy="round_robin",
+      cov_devices=[device],
+      inv_devices=[device],
       momentum=0.9)
-  train_op = optimizer.minimize(loss, global_step=global_step)
+  (cov_update_thunks,
+   inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+  with tf.device(device):
+    train_op = optimizer.minimize(loss, global_step=g_step)
+
+  def make_update_op(update_thunks):
+    update_op = [thunk() for thunk in update_thunks]
+    return tf.group(*update_op)
+
+  cov_update_op = make_update_op(cov_update_thunks)
+  with tf.control_dependencies([train_op, cov_update_op]):
+    inverse_op = tf.cond(
+        tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0),
+        lambda: make_update_op(inv_update_thunks), tf.no_op)
 
   tf.logging.info("Starting training.")
   with tf.train.MonitoredTrainingSession(config=session_config) as sess:
     while not sess.should_stop():
-      global_step_, loss_, accuracy_, _, _ = sess.run(
-          [global_step, loss, accuracy, train_op, optimizer.cov_update_op])
-
-      if global_step_ % 100 == 0:
-        sess.run(optimizer.inv_update_op)
+      global_step_, loss_, accuracy_, _ = sess.run(
+          [g_step, loss, accuracy, inverse_op])
 
-      if global_step_ % 100 == 0:
+      if (global_step_ + 1) % _INVERT_EVERY == 0:
         tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
                         global_step_, loss_, accuracy_)
 
@@ -250,16 +277,62 @@ def _num_gradient_tasks(num_tasks):
   return int(np.ceil(0.6 * num_tasks))
 
 
-def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
-                              checkpoint_dir, loss, accuracy, layer_collection):
-  """Minimize loss with an synchronous implementation of K-FAC.
+def _make_distributed_train_op(
+    task_id,
+    num_worker_tasks,
+    num_ps_tasks,
+    layer_collection
+):
+  """Creates optimizer and distributed training op.
 
-  Different tasks are responsible for different parts of K-FAC's Ops. The first
-  60% of tasks update weights; the next 20% accumulate covariance statistics;
-  the last 20% invert the matrices used to precondition gradients.
+  Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
+  the train op.
+
+  Args:
+   task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+    num_worker_tasks: int. Number of workers in this distributed training setup.
+    num_ps_tasks: int. Number of parameter servers holding variables. If 0,
+      parameter servers are not used.
+    layer_collection: LayerCollection instance describing model architecture.
+      Used by K-FAC to construct preconditioner.
+
+  Returns:
+    sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
+      optimizer.
+    optimizer: Instance of `opt.KfacOptimizer`.
+    global_step: `tensor`, Global step.
+  """
+  tf.logging.info("Task id : %d", task_id)
+  with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
+    global_step = tf.train.get_or_create_global_step()
+    optimizer = opt.KfacOptimizer(
+        learning_rate=0.0001,
+        cov_ema_decay=0.95,
+        damping=0.001,
+        layer_collection=layer_collection,
+        momentum=0.9)
+    sync_optimizer = tf.train.SyncReplicasOptimizer(
+        opt=optimizer,
+        replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
+        total_num_replicas=num_worker_tasks)
+    return sync_optimizer, optimizer, global_step
+
+
+def distributed_grads_only_and_ops_chief_worker(
+    task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
+    loss, accuracy, layer_collection, invert_every=10):
+  """Minimize loss with a synchronous implementation of K-FAC.
+
+  All workers perform gradient computation. Chief worker applies gradient after
+  averaging the gradients obtained from all the workers. All workers block
+  execution untill the update is applied. Chief worker runs covariance and
+  inverse update ops. Covariance and inverse matrices are placed on parameter
+  servers in a round robin manner. For further details on synchronous
+  distributed optimization check `tf.train.SyncReplicasOptimizer`.
 
   Args:
     task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+    is_chief: `boolean`, `True` if the worker is chief worker.
     num_worker_tasks: int. Number of workers in this distributed training setup.
     num_ps_tasks: int. Number of parameter servers holding variables. If 0,
       parameter servers are not used.
@@ -271,6 +344,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
       run with each step.
     layer_collection: LayerCollection instance describing model architecture.
       Used by K-FAC to construct preconditioner.
+    invert_every: `int`, Number of steps between update the inverse.
 
   Returns:
     final value for 'accuracy'.
@@ -278,19 +352,80 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
   Raises:
     ValueError: if task_id >= num_worker_tasks.
   """
-  with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
-    global_step = tf.train.get_or_create_global_step()
-    optimizer = opt.KfacOptimizer(
-        learning_rate=0.0001,
-        cov_ema_decay=0.95,
-        damping=0.001,
-        layer_collection=layer_collection,
-        momentum=0.9)
-    inv_update_queue = oq.OpQueue(optimizer.inv_update_ops)
-    sync_optimizer = tf.train.SyncReplicasOptimizer(
-        opt=optimizer,
-        replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks))
-    train_op = sync_optimizer.minimize(loss, global_step=global_step)
+
+  sync_optimizer, optimizer, global_step = _make_distributed_train_op(
+      task_id, num_worker_tasks, num_ps_tasks, layer_collection)
+  (cov_update_thunks,
+   inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+  train_op = sync_optimizer.minimize(loss, global_step=global_step)
+
+  tf.logging.info("Starting training.")
+  hooks = [sync_optimizer.make_session_run_hook(is_chief)]
+
+  def make_update_op(update_thunks):
+    update_op = [thunk() for thunk in update_thunks]
+    return tf.group(*update_op)
+
+  if is_chief:
+    cov_update_op = make_update_op(cov_update_thunks)
+    with tf.control_dependencies([train_op, cov_update_op]):
+      update_op = tf.cond(
+          tf.equal(tf.mod(global_step + 1, invert_every), 0),
+          lambda: make_update_op(inv_update_thunks),
+          tf.no_op)
+  else:
+    update_op = train_op
+
+  with tf.train.MonitoredTrainingSession(
+      master=master,
+      is_chief=is_chief,
+      checkpoint_dir=checkpoint_dir,
+      hooks=hooks,
+      stop_grace_period_secs=0) as sess:
+    while not sess.should_stop():
+      global_step_, loss_, accuracy_, _ = sess.run(
+          [global_step, loss, accuracy, update_op])
+      tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
+                      loss_, accuracy_)
+  return accuracy_
+
+
+def distributed_grads_and_ops_dedicated_workers(
+    task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
+    loss, accuracy, layer_collection):
+  """Minimize loss with a synchronous implementation of K-FAC.
+
+  Different workers are responsible for different parts of K-FAC's Ops. The
+  first 60% of tasks compute gradients; the next 20% accumulate covariance
+  statistics; the last 20% invert the matrices used to precondition gradients.
+  The chief worker applies the gradient .
+
+  Args:
+    task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+    is_chief: `boolean`, `True` if the worker is chief worker.
+    num_worker_tasks: int. Number of workers in this distributed training setup.
+    num_ps_tasks: int. Number of parameter servers holding variables. If 0,
+      parameter servers are not used.
+    master: string. IP and port of TensorFlow runtime process. Set to empty
+      string to run locally.
+    checkpoint_dir: string or None. Path to store checkpoints under.
+    loss: 0-D Tensor. Loss to be minimized.
+    accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
+      run with each step.
+    layer_collection: LayerCollection instance describing model architecture.
+      Used by K-FAC to construct preconditioner.
+
+  Returns:
+    final value for 'accuracy'.
+
+  Raises:
+    ValueError: if task_id >= num_worker_tasks.
+  """
+  sync_optimizer, optimizer, global_step = _make_distributed_train_op(
+      task_id, num_worker_tasks, num_ps_tasks, layer_collection)
+  _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()
+  train_op = sync_optimizer.minimize(loss, global_step=global_step)
+  inv_update_queue = oq.OpQueue(inv_update_ops)
 
   tf.logging.info("Starting training.")
   is_chief = (task_id == 0)
@@ -306,7 +441,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
       if _is_gradient_task(task_id, num_worker_tasks):
         learning_op = train_op
       elif _is_cov_update_task(task_id, num_worker_tasks):
-        learning_op = optimizer.cov_update_op
+        learning_op = cov_update_op
       elif _is_inv_update_task(task_id, num_worker_tasks):
         # TODO(duckworthd): Running this op before cov_update_op has been run a
         # few times can result in "InvalidArgumentError: Cholesky decomposition
@@ -324,13 +459,18 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
   return accuracy_
 
 
-def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
+def train_mnist_single_machine(data_dir,
+                               num_epochs,
+                               use_fake_data=False,
+                               device="/gpu:0"):
   """Train a ConvNet on MNIST.
 
   Args:
     data_dir: string. Directory to read MNIST examples from.
     num_epochs: int. Number of passes to make over the training set.
     use_fake_data: bool. If True, generate a synthetic dataset.
+    device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and inverse
+      update ops are run on this device.
 
   Returns:
     accuracy of model on the final minibatch of training data.
@@ -350,22 +490,38 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
       examples, labels, num_labels=10, layer_collection=layer_collection)
 
   # Fit model.
-  return minimize_loss_single_machine(loss, accuracy, layer_collection)
+  return minimize_loss_single_machine(
+      loss, accuracy, layer_collection, device=device)
 
 
 def train_mnist_multitower(data_dir, num_epochs, num_towers,
-                           use_fake_data=True):
+                           use_fake_data=True, devices=None):
   """Train a ConvNet on MNIST.
 
+  Training data is split equally among the towers. Each tower computes loss on
+  its own batch of data and the loss is aggregated on the CPU. The model
+  variables are placed on first tower. The covariance and inverse update ops
+  and variables are placed on GPUs in a round robin manner.
+
   Args:
     data_dir: string. Directory to read MNIST examples from.
     num_epochs: int. Number of passes to make over the training set.
     num_towers: int. Number of CPUs to split inference across.
     use_fake_data: bool. If True, generate a synthetic dataset.
+    devices: string, Either list of CPU or GPU. The covaraince and inverse
+      update ops are run on this device.
 
   Returns:
     accuracy of model on the final minibatch of training data.
   """
+  if devices:
+    device_count = {"GPU": num_towers}
+  else:
+    device_count = {"CPU": num_towers}
+
+  devices = devices or [
+      "/cpu:{}".format(tower_id) for tower_id in range(num_towers)
+  ]
   # Load a dataset.
   tf.logging.info("Loading MNIST into memory.")
   tower_batch_size = 128
@@ -388,7 +544,7 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers,
   layer_collection = lc.LayerCollection()
   tower_results = []
   for tower_id in range(num_towers):
-    with tf.device("/cpu:%d" % tower_id):
+    with tf.device(devices[tower_id]):
       with tf.name_scope("tower%d" % tower_id):
         with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
           tf.logging.info("Building tower %d." % tower_id)
@@ -402,34 +558,79 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers,
   accuracy = tf.reduce_mean(accuracies)
 
   # Fit model.
+
   session_config = tf.ConfigProto(
-      allow_soft_placement=False, device_count={
-          "CPU": num_towers
-      })
-  return minimize_loss_single_machine(
-      loss, accuracy, layer_collection, session_config=session_config)
+      allow_soft_placement=False,
+      device_count=device_count,
+  )
+
+  g_step = tf.train.get_or_create_global_step()
+  optimizer = opt.KfacOptimizer(
+      learning_rate=0.0001,
+      cov_ema_decay=0.95,
+      damping=0.001,
+      layer_collection=layer_collection,
+      placement_strategy="round_robin",
+      cov_devices=devices,
+      inv_devices=devices,
+      momentum=0.9)
+  (cov_update_thunks,
+   inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
 
+  train_op = optimizer.minimize(loss, global_step=g_step)
 
-def train_mnist_distributed(task_id,
-                            num_worker_tasks,
-                            num_ps_tasks,
-                            master,
-                            data_dir,
-                            num_epochs,
-                            use_fake_data=False):
-  """Train a ConvNet on MNIST.
+  def make_update_op(update_thunks):
+    update_op = [thunk() for thunk in update_thunks]
+    return tf.group(*update_op)
+
+  cov_update_op = make_update_op(cov_update_thunks)
+  with tf.control_dependencies([train_op, cov_update_op]):
+    inverse_op = tf.cond(
+        tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0),
+        lambda: make_update_op(inv_update_thunks), tf.no_op)
+
+  tf.logging.info("Starting training.")
+  with tf.train.MonitoredTrainingSession(config=session_config) as sess:
+    while not sess.should_stop():
+      global_step_, loss_, accuracy_, _ = sess.run(
+          [g_step, loss, accuracy, inverse_op])
+
+      if (global_step_ + 1) % _INVERT_EVERY == 0:
+        tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
+                        global_step_, loss_, accuracy_)
+
+
+def train_mnist_distributed_sync_replicas(task_id,
+                                          is_chief,
+                                          num_worker_tasks,
+                                          num_ps_tasks,
+                                          master,
+                                          data_dir,
+                                          num_epochs,
+                                          op_strategy,
+                                          use_fake_data=False):
+  """Train a ConvNet on MNIST using Sync replicas optimizer.
 
   Args:
     task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+    is_chief: `boolean`, `True` if the worker is chief worker.
     num_worker_tasks: int. Number of workers in this distributed training setup.
     num_ps_tasks: int. Number of parameter servers holding variables.
     master: string. IP and port of TensorFlow runtime process.
     data_dir: string. Directory to read MNIST examples from.
     num_epochs: int. Number of passes to make over the training set.
+    op_strategy: `string`, Strategy to run the covariance and inverse
+      ops. If op_strategy == `chief_worker` then covaraiance and inverse
+      update ops are run on chief worker otherwise they are run on dedicated
+      workers.
+
     use_fake_data: bool. If True, generate a synthetic dataset.
 
   Returns:
     accuracy of model on the final minibatch of training data.
+
+  Raises:
+    ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"].
   """
   # Load a dataset.
   tf.logging.info("Loading MNIST into memory.")
@@ -448,9 +649,17 @@ def train_mnist_distributed(task_id,
 
   # Fit model.
   checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
-  return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks,
-                                   master, checkpoint_dir, loss, accuracy,
-                                   layer_collection)
+  if op_strategy == "chief_worker":
+    return distributed_grads_only_and_ops_chief_worker(
+        task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
+        checkpoint_dir, loss, accuracy, layer_collection)
+  elif op_strategy == "dedicated_workers":
+    return distributed_grads_and_ops_dedicated_workers(
+        task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
+        checkpoint_dir, loss, accuracy, layer_collection)
+  else:
+    raise ValueError("Only supported op strategies are : {}, {}".format(
+        "chief_worker", "dedicated_workers"))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
new file mode 100644 (file)
index 0000000..b4c2d4a
--- /dev/null
@@ -0,0 +1,62 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train a ConvNet on MNIST using K-FAC.
+
+Distributed training with sync replicas optimizer. See
+`convnet.train_mnist_distributed_sync_replicas` for details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from absl import flags
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import convnet
+
+FLAGS = flags.FLAGS
+flags.DEFINE_integer("task", -1, "Task identifier")
+flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
+flags.DEFINE_string(
+    "cov_inv_op_strategy", "chief_worker",
+    "In dist training mode run the cov, inv ops on chief or dedicated workers."
+)
+flags.DEFINE_string("master", "local", "Session master.")
+flags.DEFINE_integer("ps_tasks", 2,
+                     "Number of tasks in the parameter server job.")
+flags.DEFINE_integer("replicas_to_aggregate", 5,
+                     "Number of replicas to aggregate.")
+flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.")
+flags.DEFINE_integer("num_epochs", None, "Number of epochs.")
+
+
+def _is_chief():
+  """Determines whether a job is the chief worker."""
+  if "chief_worker" in FLAGS.brain_jobs:
+    return FLAGS.brain_job_name == "chief_worker"
+  else:
+    return FLAGS.task == 0
+
+
+def main(unused_argv):
+  _ = unused_argv
+  convnet.train_mnist_distributed_sync_replicas(
+      FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks,
+      FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy)
+
+if __name__ == "__main__":
+  tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
new file mode 100644 (file)
index 0000000..4249bf8
--- /dev/null
@@ -0,0 +1,48 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train a ConvNet on MNIST using K-FAC.
+
+Multi tower training mode. See `convnet.train_mnist_multitower` for details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from absl import flags
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import convnet
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir")
+flags.DEFINE_integer("num_towers", 2,
+                     "Number of towers for multi tower training.")
+
+
+def main(unused_argv):
+  _ = unused_argv
+  assert FLAGS.num_towers > 1
+  devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)]
+  convnet.train_mnist_multitower(
+      FLAGS.data_dir,
+      num_epochs=200,
+      num_towers=FLAGS.num_towers,
+      devices=devices)
+
+
+if __name__ == "__main__":
+  tf.app.run(main=main)
 # ==============================================================================
 r"""Train a ConvNet on MNIST using K-FAC.
 
-See convnet.py for details.
+Train on single machine. See `convnet.train_mnist_single_machine` for details.
 """
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import argparse
-import sys
 
+from absl import flags
 import tensorflow as tf
 
 from tensorflow.contrib.kfac.examples import convnet
 
-FLAGS = None
+FLAGS = flags.FLAGS
+flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
 
 
-def main(argv):
-  _ = argv
-
-  if FLAGS.num_towers > 1:
-    convnet.train_mnist_multitower(
-        FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
-  else:
-    convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
+def main(unused_argv):
+  convnet.train_mnist_single_gpu(FLAGS.data_dir, num_epochs=200)
 
 
 if __name__ == "__main__":
-  parser = argparse.ArgumentParser()
-  parser.add_argument(
-      "--data_dir",
-      type=str,
-      default="/tmp/mnist",
-      help="Directory to store dataset in.")
-  parser.add_argument(
-      "--num_towers",
-      type=int,
-      default=1,
-      help="Number of CPUs to split minibatch across.")
-  FLAGS, unparsed = parser.parse_known_args()
-  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+  tf.app.run(main=main)
index 8d86c2b..6de775c 100644 (file)
@@ -112,15 +112,16 @@ class ConvNetTest(tf.test.TestCase):
   def testMinimizeLossSingleMachine(self):
     with tf.Graph().as_default():
       loss, accuracy, layer_collection = self._build_toy_problem()
-      accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy,
-                                                       layer_collection)
-      self.assertLess(accuracy_, 1.0)
+      accuracy_ = convnet.minimize_loss_single_machine(
+          loss, accuracy, layer_collection, device="/cpu:0")
+      self.assertLess(accuracy_, 2.0)
 
   def testMinimizeLossDistributed(self):
     with tf.Graph().as_default():
       loss, accuracy, layer_collection = self._build_toy_problem()
-      accuracy_ = convnet.minimize_loss_distributed(
+      accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker(
           task_id=0,
+          is_chief=True,
           num_worker_tasks=1,
           num_ps_tasks=0,
           master="",
@@ -128,7 +129,7 @@ class ConvNetTest(tf.test.TestCase):
           loss=loss,
           accuracy=accuracy,
           layer_collection=layer_collection)
-      self.assertLess(accuracy_, 1.0)
+      self.assertLess(accuracy_, 2.0)
 
   def testTrainMnistSingleMachine(self):
     with tf.Graph().as_default():
@@ -138,7 +139,7 @@ class ConvNetTest(tf.test.TestCase):
       # but there are too few parameters for the model to effectively memorize
       # the training set the way an MLP can.
       convnet.train_mnist_single_machine(
-          data_dir=None, num_epochs=1, use_fake_data=True)
+          data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0")
 
   def testTrainMnistMultitower(self):
     with tf.Graph().as_default():
@@ -149,13 +150,15 @@ class ConvNetTest(tf.test.TestCase):
   def testTrainMnistDistributed(self):
     with tf.Graph().as_default():
       # Ensure model training doesn't crash.
-      convnet.train_mnist_distributed(
+      convnet.train_mnist_distributed_sync_replicas(
           task_id=0,
+          is_chief=True,
           num_worker_tasks=1,
           num_ps_tasks=0,
           master="",
           data_dir=None,
           num_epochs=1,
+          op_strategy="chief_worker",
           use_fake_data=True)