From 8e0848160a7d135f728dde2519a32876b8a7e3ce Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 23 Mar 2018 14:52:44 -0700 Subject: [PATCH] Prepare the XLA and TF code to correctly behave once automatic device placement will be enabled by default on computation shapes > 1. PiperOrigin-RevId: 190278826 --- tensorflow/compiler/xla/service/service.cc | 35 +++++++++++++++++----- tensorflow/compiler/xla/service/service.h | 6 ++++ tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 3 +- tensorflow/contrib/tpu/python/tpu/tpu_feed.py | 19 ++++++++++-- 4 files changed, 51 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 04487a4..4f6a823 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -861,6 +861,33 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return tensorflow::Status::OK(); } +tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result) { + ExecuteParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); + // The "result device" selection is a bit hacky, but better than assuming it + // is device 0. We have b/76035356 for restructuring the client API to clean + // up the current asymmetries and support more functionalities. + for (int64 i = 0; i < parallel_result.responses_size(); ++i) { + TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, + allocation_tracker_.ResolveForReplica( + parallel_result.responses(i).output(), 0)); + const Shape& shape = buffer->on_host_shape(); + if (!ShapeUtil::IsEmptyTuple(shape)) { + *result = parallel_result.responses(i); + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return Status::OK(); + } + } + TF_RET_CHECK(parallel_result.responses_size() > 0); + *result = parallel_result.responses(0); + VLOG(1) << "Defaulting to device 0 result"; + return Status::OK(); +} + tensorflow::Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { VLOG(1) << "running execute request: " << arg->ShortDebugString(); @@ -877,13 +904,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, // If we received multiple device handles, we must partition the module. if (arg->execution_options().device_handles_size() > 1) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - TF_RET_CHECK(parallel_result.responses_size() > 0); - *result = parallel_result.responses(0); - return Status::OK(); + return ExecuteOneToN(arg, result); } TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index a76bdd8..3b79920 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -346,6 +346,12 @@ class Service : public ServiceInterface { const std::function(UserComputation*)>& adder); + // Executes a single computation which has more than one target device. + // The N devices are expected to all return an empty tuple, but one, which + // will be the result of this computation. + tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result); + // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index aaa6f3c..152f8c8 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -931,8 +931,7 @@ class _InputPipeline(object): # In the model-parallel case, both the host-side and device-side # computations must agree on the core on which infeed takes place. We # choose to perform infeed on logical core 0 of each replica. - with ops.device(tpu.core(0)): - values = self._infeed_queue.generate_dequeue_op() + values = self._infeed_queue.generate_dequeue_op(tpu_device=0) # The unflatten process uses the structure information recorded above. return self._inputs_structure_recorder.unflatten_features_and_labels( values) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index 42ac6eb..604e660 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -23,6 +23,7 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_sharding from tensorflow.python.framework import dtypes @@ -368,13 +369,20 @@ class InfeedQueue(object): policy.freeze() self._validate() - def generate_dequeue_op(self): + def generate_dequeue_op(self, tpu_device=0): """Generates the device-side Op to dequeue a tuple from the queue. Implicitly freezes the queue configuration if it is not already frozen, which will raise errors if the shapes and types have not been fully specified. + Args: + tpu_device: The TPU device ordinal where the infeed instruction should be + placed. If None, no explicit placement will be performed, and it is up + to the user to call this API from within a proper TPU device scope. + The XLA code will fail if the TPU dequeue instruction is not bound to + any device. + Returns: A list of Outputs corresponding to a shard of infeed dequeued into XLA, suitable for use within a replicated block. @@ -392,8 +400,13 @@ class InfeedQueue(object): policy.get_sharded_shape(shape) for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) ] - return tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + if tpu_device is not None: + with ops.device(tpu.core(tpu_device)): + return tpu_ops.infeed_dequeue_tuple( + dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + else: + return tpu_ops.infeed_dequeue_tuple( + dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) def _generate_enqueue_op(self, inputs, -- 2.7.4