)
cc_library(
+ name = "hlo_module_group_metadata",
+ srcs = ["hlo_module_group_metadata.cc"],
+ hdrs = ["hlo_module_group_metadata.h"],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "hlo_module_group_util",
+ srcs = ["hlo_module_group_util.cc"],
+ hdrs = ["hlo_module_group_util.h"],
+ deps = [
+ ":hlo",
+ ":hlo_module_group_metadata",
+ ":hlo_reachability",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "hlo_scheduling",
srcs = ["hlo_scheduling.cc"],
hdrs = ["hlo_scheduling.h"],
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
+
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
+ string repr =
+ (instruction_ != nullptr) ? instruction_->ToShortString() : "NULL";
+ switch (kind_) {
+ case ComputationKind::kInvalid:
+ repr += ":INVALID";
+ break;
+ case ComputationKind::kWhileCondition:
+ repr += ":WHILE_CONDITION";
+ break;
+ case ComputationKind::kWhileBody:
+ repr += ":WHILE_BODY";
+ break;
+ case ComputationKind::kConditionalTrue:
+ repr += ":CONDITIONAL_TRUE";
+ break;
+ case ComputationKind::kConditionalFalse:
+ repr += ":CONDITIONAL_FALSE";
+ break;
+ }
+ return repr;
+}
+
+/* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>>
+HloModuleGroupMetadata::Build(const std::vector<HloModule*>& modules) {
+ auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules);
+ TF_RETURN_IF_ERROR(metadata->Build());
+ return std::move(metadata);
+}
+
+Status HloModuleGroupMetadata::Build() {
+ TF_RETURN_IF_ERROR(RecordInstructions());
+ TF_RETURN_IF_ERROR(VerifyChannelInstructions());
+
+ // Record all companion while instructions.
+ const auto visitor = [this](HloInstruction* hlo) -> Status {
+ // We only need to process if the instruction is within the computation
+ // of a companion instruction, like in the condition or body computation
+ // of a While.
+ const TrackedInstruction* tracked = GetTrackedInstruction(hlo->parent());
+ if (tracked == nullptr) {
+ return Status::OK();
+ }
+ // Add the parent computation of this channel instruction and its peer
+ // computation (both must be while computations) as companions.
+ if (IsChannelInstruction(hlo)) {
+ HloComputation* peer_computation = PeerComputation(hlo);
+ const TrackedInstruction* peer_tracked =
+ GetTrackedInstruction(peer_computation);
+ TF_RET_CHECK(peer_tracked != nullptr)
+ << "Peer instruction is not a possible companion";
+ TF_RET_CHECK(*tracked == *peer_tracked)
+ << "Peer instruction does not match the computation kind";
+ TF_RETURN_IF_ERROR(
+ AddCompanion(tracked->instruction(), peer_tracked->instruction()));
+ }
+
+ // Add the parents of companion instructions (they must be all of the same
+ // kind of instructions, opcode wise) as companions.
+ if (IsCompanionInstruction(hlo)) {
+ for (HloInstruction* companion : Companions(hlo)) {
+ const TrackedInstruction* companion_tracked =
+ GetTrackedInstruction(companion->parent());
+ TF_RET_CHECK(companion_tracked != nullptr);
+ TF_RET_CHECK(*tracked == *companion_tracked);
+ TF_RETURN_IF_ERROR(AddCompanion(tracked->instruction(),
+ companion_tracked->instruction()));
+ }
+ }
+ return Status::OK();
+ };
+
+ // Visit the computations in postorder so that the companion information grows
+ // from inner computations to outer ones.
+ for (HloModule* module : modules_) {
+ for (HloComputation* computation : module->MakeComputationPostOrder()) {
+ TF_RETURN_IF_ERROR(computation->Accept(visitor));
+ }
+ }
+ return Status::OK();
+}
+
+bool HloModuleGroupMetadata::IsChannelInstruction(
+ const HloInstruction* instruction) const {
+ switch (instruction->opcode()) {
+ case HloOpcode::kSend:
+ case HloOpcode::kRecv:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kRecvDone:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const {
+ return companion_set_index_.count(hlo) > 0;
+}
+
+bool HloModuleGroupMetadata::InstructionCommunicates(
+ HloInstruction* hlo) const {
+ return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo);
+}
+
+const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
+ int64 channel_id) const {
+ CHECK(channel_id_map_.find(channel_id) != channel_id_map_.end());
+ return channels_[channel_id_map_.at(channel_id)];
+}
+
+HloComputation* HloModuleGroupMetadata::PeerComputation(
+ const HloInstruction* instruction) const {
+ CHECK(IsChannelInstruction(instruction));
+ const Channel& channel = GetChannel(instruction->channel_id());
+ switch (instruction->opcode()) {
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ return channel.recv->parent();
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ return channel.send->parent();
+ default:
+ LOG(FATAL) << "opcode not supported";
+ }
+}
+
+std::vector<HloModuleGroupMetadata::TrackedInstruction>
+HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const {
+ std::vector<TrackedInstruction> path;
+ const HloComputation* parent = hlo->parent();
+ const TrackedInstruction* companion;
+ while ((companion = GetTrackedInstruction(parent)) != nullptr) {
+ parent = companion->instruction()->parent();
+ path.push_back(*companion);
+ }
+ return path;
+}
+
+bool HloModuleGroupMetadata::CheckCompanionPathsCompatibility(
+ const std::vector<TrackedInstruction>& path0,
+ const std::vector<TrackedInstruction>& path1) const {
+ if (path0.size() != path1.size()) {
+ VLOG(5) << "Companion path size do not match: " << path0.size()
+ << " != " << path1.size();
+ return false;
+ }
+ for (int64 i = 0; i < path0.size(); ++i) {
+ if (path0[i] != path1[i]) {
+ VLOG(5) << "Companion instructions at path index " << i
+ << " do not have the same opcode: " << path0[i].ToString()
+ << " vs " << path1[i].ToString();
+ return false;
+ }
+ }
+ return true;
+}
+
+int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
+ for (int64 i = 0; i < modules_.size(); ++i) {
+ if (modules_[i] == module) {
+ return i;
+ }
+ }
+ LOG(FATAL) << "unknown module";
+}
+
+Status HloModuleGroupMetadata::RecordInstructions() {
+ const auto visitor = [this](HloInstruction* hlo) -> Status {
+ if (hlo->opcode() == HloOpcode::kWhile) {
+ tracked_instructions_[hlo->while_condition()] =
+ TrackedInstruction(hlo, ComputationKind::kWhileCondition);
+ tracked_instructions_[hlo->while_body()] =
+ TrackedInstruction(hlo, ComputationKind::kWhileBody);
+ } else if (hlo->opcode() == HloOpcode::kConditional) {
+ tracked_instructions_[hlo->true_computation()] =
+ TrackedInstruction(hlo, ComputationKind::kConditionalTrue);
+ tracked_instructions_[hlo->false_computation()] =
+ TrackedInstruction(hlo, ComputationKind::kConditionalFalse);
+ }
+ if (!IsChannelInstruction(hlo)) {
+ return Status::OK();
+ }
+
+ // Add a new channel if needed.
+ if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) {
+ channels_.emplace_back();
+ channels_.back().id = hlo->channel_id();
+ channel_id_map_[hlo->channel_id()] = channels_.size() - 1;
+ }
+ Channel& channel = channels_[channel_id_map_[hlo->channel_id()]];
+
+ if (hlo->opcode() == HloOpcode::kSend) {
+ TF_RET_CHECK(channel.send == nullptr)
+ << "channel id " << hlo->channel_id()
+ << " is used by multiple send instructions";
+ channel.send = hlo;
+ }
+ if (hlo->opcode() == HloOpcode::kRecv) {
+ TF_RET_CHECK(channel.recv == nullptr)
+ << "channel id " << hlo->channel_id()
+ << " is used by multiple recv instructions";
+ channel.recv = hlo;
+ }
+ if (hlo->opcode() == HloOpcode::kSendDone) {
+ TF_RET_CHECK(channel.send_done == nullptr)
+ << "channel id " << hlo->channel_id()
+ << " is used by multiple send-done instructions";
+ channel.send_done = hlo;
+ }
+ if (hlo->opcode() == HloOpcode::kRecvDone) {
+ TF_RET_CHECK(channel.recv_done == nullptr)
+ << "channel id " << hlo->channel_id()
+ << " is used by multiple recv-done instructions";
+ channel.recv_done = hlo;
+ }
+ return Status::OK();
+ };
+
+ for (HloModule* module : modules_) {
+ for (auto* computation : module->computations()) {
+ TF_RETURN_IF_ERROR(computation->Accept(visitor));
+ }
+ }
+ return Status::OK();
+}
+
+Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
+ HloInstruction* instruction2) {
+ TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile ||
+ instruction1->opcode() == HloOpcode::kConditional);
+ VLOG(2) << "adding as companions:" << instruction1->ToString() << " and "
+ << instruction2->ToString();
+
+ if (!ContainsKey(companion_set_index_, instruction1) &&
+ !ContainsKey(companion_set_index_, instruction2)) {
+ companion_sets_.push_back(
+ absl::make_unique<std::unordered_set<HloInstruction*>>());
+ auto companion_set = companion_sets_.back().get();
+ companion_set->insert(instruction1);
+ companion_set->insert(instruction2);
+ companion_set_index_[instruction1] = companion_sets_.size() - 1;
+ companion_set_index_[instruction2] = companion_sets_.size() - 1;
+ } else if (!ContainsKey(companion_set_index_, instruction1)) {
+ companion_sets_[companion_set_index_[instruction2]]->insert(instruction1);
+ companion_set_index_[instruction1] = companion_set_index_[instruction2];
+ } else if (!ContainsKey(companion_set_index_, instruction2)) {
+ companion_sets_[companion_set_index_[instruction1]]->insert(instruction2);
+ companion_set_index_[instruction2] = companion_set_index_[instruction1];
+ } else if (companion_set_index_[instruction1] !=
+ companion_set_index_[instruction2]) {
+ companion_sets_[companion_set_index_[instruction1]]->insert(
+ Companions(instruction2).begin(), Companions(instruction2).end());
+ int64 index_to_remove = companion_set_index_[instruction2];
+ for (HloInstruction* hlo : Companions(instruction2)) {
+ companion_set_index_[hlo] = companion_set_index_[instruction1];
+ }
+ companion_sets_.erase(companion_sets_.begin() + index_to_remove);
+ }
+ return Status::OK();
+}
+
+Status HloModuleGroupMetadata::VerifyChannelInstructions() {
+ for (const Channel& channel : channels_) {
+ if (channel.send == nullptr) {
+ return FailedPrecondition("missing send for id : %lld", channel.id);
+ }
+ if (channel.recv == nullptr) {
+ return FailedPrecondition("missing recv for id : %lld", channel.id);
+ }
+ if (channel.send_done == nullptr) {
+ return FailedPrecondition("missing send-done for id : %lld", channel.id);
+ }
+ if (channel.recv_done == nullptr) {
+ return FailedPrecondition("missing recv-done for id : %lld", channel.id);
+ }
+ }
+
+ // Check if the shapes match for each channel.
+ for (const Channel& channel : channels_) {
+ const Shape& send_shape = channel.send->operand(0)->shape();
+ const Shape& recv_shape = channel.recv_done->shape();
+ if (!ShapeUtil::Compatible(send_shape, recv_shape)) {
+ return FailedPrecondition("send/recv shapes do not match");
+ }
+ }
+
+ // Check if channel instructions are used only in allowed computations.
+ const auto allowed = [this](HloInstruction* hlo) {
+ HloComputation* computation = hlo->parent();
+ const HloModule* module = computation->parent();
+ if (module->entry_computation() == computation ||
+ tracked_instructions_.count(computation) > 0) {
+ return true;
+ }
+ return false;
+ };
+ for (const Channel& channel : channels_) {
+ if (!allowed(channel.send) || !allowed(channel.send_done) ||
+ !allowed(channel.recv) || !allowed(channel.recv_done)) {
+ return FailedPrecondition("channel is used in disallowed computation");
+ }
+ }
+ // Check if the nest levels match for each channel.
+ for (const Channel& channel : channels_) {
+ std::vector<TrackedInstruction> path = GetCompanionsPath(channel.send);
+ if (!CheckCompanionPathsCompatibility(
+ path, GetCompanionsPath(channel.send_done)) ||
+ !CheckCompanionPathsCompatibility(path,
+ GetCompanionsPath(channel.recv)) ||
+ !CheckCompanionPathsCompatibility(
+ path, GetCompanionsPath(channel.recv_done))) {
+ return FailedPrecondition(
+ "Nest companion paths do not match for channel %lld", channel.id);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
+
+#include <memory>
+#include <set>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Class for bookkeeping the information on the given modules, in particular on
+// the interaction between computations.
+//
+// Companion instructions are one of the information collected as we build the
+// metadata. For example, for each While instruction, companion instructions
+// refer to a set of While instructions in other computations that communicate
+// with each other.
+// In the example below with 3 modules, {While_0, While_2, While_5}, {While_1,
+// While_4}, {While_3, While_6} are companion sets.
+//
+// <Module 0> <Module 1> <Module 2>
+// While_0() { While_2() { While_5() {
+// While_1() { Send(0) } While_3() { Send(1) } While_6() { Recv(1) }
+// } While_4() { Recv(0) }
+// }
+//
+// Companion instructions are used to detect cycles in the graph and also for
+// global scheduling.
+class HloModuleGroupMetadata {
+ public:
+ // The kind of companion computation a given instruction can be within.
+ enum class ComputationKind {
+ kInvalid,
+ kWhileCondition,
+ kWhileBody,
+ kConditionalTrue,
+ kConditionalFalse,
+ };
+
+ // Tracks the instruction mapped to a given computation, and the computation
+ // kind.
+ // For example, a body computation of a while instruction, will generate a
+ // TrackedInstruction with instruction being the while instruction, and
+ // kind being ComputationKind::kWhileBody.
+ class TrackedInstruction {
+ public:
+ TrackedInstruction() = default;
+ TrackedInstruction(HloInstruction* instruction, ComputationKind kind)
+ : instruction_(instruction), kind_(kind) {}
+
+ bool operator==(const TrackedInstruction& rhs) const {
+ return instruction_->opcode() == rhs.instruction_->opcode() &&
+ kind_ == rhs.kind_;
+ }
+ bool operator!=(const TrackedInstruction& rhs) const {
+ return !operator==(rhs);
+ }
+
+ HloInstruction* instruction() const { return instruction_; }
+
+ string ToString() const;
+
+ private:
+ HloInstruction* instruction_ = nullptr;
+ ComputationKind kind_ = ComputationKind::kInvalid;
+ };
+
+ // Represents a channel and the 4 instructions that form the channel.
+ struct Channel {
+ int64 id = -1;
+ HloInstruction* send = nullptr;
+ HloInstruction* recv = nullptr;
+ HloInstruction* send_done = nullptr;
+ HloInstruction* recv_done = nullptr;
+ };
+
+ explicit HloModuleGroupMetadata(const std::vector<HloModule*>& modules)
+ : modules_(modules) {}
+
+ ~HloModuleGroupMetadata() = default;
+
+ // Build and return the metadata for the given modules.
+ static StatusOr<std::unique_ptr<HloModuleGroupMetadata>> Build(
+ const std::vector<HloModule*>& modules);
+
+ // Returns true if the instruction is one of the 4 channel instructions (Send,
+ // Recv, SendDone, RecvDone).
+ bool IsChannelInstruction(const HloInstruction* instruction) const;
+
+ // Returns true if the instruction is a companion instruction. See the class
+ // comment above on companion instructions.
+ bool IsCompanionInstruction(HloInstruction* hlo) const;
+
+ // Returns true if the instruction is either a channel instruction or a
+ // companion instruction.
+ bool InstructionCommunicates(HloInstruction* hlo) const;
+
+ // Returns the Channel instance for the given channel id.
+ const Channel& GetChannel(int64 channel_id) const;
+
+ // Returns the computation that contains the peer channel instructions for
+ // the given instruction.
+ //
+ // Precondition: IsChannelInstruction(instruction) is true.
+ HloComputation* PeerComputation(const HloInstruction* instruction) const;
+
+ // Returns the path of the nested companion instructions, in terms of HLO
+ // instructions. The path goes from inner to outer companions.
+ // The returned path does not include the input hlo instruction, in case it
+ // is a companion instruction.
+ std::vector<TrackedInstruction> GetCompanionsPath(
+ const HloInstruction* hlo) const;
+
+ // Checks whether two companion paths (as returned by the GetCompanionsPath()
+ // API) are compatible. The two paths are compatible if the sequence of
+ // opcodes, and the companion kinds, of the two paths matches.
+ bool CheckCompanionPathsCompatibility(
+ const std::vector<TrackedInstruction>& path0,
+ const std::vector<TrackedInstruction>& path1) const;
+
+ // Returns the unique integer for each module. The returned id is the index of
+ // the module in the module vector.
+ int64 GetModuleId(const HloModule* module) const;
+
+ // Returns the companion instructions for the given instruction.
+ //
+ // Precondition: IsCompanionWhile(instruction) is true.
+ const std::unordered_set<HloInstruction*>& Companions(
+ HloInstruction* instruction) const {
+ CHECK_EQ(companion_set_index_.count(instruction), 1);
+ return companion_set(companion_set_index_.at(instruction));
+ }
+
+ // Returns the companion set at the given index.
+ const std::unordered_set<HloInstruction*>& companion_set(int64 index) const {
+ CHECK_LT(index, companion_sets_.size());
+ return *companion_sets_[index];
+ }
+
+ // Returns the companion set index of the given instruction.
+ int64 companion_set_index(HloInstruction* instruction) const {
+ return companion_set_index_.at(instruction);
+ }
+
+ // Returns the list of all companion sets in the HLO module group.
+ const std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>&
+ companion_sets() const {
+ return companion_sets_;
+ }
+
+ private:
+ Status Build();
+
+ // Record all channel instructions and While instructions.
+ Status RecordInstructions();
+
+ // Verifies the given HloModules are well-formed and follow the specification,
+ // in particular with respect to using channel instructions.
+ //
+ // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone).
+ // * The shape of channel instructions match.
+ // * The nest level of channel instructions match.
+ // * Channel instructions are used in allowed computations; i.e., in the
+ // entry computation of the module or condition/body of While computations.
+ //
+ // TODO(b/62064342): Currently, HloModuleGroupScheduler checks if there is a
+ // cycle in the graph, but it would be good to verify here.
+ Status VerifyChannelInstructions();
+
+ // Adds metadata that the given two instructions are companions.
+ Status AddCompanion(HloInstruction* instruction1,
+ HloInstruction* instruction2);
+
+ // Retrieves a pointer to the stored TrackedInstruction associated with a
+ // tracked computation, or nullptr in case such computation is not tracked.
+ const TrackedInstruction* GetTrackedInstruction(
+ const HloComputation* computation) const {
+ auto it = tracked_instructions_.find(computation);
+ return it != tracked_instructions_.end() ? &it->second : nullptr;
+ }
+
+ // List of all companion instructions sets in the module.
+ std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>
+ companion_sets_;
+
+ // Map from each companion while instruction to the index into companion_set_.
+ tensorflow::gtl::FlatMap<HloInstruction*, int64> companion_set_index_;
+
+ // Map from computation to the instruction using it (a kWhile, kConditional).
+ tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction>
+ tracked_instructions_;
+
+ // All channels in the module.
+ std::vector<Channel> channels_;
+
+ // Map from channel ids to the index in channels_.
+ tensorflow::gtl::FlatMap<int64, int64> channel_id_map_;
+
+ // The modules that this metadata was built from.
+ const std::vector<HloModule*>& modules_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group_util.h"
+
+#include <algorithm>
+#include <list>
+#include <queue>
+#include <stack>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_reachability.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
+ HloInstruction* instruction) {
+ std::vector<HloInstruction*> predecessors;
+
+ // Adds to the unique predecessors list and also add companion instructions
+ // if the given predecessor has those.
+ auto add_unique_predecessor = [&](HloInstruction* predecessor) {
+ if (std::find(predecessors.begin(), predecessors.end(), predecessor) !=
+ predecessors.end()) {
+ return;
+ }
+ if (!metadata_.IsCompanionInstruction(predecessor)) {
+ predecessors.push_back(predecessor);
+ return;
+ }
+ for (HloInstruction* companion : metadata_.Companions(predecessor)) {
+ predecessors.push_back(companion);
+ }
+ };
+
+ // If the given instruction is a companion instruction, we need to find the
+ // predecessors of all of its companion instructions.
+ std::vector<HloInstruction*> instruction_group;
+ if (metadata_.IsCompanionInstruction(instruction)) {
+ for (HloInstruction* companion : metadata_.Companions(instruction)) {
+ instruction_group.push_back(companion);
+ }
+ } else {
+ instruction_group.push_back(instruction);
+ }
+
+ for (HloInstruction* hlo : instruction_group) {
+ for (HloInstruction* operand : hlo->operands()) {
+ add_unique_predecessor(operand);
+ }
+ for (HloInstruction* control_predecessor : hlo->control_predecessors()) {
+ add_unique_predecessor(control_predecessor);
+ }
+ }
+ if (instruction->opcode() == HloOpcode::kRecvDone) {
+ // Send is a remote predecessor of RecvDone.
+ HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
+ add_unique_predecessor(send);
+ }
+ if (instruction->opcode() == HloOpcode::kSend) {
+ // Recv is a remote predecessor of Send.
+ HloInstruction* recv_done =
+ metadata_.GetChannel(instruction->channel_id()).recv_done;
+ CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
+ CHECK_EQ(recv_done->operand_count(), 1);
+ HloInstruction* recv = recv_done->mutable_operand(0);
+ add_unique_predecessor(recv);
+ }
+ return predecessors;
+}
+
+std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
+ HloInstruction* instruction) {
+ std::vector<HloInstruction*> successors;
+
+ // Adds to the unique successors list and also add companion instructions
+ // if the given successor has those.
+ auto add_unique_successor = [&](HloInstruction* successor) {
+ if (std::find(successors.begin(), successors.end(), successor) !=
+ successors.end()) {
+ return;
+ }
+ if (!metadata_.IsCompanionInstruction(successor)) {
+ successors.push_back(successor);
+ return;
+ }
+ for (HloInstruction* companion : metadata_.Companions(successor)) {
+ successors.push_back(companion);
+ }
+ };
+
+ // If the given instruction is a companion instruction, we need to find the
+ // successors of all of its companion instructions.
+ std::vector<HloInstruction*> instruction_group;
+ if (metadata_.IsCompanionInstruction(instruction)) {
+ for (HloInstruction* companion : metadata_.Companions(instruction)) {
+ instruction_group.push_back(companion);
+ }
+ } else {
+ instruction_group.push_back(instruction);
+ }
+
+ for (HloInstruction* hlo : instruction_group) {
+ for (HloInstruction* user : hlo->users()) {
+ add_unique_successor(user);
+ }
+ for (HloInstruction* control_successor : hlo->control_successors()) {
+ add_unique_successor(control_successor);
+ }
+ }
+ if (instruction->opcode() == HloOpcode::kRecv) {
+ // Send is a remote successor of Recv.
+ const HloInstruction* recv_done = instruction->users().front();
+ CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
+ HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
+ add_unique_successor(send);
+ }
+ if (instruction->opcode() == HloOpcode::kSend) {
+ // RecvDone is a remote successor of Send.
+ HloInstruction* recv_done =
+ metadata_.GetChannel(instruction->channel_id()).recv_done;
+ add_unique_successor(recv_done);
+ }
+ return successors;
+}
+
+std::vector<HloInstruction*> HloModuleGroupUtil::RootInstructions(
+ tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ std::vector<HloInstruction*> roots;
+ for (HloComputation* computation : computations) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (GlobalSuccessors(instruction).empty()) {
+ roots.push_back(instruction);
+ }
+ }
+ }
+ return roots;
+}
+
+Status HloModuleGroupUtil::VisitTopologicalOrder(
+ VisitStates* visit_state, const VisitFunction& visit_function,
+ HloInstruction* root) {
+ // Stack of HLO instructions visited in DFS order.
+ std::stack<HloInstruction*> stack;
+ stack.push(root);
+
+ while (!stack.empty()) {
+ HloInstruction* hlo = stack.top();
+
+ // Find the instruction group of the currently visited instruction. The
+ // instruction group represents all companion instructions of the
+ // current instruction, and are considered to be a single entity for the
+ // purpose of the traversal (i.e., they must always be in the same visit
+ // state).
+ std::vector<HloInstruction*> instruction_group;
+ if (metadata_.IsCompanionInstruction(hlo)) {
+ for (HloInstruction* companion : metadata_.Companions(hlo)) {
+ instruction_group.push_back(companion);
+ }
+ } else {
+ instruction_group.push_back(hlo);
+ }
+
+ if ((*visit_state)[hlo] == VisitState::kVisited) {
+ // All instructions in the group must be in the same state.
+ for (HloInstruction* instruction : instruction_group) {
+ TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisited);
+ }
+ stack.pop();
+ continue;
+ }
+
+ if ((*visit_state)[hlo] == VisitState::kVisiting) {
+ TF_RETURN_IF_ERROR(visit_function(hlo, instruction_group));
+
+ // Set the visit state of all instructions in the group to kVisited.
+ for (HloInstruction* instruction : instruction_group) {
+ TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisiting);
+ (*visit_state)[instruction] = VisitState::kVisited;
+ }
+ stack.pop();
+ continue;
+ }
+
+ // Set the visit state of all instructions in the group to kVisiting.
+ for (HloInstruction* instruction : instruction_group) {
+ TF_RET_CHECK((*visit_state)[instruction] == VisitState::kNotVisited)
+ << instruction->ToString();
+ (*visit_state)[instruction] = VisitState::kVisiting;
+ }
+
+ // For each instruction in the group, visit its predecessors (operands,
+ // control predecessors and remote predecessors).
+ for (HloInstruction* instruction : instruction_group) {
+ for (HloInstruction* predecessor : GlobalPredecessors(instruction)) {
+ // Visiting a node that is already being visited implies that there is
+ // a cycle. Generate an error with the list of instructions in the
+ // cycle.
+ if ((*visit_state)[predecessor] == VisitState::kVisiting) {
+ string cyclic_instructions;
+ for (const auto& state : *visit_state) {
+ if (state.second == VisitState::kVisiting) {
+ tensorflow::strings::StrAppend(&cyclic_instructions,
+ state.first->ToString(), "\n");
+ }
+ }
+ // TODO(b/64305524): Improve the error message to print out the
+ // instructions in a deterministic order that forms the cycle.
+ return FailedPrecondition(
+ "Cross-computation cycle detected via communicating nodes. The "
+ "cycle contains the node %s. The cycle is found among the "
+ "following nodes. Note that the order of the nodes is arbitrary "
+ "and that the list may include nodes that are not part of the "
+ "cycle.\n%s",
+ predecessor->ToString().c_str(), cyclic_instructions.c_str());
+ }
+ stack.push(predecessor);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+Status HloModuleGroupUtil::VerifyComputations(
+ tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ auto visit_function =
+ [&](HloInstruction* instruction,
+ const std::vector<HloInstruction*>& instruction_group) {
+ return Status::OK();
+ };
+ int64 instructions_count = 0;
+ VisitStates visit_states;
+ for (HloComputation* computation : computations) {
+ // Visit all instructions, and not just from the root instruction of the
+ // computation. This allows us to detect dead cycles (i.e., cycles that
+ // are not reachable from the root) or to enforce an order for the
+ // communication instructions that are not reachable from any roots.
+ for (HloInstruction* instruction : computation->instructions()) {
+ TF_RETURN_IF_ERROR(
+ VisitTopologicalOrder(&visit_states, visit_function, instruction));
+ }
+ instructions_count += computation->instruction_count();
+ }
+
+ // Check if all instructions are visited and are in the visited state.
+ TF_RET_CHECK(visit_states.size() == instructions_count);
+ for (auto& state : visit_states) {
+ TF_RET_CHECK(state.second == VisitState::kVisited);
+ }
+
+ return Status::OK();
+}
+
+StatusOr<std::unique_ptr<HloReachabilityMap>>
+HloModuleGroupUtil::ComputeReachability(
+ tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ std::list<HloInstruction*> post_order;
+ auto visit_function =
+ [&](HloInstruction* instruction,
+ const std::vector<HloInstruction*>& instruction_group) {
+ post_order.insert(post_order.end(), instruction_group.begin(),
+ instruction_group.end());
+ return Status::OK();
+ };
+ HloModuleGroupUtil::VisitStates visit_states;
+ for (HloInstruction* root : RootInstructions(computations)) {
+ TF_RETURN_IF_ERROR(
+ VisitTopologicalOrder(&visit_states, visit_function, root));
+ }
+ auto reachability = absl::make_unique<HloReachabilityMap>(post_order);
+ for (HloInstruction* hlo : post_order) {
+ reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
+ }
+ return std::move(reachability);
+}
+
+void HloModuleGroupUtil::UpdateReachabilityThroughInstruction(
+ HloInstruction* instruction, HloReachabilityMap* reachability_map) {
+ std::queue<HloInstruction*> worklist;
+ worklist.push(instruction);
+
+ while (!worklist.empty()) {
+ HloInstruction* item = worklist.front();
+ worklist.pop();
+ if (reachability_map->SetReachabilityToUnion(GlobalPredecessors(item),
+ item)) {
+ for (HloInstruction* successor : GlobalSuccessors(item)) {
+ worklist.push(successor);
+ }
+ }
+ }
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_reachability.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace xla {
+
+// Collection of utilities for handling HloModuleGroups.
+class HloModuleGroupUtil {
+ public:
+ explicit HloModuleGroupUtil(const HloModuleGroupMetadata& metadata)
+ : metadata_(metadata) {}
+
+ // Returns all unique predecessors of the instruction. This includes:
+ // * predecessors in the same computation: operands and control predecessors
+ // * Recv is a predecessor of Send
+ // * Send is a predecessor of RecvDone
+ // * predecessors of companions (if the instruction is a companion while)
+ // * predecessors' companions (for any predecessor that is a companion while)
+ std::vector<HloInstruction*> GlobalPredecessors(HloInstruction* instruction);
+
+ // Returns all unique successors of the instruction. This includes:
+ // * successors in the same computation: users and control successors
+ // * Send is a successor of Recv
+ // * RecvDone is a predecessor of Send
+ // * successors of companions (if the instruction is a companion while)
+ // * successors' companions (for any successor that is a companion while)
+ std::vector<HloInstruction*> GlobalSuccessors(HloInstruction* instruction);
+
+ // Returns the root instructions of the computations.
+ std::vector<HloInstruction*> RootInstructions(
+ tensorflow::gtl::ArraySlice<HloComputation*> computations);
+
+ // Visit state of each instruction during DFS traversal.
+ enum VisitState {
+ kNotVisited = 0,
+ kVisiting,
+ kVisited,
+ };
+
+ // Function called on each instruction group during the DFS traversal. See the
+ // comment for VisitTopologicalOrder()).
+ using VisitFunction = std::function<Status(
+ HloInstruction* hlo,
+ const std::vector<HloInstruction*>& instruction_group)>;
+
+ // Given the hlo instruction as the root, recursively visits all its
+ // predecessor instructions in DFS order to visit nodes in topological order.
+ //
+ // Note that the DFS traversal does not only visit nodes in the same
+ // computation (parent of the root instruction), but also visits nodes in
+ // different computations connected via communication instructions. During the
+ // traversal, companion While instructions (see the class comment in
+ // HloModuleGroupMetadata) are treated as a single instruction (called
+ // instruction group, which contains only a single instruction if the visiting
+ // node is not a companion while) -- visiting one of the instructions in the
+ // group effectively visits all other instructions in the group, and then all
+ // predecessor instructions of the group are visited.
+ //
+ // * visit_state: map from each instruction to its visit state.
+ // * visit_function: function called when each instruction group.
+ // * root: the root instruction of the traversal.
+ using VisitStates = tensorflow::gtl::FlatMap<HloInstruction*, VisitState>;
+ Status VisitTopologicalOrder(VisitStates* visit_state,
+ const VisitFunction& visit_function,
+ HloInstruction* root);
+
+ // Verifies that the computations are well-formed (e.g., no cycles).
+ Status VerifyComputations(
+ tensorflow::gtl::ArraySlice<HloComputation*> computations);
+
+ // Below Reachability utils resemble those in HloComputation, except that
+ // they can handle instructions across multiple computations.
+ //
+ // Creates the reachability map for the instructions in the computations.
+ StatusOr<std::unique_ptr<HloReachabilityMap>> ComputeReachability(
+ tensorflow::gtl::ArraySlice<HloComputation*> computations);
+
+ // Updates the reachability of the given instruction, taking the global
+ // predeccessorss and successors into account.
+ void UpdateReachabilityThroughInstruction(
+ HloInstruction* instruction, HloReachabilityMap* reachability_map);
+
+ private:
+ const HloModuleGroupMetadata& metadata_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_