Propagate outfeed sharding, if specified from TensorFlow.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 3 Feb 2018 02:03:11 +0000 (18:03 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 3 Feb 2018 02:07:05 +0000 (18:07 -0800)
PiperOrigin-RevId: 184361221

tensorflow/compiler/xla/service/service.cc
tensorflow/compiler/xla/service/user_computation.cc
tensorflow/compiler/xla/service/user_computation.h
tensorflow/compiler/xla/service/user_computation_test.cc

index a57b7e5..98dfc89 100644 (file)
@@ -1453,9 +1453,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
       handle_status = computation->AddInfeedInstruction(arg->infeed_request());
       break;
     case OpRequest::kOutfeedRequest:
-      TF_RETURN_IF_ERROR(
-          computation->AddOutfeedInstruction(arg->outfeed_request()));
-      return tensorflow::Status::OK();
+      handle_status =
+          computation->AddOutfeedInstruction(arg->outfeed_request());
+      break;
     case OpRequest::kMapRequest: {
       TF_ASSIGN_OR_RETURN(
           UserComputation * to_apply,
index 2ea6507..ef9c80b 100644 (file)
@@ -1185,7 +1185,7 @@ StatusOr<ComputationDataHandle> UserComputation::AddInfeedInstruction(
   return handle;
 }
 
-Status UserComputation::AddOutfeedInstruction(
+StatusOr<ComputationDataHandle> UserComputation::AddOutfeedInstruction(
     const OutfeedRequest& outfeed_request) {
   tensorflow::mutex_lock lock(mutex_);
 
@@ -1197,8 +1197,6 @@ Status UserComputation::AddOutfeedInstruction(
   // Verify that operand is valid.
   TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status());
 
-  // No handle is returned, but a handle must be assigned to this instruction
-  // for computation versioning.
   ComputationDataHandle handle = CreateComputationDataHandle();
   OperationRequest& request =
       (*session_computation_.mutable_requests())[handle.handle()];
@@ -1209,7 +1207,7 @@ Status UserComputation::AddOutfeedInstruction(
   VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal()
           << "), data handle " << handle.handle() << ": "
           << outfeed_request.ShortDebugString();
-  return Status::OK();
+  return handle;
 }
 
 StatusOr<ComputationDataHandle> UserComputation::AddCallInstruction(
index 4f92e58..54bb24d 100644 (file)
@@ -146,7 +146,8 @@ class UserComputation {
       const InfeedRequest& infeed_request);
 
   // Enqueues an outfeed instruction onto this user computation.
-  Status AddOutfeedInstruction(const OutfeedRequest& outfeed_request);
+  StatusOr<ComputationDataHandle> AddOutfeedInstruction(
+      const OutfeedRequest& outfeed_request);
 
   // Enqueues a call instruction onto this user computation.
   StatusOr<ComputationDataHandle> AddCallInstruction(
index ca02115..2fa1639 100644 (file)
@@ -67,7 +67,8 @@ TEST_F(UserComputationTest, SimpleComputation) {
   *outfeed_request.mutable_operand() = constant_handle;
   *outfeed_request.mutable_shape() = kVectorShape;
   outfeed_request.set_outfeed_config("abc");
-  TF_ASSERT_OK(computation.AddOutfeedInstruction(outfeed_request));
+  TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle,
+                          computation.AddOutfeedInstruction(outfeed_request));
 
   auto hlo_resolver = [](const VersionedComputationHandle& handle) {
     return nullptr;