from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu
+from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.estimator import model_fn as model_fn_lib
return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name)
+def _replicated_optimizer(opt, num_replicas):
+ """Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
+ if num_replicas == 1:
+ return opt
+ return keras_optimizers.TFOptimizer(
+ optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer)
+ )
+
+
class TPUFunction(object):
"""K.function compatible interface for invoking a TPU compiled function.
instead of being injected as `feed_dict` items or fetches.
"""
- def __init__(self, model, execution_mode):
+ def __init__(self, model, execution_mode, num_replicas=1):
self.model = model
self.execution_mode = execution_mode
self._compilation_cache = {}
+ self.num_replicas = num_replicas
def _specialize_model(self, input_specs):
"""Specialize `self.model` (a Keras model) for the given input shapes."""
# Call our model with our infeed inputs (re-using the weights).
model_outputs = self.model(tpu_inputs)
child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs)
+
if is_training or is_test:
child_model.compile(
- optimizer=self.model.optimizer,
+ optimizer=_replicated_optimizer(self.model.optimizer,
+ self.num_replicas),
loss=self.model.loss,
loss_weights=self.model.loss_weights,
metrics=self.model.metrics,
return [
child_model.train_function.updates_op,
tpu_ops.outfeed_enqueue_tuple(
- child_model.train_function.outputs, name='oufeed-enqueue-train')
+ child_model.train_function.outputs,
+ name='outfeed-enqueue-train')
]
elif is_test:
child_model._make_test_function()
]
return [
tpu_ops.outfeed_enqueue_tuple(
- child_model.test_function.outputs, name='outfeed-enqueue-test')
+ child_model.test_function.outputs,
+ name='outfeed-enqueue-test')
]
elif is_predict:
child_model._make_predict_function()
# Capture outfeed metadata computed during the rewrite.
self._outfeed_spec = None
+ # Generate out TPU operations using `tpu.split_compile_and_replicate`.
+ # `compile_op` can be used to test the TPU model compiles before execution.
+ # `execute op` replicates `_model_fn` `num_replicas` times, with each shard
+ # running on a different logical core.
compile_op, execute_op = tpu.split_compile_and_replicate(
- _model_fn, inputs=[[]])
+ _model_fn, inputs=[[]] * self.num_replicas)
# Generate CPU side operations to enqueue features/labels and dequeue
# outputs from the model call.
- with ops.device('/device:TPU:0'):
- infeed_tensors = []
- for spec in input_specs:
- infeed_tensors.append(
- array_ops.placeholder(
- dtype=spec.dtype,
- shape=spec.shape,
- name='infeed-enqueue-%s' % spec.name))
-
- infeed_op = tpu_ops.infeed_enqueue_tuple(
- infeed_tensors, [spec.shape for spec in input_specs],
- name='infeed-enqueue-%s' % self.execution_mode)
-
- outfeed_op = tpu_ops.outfeed_dequeue_tuple(
- dtypes=[spec.dtype for spec in self._outfeed_spec],
- shapes=[spec.shape for spec in self._outfeed_spec],
- name='outfeed-dequeue-%s' % self.execution_mode)
+ infeed_op = []
+ outfeed_op = []
+ shard_infeed_tensors = []
+
+ for shard_id in range(self.num_replicas):
+ with ops.device('/device:TPU:%d' % shard_id):
+ infeed_tensors = []
+ for spec in input_specs:
+ infeed_tensors.append(
+ array_ops.placeholder(
+ dtype=spec.dtype,
+ shape=spec.shape,
+ name='infeed-enqueue-%s-%d' % (spec.name, shard_id)))
+ shard_infeed_tensors.append(infeed_tensors)
+
+ infeed_op.append(tpu_ops.infeed_enqueue_tuple(
+ infeed_tensors, [spec.shape for spec in input_specs],
+ name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id)))
+
+ outfeed_op.extend(tpu_ops.outfeed_dequeue_tuple(
+ dtypes=[spec.dtype for spec in self._outfeed_spec],
+ shapes=[spec.shape for spec in self._outfeed_spec],
+ name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id)))
return TPUModelOp(
- compile_op, execute_op, infeed_tensors, infeed_op, outfeed_op)
+ compile_op, execute_op, infeed_tensors=shard_infeed_tensors,
+ infeed_op=infeed_op, outfeed_op=outfeed_op)
def _test_model_compiles(self, tpu_model_ops):
"""Verifies that the given TPUModelOp can be compiled via XLA."""
logging.info('Finished compiling. Time elapsed: %s secs',
end_time - start_time)
+ def _split_tensors(self, inputs):
+ """Split input data across shards.
+
+ Each input is sliced along the batch axis.
+
+ Args:
+ inputs: List of Numpy arrays to run on the TPU.
+
+ Returns:
+ List of lists containing the input to feed to each TPU shard.
+ """
+ if self.num_replicas == 1:
+ return [inputs]
+
+ batch_size = inputs[0].shape[0]
+ assert batch_size % self.num_replicas == 0, (
+ 'batch_size must be divisible by num_replicas')
+ shard_size = batch_size // self.num_replicas
+ input_list = []
+ for index in range(self.num_replicas):
+ shard_inputs = [x[index * shard_size:(index + 1) * shard_size]
+ for x in inputs]
+ input_list.append(shard_inputs)
+ return input_list
+
def __call__(self, inputs):
assert isinstance(inputs, list)
else:
input_tensors = self.model._feed_inputs
+ shard_inputs = self._split_tensors(inputs)
+ del inputs # To avoid accident usage.
+
# Compute an input specification (used to generate infeed enqueue and
# dequeue operations). We use the shape from our input array and the
# dtype from our model. A user may pass in a float64 for a float32
# input: for model compatibility we still must generate a float32 infeed.
input_specs = []
- for tensor, ary in zip(input_tensors, inputs):
+
+ # We use the shape and dtype from the first shard to compute the input
+ # metadata (`input_specs`); all replicas have the same type and shape.
+ for tensor, ary in zip(input_tensors, shard_inputs[0]):
input_specs.append(
tensor_spec.TensorSpec(ary.shape, tensor.dtype,
_valid_name(tensor.name)))
tpu_model_ops = self._compilation_cache[shape_key]
infeed_dict = {}
- for tensor, value in zip(tpu_model_ops.infeed_tensors, inputs):
- infeed_dict[tensor] = value
+ for infeed_tensors, inputs in zip(tpu_model_ops.infeed_tensors,
+ shard_inputs):
+ for tensor, value in zip(infeed_tensors, inputs):
+ infeed_dict[tensor] = value
session = K.get_session()
_, _, outfeed_outputs = session.run([
tpu_model_ops.outfeed_op
], infeed_dict)
- return outfeed_outputs
+ # TODO(xiejw): Decide how to reduce outputs, or just discard all but first.
+ return outfeed_outputs[:len(outfeed_outputs) // self.num_replicas]
@experimental
class KerasTPUModel(models.Model):
"""TPU compatible Keras model wrapper."""
- def __init__(self, inputs, outputs, name=None):
+ def __init__(self, inputs, outputs, name, replicas=1):
super(models.Model, self).__init__(
inputs=inputs,
outputs=outputs,
self.predict_function = None
self.test_function = None
self.train_function = None
+ self.replicas = replicas
def compile(self,
optimizer,
def _make_train_function(self):
if not self.train_function:
- self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN)
+ self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN,
+ num_replicas=self.replicas)
return self.train_function
@experimental
-def tpu_model(model):
+def tpu_model(model, replicas=None):
+ """Runs a model on TPU(s).
+
+ Usage:
+ ```
+ a = Input(shape=(32,))
+ b = Dense(32)(a)
+ model = Model(inputs=a, outputs=b)
+
+ model = keras_support.tpu_model(model)
+ model.compile(
+ optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
+ ...)
+ ```
+
+ If `replicas` is set, replicates the model computation on all TPU cores. The
+ model computation is replicated `num_replicas` times; each shard will run on a
+ different TPU core.
+
+ Limitation: Currently, replication is only supported for training.
+
+ Usage:
+ ```
+ a = Input(shape=(32,))
+ b = Dense(32)(a)
+ model = Model(inputs=a, outputs=b)
+
+ model = keras_support.tpu_model(model, replicas=2)
+ model.compile(
+ optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
+ ...)
+ ```
+
+ Args:
+ model: A `KerasTPUModel`.
+ replicas: (Optional) Int, number of TPU cores which to create model
+ replicas. If `None`, the model runs on single core only, i.e., no
+ replication.
+
+ Returns:
+ A new `KerasTPUModel` instance.
+ """
_validate_shapes(model)
+ # TODO(xiejw): Validate TPU model. TPUModel only?
+ # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
+ # TODO(xiejw): Adds reduction option.
+ replicas = 1 if replicas is None else replicas
return KerasTPUModel(
- inputs=model.inputs, outputs=model.outputs, name=model.name)
+ inputs=model.inputs, outputs=model.outputs, name=model.name,
+ replicas=replicas)