Early TPU distribution strategy and the associated testing infrastructure.
authorIgor Saprykin <isaprykin@google.com>
Mon, 16 Apr 2018 19:21:15 +0000 (12:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 16 Apr 2018 19:23:19 +0000 (12:23 -0700)
PiperOrigin-RevId: 193080098

tensorflow/contrib/distribute/python/BUILD
tensorflow/contrib/distribute/python/combinations.py
tensorflow/contrib/distribute/python/minimize_loss_test.py
tensorflow/contrib/distribute/python/tpu_strategy.py [new file with mode: 0644]

index 5aad21c..837a1f1 100644 (file)
@@ -131,6 +131,7 @@ py_library(
     deps = [
         ":mirrored_strategy",
         ":one_device_strategy",
+        ":tpu_strategy",
         "//tensorflow/contrib/optimizer_v2:training",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:training",
@@ -225,14 +226,30 @@ py_library(
     ],
 )
 
-cuda_py_test(
-    name = "minimize_loss_test",
+py_library(
+    name = "tpu_strategy",
+    srcs = ["tpu_strategy.py"],
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        "//tensorflow/contrib/distribute/python:one_device_strategy",
+        "//tensorflow/contrib/eager/python:datasets",
+        "//tensorflow/contrib/optimizer_v2:training",
+        "//tensorflow/contrib/tpu",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python/eager:context",
+        "@six_archive//:six",
+    ],
+)
+
+py_library(
+    name = "minimize_loss_test_lib",
+    testonly = 1,
     srcs = ["minimize_loss_test.py"],
-    additional_deps = [
+    deps = [
         ":combinations",
         ":single_loss_example",
-        "@absl_py//absl/testing:parameterized",
-        "//third_party/py/numpy",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:variables",
@@ -240,6 +257,16 @@ cuda_py_test(
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:test",
         "//tensorflow/python/ops/losses",
+        "//third_party/py/numpy",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+cuda_py_test(
+    name = "minimize_loss_test",
+    srcs = ["minimize_loss_test.py"],
+    additional_deps = [
+        ":minimize_loss_test_lib",
     ],
     tags = [
         "multi_and_single_gpu",
index 02b1e7e..1f66997 100644 (file)
@@ -45,6 +45,7 @@ from absl.testing import parameterized
 
 from tensorflow.contrib.distribute.python import mirrored_strategy
 from tensorflow.contrib.distribute.python import one_device_strategy
+from tensorflow.contrib.distribute.python import tpu_strategy
 from tensorflow.contrib.optimizer_v2 import adam as adam_v2
 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
 from tensorflow.python.eager import context
@@ -55,6 +56,7 @@ from tensorflow.python.util import tf_inspect
 
 
 GPU_TEST = "test_gpu" in sys.argv[0]
+TPU_TEST = "test_tpu" in sys.argv[0]
 
 
 def generate(combinations):
@@ -108,6 +110,11 @@ def generate(combinations):
       if "distribution" in kwargs:
         distribution = kwargs["distribution"]
         kwargs["distribution"] = distribution.strategy
+        if distribution.required_tpu and not TPU_TEST:
+          self.skipTest("Test requires a TPU, but it's not available.")
+        if not distribution.required_tpu and TPU_TEST:
+          self.skipTest("Test that doesn't require a TPU.")
+
         if not distribution.required_gpus:
           if GPU_TEST:
             self.skipTest("Test that doesn't require GPUs.")
@@ -232,10 +239,12 @@ class NamedObject(object):
 class NamedDistribution(object):
   """Translates DistributionStrategy and its data into a good name."""
 
-  def __init__(self, name, distribution, required_gpus):
+  def __init__(self, name, distribution, required_gpus=None,
+               required_tpu=False):
     self._distribution = distribution
     self._name = name
     self._required_gpus = required_gpus
+    self._required_tpu = required_tpu
 
   def __repr__(self):
     return self._name
@@ -248,10 +257,16 @@ class NamedDistribution(object):
   def required_gpus(self):
     return self._required_gpus
 
+  @property
+  def required_tpu(self):
+    return self._required_tpu
+
 
 one_device_strategy = NamedDistribution(
     "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"),
     None)
+tpu_strategy = NamedDistribution(
+    "TPU", tpu_strategy.TpuStrategy(), required_tpu=True)
 mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
     "MirroredCPUAndGPU",
     mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1)
index 0fa90df..4219d54 100644 (file)
@@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations
 from tensorflow.contrib.distribute.python import mirrored_strategy
 from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example
 from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example
+from tensorflow.contrib.tpu.python.tpu import tpu
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
 from tensorflow.python.eager import test
@@ -42,24 +43,46 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
       combinations.times(
           combinations.distributions_and_v1_optimizers(),
           combinations.combine(mode=["graph"], use_callable_loss=[True, False])
-          + combinations.combine(mode=["eager"], use_callable_loss=[True])))
-  def testTrainNetwork(self, distribution, optimizer_fn,
-                       use_callable_loss=True):
+          + combinations.combine(mode=["eager"], use_callable_loss=[True]),
+          combinations.combine(is_tpu=[False])) +
+      combinations.combine(
+          distribution=[combinations.tpu_strategy],
+          optimizer_fn=[combinations.adam_optimizer_v1_fn],
+          mode=["graph"],
+          use_callable_loss=[False],
+          is_tpu=[True]))
+  def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss,
+                       is_tpu):
     with distribution.scope():
       model_fn, dataset, layer = minimize_loss_example(
           optimizer_fn,
           use_bias=True,
           use_callable_loss=use_callable_loss)
 
+      # TODO(isaprykin):  Eliminate `is_tpu`. Probably add a
+      # `DistributionStrategy.create_monitor` so that each DistributionStrategy
+      # could influence its training loop. That method would return an instance
+      # of Monitor.  TPUMonitor would execute tpu.initialize_system() and
+      # tpu.shutdown_system().
+      if is_tpu:
+        dataset = dataset.batch(2)
+
       iterator = distribution.distribute_dataset(dataset)
 
       def run_step():
+        # TODO(isaprykin): Make iterator get_next() return a list of sub-
+        # batches for each iteration. Pass iterator.get_next() and not iterator
+        # to call_for_each_tower.
         return distribution.group(
             distribution.call_for_each_tower(
-                model_fn, iterator.get_next(), run_concurrently=layer.built))
+                model_fn,
+                iterator.get_next() if not is_tpu else iterator,
+                run_concurrently=layer.built))
 
       if not context.executing_eagerly():
         with self.test_session() as sess:
+          if is_tpu:
+            sess.run(tpu.initialize_system())
           run_step = sess.make_callable(run_step())
         self.evaluate(variables_lib.global_variables_initializer())
 
@@ -70,6 +93,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
         weights.append(self.evaluate(distribution.fetch(layer.kernel)))
         biases.append(self.evaluate(distribution.fetch(layer.bias)))
 
+      if is_tpu:
+        with self.test_session() as sess:
+          sess.run(tpu.shutdown_system())
+
       error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
       is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
       self.assertTrue(is_not_increasing)
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
new file mode 100644 (file)
index 0000000..0ac307d
--- /dev/null
@@ -0,0 +1,82 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TPU Distribution Strategy.
+
+This is experimental.  It's not ready for general use.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import tpu
+from tensorflow.contrib.distribute.python import one_device_strategy
+from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+
+
+# TODO(isaprykin):  Consider whether inheriting is really appropriate.
+class TpuStrategy(one_device_strategy.OneDeviceStrategy):
+
+  def __init__(self, master=None, iterations=None, model_dir=None):
+    super(TpuStrategy, self).__init__('/cpu:0')
+
+  def _call_for_each_tower(self, fn, *args, **kwargs):
+    kwargs.pop('run_concurrently', None)
+
+    # TODO(isaprykin): Give an API for many iterations per step.
+    iterations = 1
+
+    # TODO(isaprykin): Do not hard code shapes and input format :)
+    # TODO(isaprykin): Detect the number of TPU cores automatically.
+
+    def dequeueing_fn(*args, **kwargs):
+      del args, kwargs
+      x, = tpu.infeed_dequeue_tuple(dtypes=[dtypes.float32], shapes=[[1, 1, 1]])
+      return fn(x)
+
+    iterator = args[0]
+
+    def infeed_input(i):
+      """Get input, split it and then enqueue."""
+      batches = iterator.get_next()
+      batches = array_ops.split(batches, 2)
+
+      infeeds = [
+          tpu_ops.infeed_enqueue_tuple(
+              inputs=[batches[j]], shapes=[[1, 1, 1]], device_ordinal=j)
+          for j in range(2)
+      ]
+
+      with ops.control_dependencies(infeeds):
+        return i + 1
+
+    with ops.device('/task:0/device:CPU:0'):
+      enqueue_ops = control_flow_ops.while_loop(
+          lambda i: i < iterations,
+          infeed_input, [constant_op.constant(0)],
+          parallel_iterations=1)
+
+    def iterate_on_tpu():
+      return tpu.repeat(iterations, dequeueing_fn, [])
+
+    with one_device_strategy._OneDeviceTowerContext(self):  # pylint: disable=protected-access
+      tpu_result = tpu.batch_parallel(iterate_on_tpu, [], num_shards=2)
+
+    return control_flow_ops.group(tpu_result, enqueue_ops)