Adds basic TPU replicate training support for Keras.
authorJianwei Xie <xiejw@google.com>
Thu, 17 May 2018 00:23:02 +0000 (17:23 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 00:26:49 +0000 (17:26 -0700)
PiperOrigin-RevId: 196916177

tensorflow/contrib/tpu/python/tpu/keras_support.py
tensorflow/python/keras/_impl/keras/layers/wrappers.py

index 7564c38..9cc841f 100644 (file)
@@ -55,6 +55,7 @@ from tensorflow.contrib.framework.python.framework import experimental
 from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
 from tensorflow.contrib.tpu.python.ops import tpu_ops
 from tensorflow.contrib.tpu.python.tpu import tpu
+from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import session as tf_session
 from tensorflow.python.estimator import model_fn as model_fn_lib
@@ -104,6 +105,15 @@ def _valid_name(tensor_name):
   return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name)
 
 
+def _replicated_optimizer(opt, num_replicas):
+  """Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
+  if num_replicas == 1:
+    return opt
+  return keras_optimizers.TFOptimizer(
+      optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer)
+  )
+
+
 class TPUFunction(object):
   """K.function compatible interface for invoking a TPU compiled function.
 
@@ -116,10 +126,11 @@ class TPUFunction(object):
   instead of being injected as `feed_dict` items or fetches.
   """
 
-  def __init__(self, model, execution_mode):
+  def __init__(self, model, execution_mode, num_replicas=1):
     self.model = model
     self.execution_mode = execution_mode
     self._compilation_cache = {}
+    self.num_replicas = num_replicas
 
   def _specialize_model(self, input_specs):
     """Specialize `self.model` (a Keras model) for the given input shapes."""
@@ -165,9 +176,11 @@ class TPUFunction(object):
       # Call our model with our infeed inputs (re-using the weights).
       model_outputs = self.model(tpu_inputs)
       child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs)
+
       if is_training or is_test:
         child_model.compile(
-            optimizer=self.model.optimizer,
+            optimizer=_replicated_optimizer(self.model.optimizer,
+                                            self.num_replicas),
             loss=self.model.loss,
             loss_weights=self.model.loss_weights,
             metrics=self.model.metrics,
@@ -185,7 +198,8 @@ class TPUFunction(object):
         return [
             child_model.train_function.updates_op,
             tpu_ops.outfeed_enqueue_tuple(
-                child_model.train_function.outputs, name='oufeed-enqueue-train')
+                child_model.train_function.outputs,
+                name='outfeed-enqueue-train')
         ]
       elif is_test:
         child_model._make_test_function()
@@ -195,7 +209,8 @@ class TPUFunction(object):
         ]
         return [
             tpu_ops.outfeed_enqueue_tuple(
-                child_model.test_function.outputs, name='outfeed-enqueue-test')
+                child_model.test_function.outputs,
+                name='outfeed-enqueue-test')
         ]
       elif is_predict:
         child_model._make_predict_function()
@@ -215,31 +230,42 @@ class TPUFunction(object):
     # Capture outfeed metadata computed during the rewrite.
     self._outfeed_spec = None
 
+    # Generate out TPU operations using `tpu.split_compile_and_replicate`.
+    # `compile_op` can be used to test the TPU model compiles before execution.
+    # `execute op` replicates `_model_fn` `num_replicas` times, with each shard
+    # running on a different logical core.
     compile_op, execute_op = tpu.split_compile_and_replicate(
-        _model_fn, inputs=[[]])
+        _model_fn, inputs=[[]] * self.num_replicas)
 
     # Generate CPU side operations to enqueue features/labels and dequeue
     # outputs from the model call.
-    with ops.device('/device:TPU:0'):
-      infeed_tensors = []
-      for spec in input_specs:
-        infeed_tensors.append(
-            array_ops.placeholder(
-                dtype=spec.dtype,
-                shape=spec.shape,
-                name='infeed-enqueue-%s' % spec.name))
-
-      infeed_op = tpu_ops.infeed_enqueue_tuple(
-          infeed_tensors, [spec.shape for spec in input_specs],
-          name='infeed-enqueue-%s' % self.execution_mode)
-
-      outfeed_op = tpu_ops.outfeed_dequeue_tuple(
-          dtypes=[spec.dtype for spec in self._outfeed_spec],
-          shapes=[spec.shape for spec in self._outfeed_spec],
-          name='outfeed-dequeue-%s' % self.execution_mode)
+    infeed_op = []
+    outfeed_op = []
+    shard_infeed_tensors = []
+
+    for shard_id in range(self.num_replicas):
+      with ops.device('/device:TPU:%d' % shard_id):
+        infeed_tensors = []
+        for spec in input_specs:
+          infeed_tensors.append(
+              array_ops.placeholder(
+                  dtype=spec.dtype,
+                  shape=spec.shape,
+                  name='infeed-enqueue-%s-%d' % (spec.name, shard_id)))
+        shard_infeed_tensors.append(infeed_tensors)
+
+        infeed_op.append(tpu_ops.infeed_enqueue_tuple(
+            infeed_tensors, [spec.shape for spec in input_specs],
+            name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id)))
+
+        outfeed_op.extend(tpu_ops.outfeed_dequeue_tuple(
+            dtypes=[spec.dtype for spec in self._outfeed_spec],
+            shapes=[spec.shape for spec in self._outfeed_spec],
+            name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id)))
 
     return TPUModelOp(
-        compile_op, execute_op, infeed_tensors, infeed_op, outfeed_op)
+        compile_op, execute_op, infeed_tensors=shard_infeed_tensors,
+        infeed_op=infeed_op, outfeed_op=outfeed_op)
 
   def _test_model_compiles(self, tpu_model_ops):
     """Verifies that the given TPUModelOp can be compiled via XLA."""
@@ -259,6 +285,31 @@ class TPUFunction(object):
     logging.info('Finished compiling. Time elapsed: %s secs',
                  end_time - start_time)
 
+  def _split_tensors(self, inputs):
+    """Split input data across shards.
+
+    Each input is sliced along the batch axis.
+
+    Args:
+      inputs: List of Numpy arrays to run on the TPU.
+
+    Returns:
+      List of lists containing the input to feed to each TPU shard.
+    """
+    if self.num_replicas == 1:
+      return [inputs]
+
+    batch_size = inputs[0].shape[0]
+    assert batch_size % self.num_replicas == 0, (
+        'batch_size must be divisible by num_replicas')
+    shard_size = batch_size // self.num_replicas
+    input_list = []
+    for index in range(self.num_replicas):
+      shard_inputs = [x[index * shard_size:(index + 1) * shard_size]
+                      for x in inputs]
+      input_list.append(shard_inputs)
+    return input_list
+
   def __call__(self, inputs):
     assert isinstance(inputs, list)
 
@@ -270,12 +321,18 @@ class TPUFunction(object):
     else:
       input_tensors = self.model._feed_inputs
 
+    shard_inputs = self._split_tensors(inputs)
+    del inputs  # To avoid accident usage.
+
     # Compute an input specification (used to generate infeed enqueue and
     # dequeue operations).  We use the shape from our input array and the
     # dtype from our model.  A user may pass in a float64 for a float32
     # input: for model compatibility we still must generate a float32 infeed.
     input_specs = []
-    for tensor, ary in zip(input_tensors, inputs):
+
+    # We use the shape and dtype from the first shard to compute the input
+    # metadata (`input_specs`); all replicas have the same type and shape.
+    for tensor, ary in zip(input_tensors, shard_inputs[0]):
       input_specs.append(
           tensor_spec.TensorSpec(ary.shape, tensor.dtype,
                                  _valid_name(tensor.name)))
@@ -295,8 +352,10 @@ class TPUFunction(object):
     tpu_model_ops = self._compilation_cache[shape_key]
 
     infeed_dict = {}
-    for tensor, value in zip(tpu_model_ops.infeed_tensors, inputs):
-      infeed_dict[tensor] = value
+    for infeed_tensors, inputs in zip(tpu_model_ops.infeed_tensors,
+                                      shard_inputs):
+      for tensor, value in zip(infeed_tensors, inputs):
+        infeed_dict[tensor] = value
 
     session = K.get_session()
     _, _, outfeed_outputs = session.run([
@@ -304,7 +363,8 @@ class TPUFunction(object):
         tpu_model_ops.outfeed_op
     ], infeed_dict)
 
-    return outfeed_outputs
+    # TODO(xiejw): Decide how to reduce outputs, or just discard all but first.
+    return outfeed_outputs[:len(outfeed_outputs) // self.num_replicas]
 
 
 @experimental
@@ -339,7 +399,7 @@ def shutdown_tpu_session(session=None):
 class KerasTPUModel(models.Model):
   """TPU compatible Keras model wrapper."""
 
-  def __init__(self, inputs, outputs, name=None):
+  def __init__(self, inputs, outputs, name, replicas=1):
     super(models.Model, self).__init__(
         inputs=inputs,
         outputs=outputs,
@@ -348,6 +408,7 @@ class KerasTPUModel(models.Model):
     self.predict_function = None
     self.test_function = None
     self.train_function = None
+    self.replicas = replicas
 
   def compile(self,
               optimizer,
@@ -376,7 +437,8 @@ class KerasTPUModel(models.Model):
 
   def _make_train_function(self):
     if not self.train_function:
-      self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN)
+      self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN,
+                                        num_replicas=self.replicas)
 
     return self.train_function
 
@@ -442,7 +504,53 @@ Output shape: %(output_shape)s
 
 
 @experimental
-def tpu_model(model):
+def tpu_model(model, replicas=None):
+  """Runs a model on TPU(s).
+
+  Usage:
+  ```
+  a = Input(shape=(32,))
+  b = Dense(32)(a)
+  model = Model(inputs=a, outputs=b)
+
+  model = keras_support.tpu_model(model)
+  model.compile(
+      optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
+      ...)
+  ```
+
+  If `replicas` is set, replicates the model computation on all TPU cores. The
+  model computation is replicated `num_replicas` times; each shard will run on a
+  different TPU core.
+
+  Limitation: Currently, replication is only supported for training.
+
+  Usage:
+  ```
+  a = Input(shape=(32,))
+  b = Dense(32)(a)
+  model = Model(inputs=a, outputs=b)
+
+  model = keras_support.tpu_model(model, replicas=2)
+  model.compile(
+      optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
+      ...)
+  ```
+
+  Args:
+    model: A `KerasTPUModel`.
+    replicas: (Optional) Int, number of TPU cores which to create model
+        replicas. If `None`, the model runs on single core only, i.e., no
+        replication.
+
+  Returns:
+    A new `KerasTPUModel` instance.
+  """
   _validate_shapes(model)
+  # TODO(xiejw): Validate TPU model. TPUModel only?
+  # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
+  # TODO(xiejw): Adds reduction option.
+  replicas = 1 if replicas is None else replicas
   return KerasTPUModel(
-      inputs=model.inputs, outputs=model.outputs, name=model.name)
+      inputs=model.inputs, outputs=model.outputs, name=model.name,
+      replicas=replicas)
index d1d09bb..7fe5745 100644 (file)
@@ -202,7 +202,7 @@ class TimeDistributed(Wrapper):
           step,
           inputs,
           initial_states=[],
-          input_length=input_shape[0],
+          input_length=input_shape[1],
           unroll=False)
       y = outputs
     else: