Allow communicating instructions within a kCall computation.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 11 May 2018 18:04:22 +0000 (11:04 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 18:07:41 +0000 (11:07 -0700)
PiperOrigin-RevId: 196278635

tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
tensorflow/compiler/xla/service/hlo_module_group_metadata.h

index 54c34ce..67f4c37 100644 (file)
@@ -47,6 +47,9 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
     case ComputationKind::kConditionalFalse:
       repr += ":CONDITIONAL_FALSE";
       break;
+    case ComputationKind::kCallFunction:
+      repr += ":CALL";
+      break;
   }
   return repr;
 }
@@ -206,6 +209,9 @@ Status HloModuleGroupMetadata::RecordInstructions() {
           TrackedInstruction(hlo, ComputationKind::kConditionalTrue);
       tracked_instructions_[hlo->false_computation()] =
           TrackedInstruction(hlo, ComputationKind::kConditionalFalse);
+    } else if (hlo->opcode() == HloOpcode::kCall) {
+      tracked_instructions_[hlo->to_apply()] =
+          TrackedInstruction(hlo, ComputationKind::kCallFunction);
     }
     if (!IsChannelInstruction(hlo)) {
       return Status::OK();
@@ -258,7 +264,8 @@ Status HloModuleGroupMetadata::RecordInstructions() {
 Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
                                             HloInstruction* instruction2) {
   TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile ||
-               instruction1->opcode() == HloOpcode::kConditional);
+               instruction1->opcode() == HloOpcode::kConditional ||
+               instruction1->opcode() == HloOpcode::kCall);
   VLOG(2) << "adding as companions:" << instruction1->ToString() << " and "
           << instruction2->ToString();
 
@@ -336,21 +343,11 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
     }
   }
 
-  // Check if channel instructions are used only in allowed computations.
-  const auto allowed = [this](HloInstruction* hlo) {
-    HloComputation* computation = hlo->parent();
-    const HloModule* module = computation->parent();
-    if (module->entry_computation() == computation ||
-        tracked_instructions_.count(computation) > 0) {
-      return true;
-    }
-    return false;
-  };
   for (const Channel& channel : channels_) {
-    if (!allowed(channel.send) || !allowed(channel.send_done) ||
-        !allowed(channel.recv) || !allowed(channel.recv_done)) {
-      return FailedPrecondition("channel is used in disallowed computation");
-    }
+    TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send));
+    TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done));
+    TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv));
+    TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done));
   }
   // Check if the nest levels match for each channel.
   for (const Channel& channel : channels_) {
@@ -368,4 +365,15 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
   return Status::OK();
 }
 
+Status HloModuleGroupMetadata::CheckCommunicatingInstruction(
+    HloInstruction* instruction) const {
+  HloComputation* computation = instruction->parent();
+  const HloModule* module = computation->parent();
+  if (module->entry_computation() == computation ||
+      tracked_instructions_.count(computation) > 0) {
+    return Status::OK();
+  }
+  return FailedPrecondition("channel is used in disallowed computation");
+}
+
 }  // namespace xla
index c48a7ab..88ed9a2 100644 (file)
@@ -60,6 +60,7 @@ class HloModuleGroupMetadata {
     kWhileBody,
     kConditionalTrue,
     kConditionalFalse,
+    kCallFunction,
   };
 
   // Tracks the instruction mapped to a given computation, and the computation
@@ -202,6 +203,10 @@ class HloModuleGroupMetadata {
   Status AddCompanion(HloInstruction* instruction1,
                       HloInstruction* instruction2);
 
+  // Checks whether a communicating instruction is placed in a valid position
+  // within the graph.
+  Status CheckCommunicatingInstruction(HloInstruction* instruction) const;
+
   // Retrieves a pointer to the stored TrackedInstruction associated with a
   // tracked computation, or nullptr in case such computation is not tracked.
   const TrackedInstruction* GetTrackedInstruction(