case ComputationKind::kConditionalFalse:
repr += ":CONDITIONAL_FALSE";
break;
+ case ComputationKind::kCallFunction:
+ repr += ":CALL";
+ break;
}
return repr;
}
TrackedInstruction(hlo, ComputationKind::kConditionalTrue);
tracked_instructions_[hlo->false_computation()] =
TrackedInstruction(hlo, ComputationKind::kConditionalFalse);
+ } else if (hlo->opcode() == HloOpcode::kCall) {
+ tracked_instructions_[hlo->to_apply()] =
+ TrackedInstruction(hlo, ComputationKind::kCallFunction);
}
if (!IsChannelInstruction(hlo)) {
return Status::OK();
Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
HloInstruction* instruction2) {
TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile ||
- instruction1->opcode() == HloOpcode::kConditional);
+ instruction1->opcode() == HloOpcode::kConditional ||
+ instruction1->opcode() == HloOpcode::kCall);
VLOG(2) << "adding as companions:" << instruction1->ToString() << " and "
<< instruction2->ToString();
}
}
- // Check if channel instructions are used only in allowed computations.
- const auto allowed = [this](HloInstruction* hlo) {
- HloComputation* computation = hlo->parent();
- const HloModule* module = computation->parent();
- if (module->entry_computation() == computation ||
- tracked_instructions_.count(computation) > 0) {
- return true;
- }
- return false;
- };
for (const Channel& channel : channels_) {
- if (!allowed(channel.send) || !allowed(channel.send_done) ||
- !allowed(channel.recv) || !allowed(channel.recv_done)) {
- return FailedPrecondition("channel is used in disallowed computation");
- }
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send));
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done));
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv));
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done));
}
// Check if the nest levels match for each channel.
for (const Channel& channel : channels_) {
return Status::OK();
}
+Status HloModuleGroupMetadata::CheckCommunicatingInstruction(
+ HloInstruction* instruction) const {
+ HloComputation* computation = instruction->parent();
+ const HloModule* module = computation->parent();
+ if (module->entry_computation() == computation ||
+ tracked_instructions_.count(computation) > 0) {
+ return Status::OK();
+ }
+ return FailedPrecondition("channel is used in disallowed computation");
+}
+
} // namespace xla
kWhileBody,
kConditionalTrue,
kConditionalFalse,
+ kCallFunction,
};
// Tracks the instruction mapped to a given computation, and the computation
Status AddCompanion(HloInstruction* instruction1,
HloInstruction* instruction2);
+ // Checks whether a communicating instruction is placed in a valid position
+ // within the graph.
+ Status CheckCommunicatingInstruction(HloInstruction* instruction) const;
+
// Retrieves a pointer to the stored TrackedInstruction associated with a
// tracked computation, or nullptr in case such computation is not tracked.
const TrackedInstruction* GetTrackedInstruction(