Add HloModuleGroupMetadata and HloModuleGroupUtil
authorHyoukJoong Lee <hyouklee@google.com>
Tue, 6 Mar 2018 18:24:45 +0000 (10:24 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 18:29:47 +0000 (10:29 -0800)
PiperOrigin-RevId: 188041608

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/hlo_module_group_metadata.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/hlo_module_group_metadata.h [new file with mode: 0644]
tensorflow/compiler/xla/service/hlo_module_group_util.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/hlo_module_group_util.h [new file with mode: 0644]

index 3eecc46..611b183 100644 (file)
@@ -1066,6 +1066,38 @@ tf_cc_test(
 )
 
 cc_library(
+    name = "hlo_module_group_metadata",
+    srcs = ["hlo_module_group_metadata.cc"],
+    hdrs = ["hlo_module_group_metadata.h"],
+    deps = [
+        ":hlo",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
+    name = "hlo_module_group_util",
+    srcs = ["hlo_module_group_util.cc"],
+    hdrs = ["hlo_module_group_util.h"],
+    deps = [
+        ":hlo",
+        ":hlo_module_group_metadata",
+        ":hlo_reachability",
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
     name = "hlo_scheduling",
     srcs = ["hlo_scheduling.cc"],
     hdrs = ["hlo_scheduling.h"],
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
new file mode 100644 (file)
index 0000000..eed0112
--- /dev/null
@@ -0,0 +1,349 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
+
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
+  string repr =
+      (instruction_ != nullptr) ? instruction_->ToShortString() : "NULL";
+  switch (kind_) {
+    case ComputationKind::kInvalid:
+      repr += ":INVALID";
+      break;
+    case ComputationKind::kWhileCondition:
+      repr += ":WHILE_CONDITION";
+      break;
+    case ComputationKind::kWhileBody:
+      repr += ":WHILE_BODY";
+      break;
+    case ComputationKind::kConditionalTrue:
+      repr += ":CONDITIONAL_TRUE";
+      break;
+    case ComputationKind::kConditionalFalse:
+      repr += ":CONDITIONAL_FALSE";
+      break;
+  }
+  return repr;
+}
+
+/* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>>
+HloModuleGroupMetadata::Build(const std::vector<HloModule*>& modules) {
+  auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules);
+  TF_RETURN_IF_ERROR(metadata->Build());
+  return std::move(metadata);
+}
+
+Status HloModuleGroupMetadata::Build() {
+  TF_RETURN_IF_ERROR(RecordInstructions());
+  TF_RETURN_IF_ERROR(VerifyChannelInstructions());
+
+  // Record all companion while instructions.
+  const auto visitor = [this](HloInstruction* hlo) -> Status {
+    // We only need to process if the instruction is within the computation
+    // of a companion instruction, like in the condition or body computation
+    // of a While.
+    const TrackedInstruction* tracked = GetTrackedInstruction(hlo->parent());
+    if (tracked == nullptr) {
+      return Status::OK();
+    }
+    // Add the parent computation of this channel instruction and its peer
+    // computation (both must be while computations) as companions.
+    if (IsChannelInstruction(hlo)) {
+      HloComputation* peer_computation = PeerComputation(hlo);
+      const TrackedInstruction* peer_tracked =
+          GetTrackedInstruction(peer_computation);
+      TF_RET_CHECK(peer_tracked != nullptr)
+          << "Peer instruction is not a possible companion";
+      TF_RET_CHECK(*tracked == *peer_tracked)
+          << "Peer instruction does not match the computation kind";
+      TF_RETURN_IF_ERROR(
+          AddCompanion(tracked->instruction(), peer_tracked->instruction()));
+    }
+
+    // Add the parents of companion instructions (they must be all of the same
+    // kind of instructions, opcode wise) as companions.
+    if (IsCompanionInstruction(hlo)) {
+      for (HloInstruction* companion : Companions(hlo)) {
+        const TrackedInstruction* companion_tracked =
+            GetTrackedInstruction(companion->parent());
+        TF_RET_CHECK(companion_tracked != nullptr);
+        TF_RET_CHECK(*tracked == *companion_tracked);
+        TF_RETURN_IF_ERROR(AddCompanion(tracked->instruction(),
+                                        companion_tracked->instruction()));
+      }
+    }
+    return Status::OK();
+  };
+
+  // Visit the computations in postorder so that the companion information grows
+  // from inner computations to outer ones.
+  for (HloModule* module : modules_) {
+    for (HloComputation* computation : module->MakeComputationPostOrder()) {
+      TF_RETURN_IF_ERROR(computation->Accept(visitor));
+    }
+  }
+  return Status::OK();
+}
+
+bool HloModuleGroupMetadata::IsChannelInstruction(
+    const HloInstruction* instruction) const {
+  switch (instruction->opcode()) {
+    case HloOpcode::kSend:
+    case HloOpcode::kRecv:
+    case HloOpcode::kSendDone:
+    case HloOpcode::kRecvDone:
+      return true;
+    default:
+      return false;
+  }
+}
+
+bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const {
+  return companion_set_index_.count(hlo) > 0;
+}
+
+bool HloModuleGroupMetadata::InstructionCommunicates(
+    HloInstruction* hlo) const {
+  return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo);
+}
+
+const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
+    int64 channel_id) const {
+  CHECK(channel_id_map_.find(channel_id) != channel_id_map_.end());
+  return channels_[channel_id_map_.at(channel_id)];
+}
+
+HloComputation* HloModuleGroupMetadata::PeerComputation(
+    const HloInstruction* instruction) const {
+  CHECK(IsChannelInstruction(instruction));
+  const Channel& channel = GetChannel(instruction->channel_id());
+  switch (instruction->opcode()) {
+    case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
+      return channel.recv->parent();
+    case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
+      return channel.send->parent();
+    default:
+      LOG(FATAL) << "opcode not supported";
+  }
+}
+
+std::vector<HloModuleGroupMetadata::TrackedInstruction>
+HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const {
+  std::vector<TrackedInstruction> path;
+  const HloComputation* parent = hlo->parent();
+  const TrackedInstruction* companion;
+  while ((companion = GetTrackedInstruction(parent)) != nullptr) {
+    parent = companion->instruction()->parent();
+    path.push_back(*companion);
+  }
+  return path;
+}
+
+bool HloModuleGroupMetadata::CheckCompanionPathsCompatibility(
+    const std::vector<TrackedInstruction>& path0,
+    const std::vector<TrackedInstruction>& path1) const {
+  if (path0.size() != path1.size()) {
+    VLOG(5) << "Companion path size do not match: " << path0.size()
+            << " != " << path1.size();
+    return false;
+  }
+  for (int64 i = 0; i < path0.size(); ++i) {
+    if (path0[i] != path1[i]) {
+      VLOG(5) << "Companion instructions at path index " << i
+              << " do not have the same opcode: " << path0[i].ToString()
+              << " vs " << path1[i].ToString();
+      return false;
+    }
+  }
+  return true;
+}
+
+int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
+  for (int64 i = 0; i < modules_.size(); ++i) {
+    if (modules_[i] == module) {
+      return i;
+    }
+  }
+  LOG(FATAL) << "unknown module";
+}
+
+Status HloModuleGroupMetadata::RecordInstructions() {
+  const auto visitor = [this](HloInstruction* hlo) -> Status {
+    if (hlo->opcode() == HloOpcode::kWhile) {
+      tracked_instructions_[hlo->while_condition()] =
+          TrackedInstruction(hlo, ComputationKind::kWhileCondition);
+      tracked_instructions_[hlo->while_body()] =
+          TrackedInstruction(hlo, ComputationKind::kWhileBody);
+    } else if (hlo->opcode() == HloOpcode::kConditional) {
+      tracked_instructions_[hlo->true_computation()] =
+          TrackedInstruction(hlo, ComputationKind::kConditionalTrue);
+      tracked_instructions_[hlo->false_computation()] =
+          TrackedInstruction(hlo, ComputationKind::kConditionalFalse);
+    }
+    if (!IsChannelInstruction(hlo)) {
+      return Status::OK();
+    }
+
+    // Add a new channel if needed.
+    if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) {
+      channels_.emplace_back();
+      channels_.back().id = hlo->channel_id();
+      channel_id_map_[hlo->channel_id()] = channels_.size() - 1;
+    }
+    Channel& channel = channels_[channel_id_map_[hlo->channel_id()]];
+
+    if (hlo->opcode() == HloOpcode::kSend) {
+      TF_RET_CHECK(channel.send == nullptr)
+          << "channel id " << hlo->channel_id()
+          << " is used by multiple send instructions";
+      channel.send = hlo;
+    }
+    if (hlo->opcode() == HloOpcode::kRecv) {
+      TF_RET_CHECK(channel.recv == nullptr)
+          << "channel id " << hlo->channel_id()
+          << " is used by multiple recv instructions";
+      channel.recv = hlo;
+    }
+    if (hlo->opcode() == HloOpcode::kSendDone) {
+      TF_RET_CHECK(channel.send_done == nullptr)
+          << "channel id " << hlo->channel_id()
+          << " is used by multiple send-done instructions";
+      channel.send_done = hlo;
+    }
+    if (hlo->opcode() == HloOpcode::kRecvDone) {
+      TF_RET_CHECK(channel.recv_done == nullptr)
+          << "channel id " << hlo->channel_id()
+          << " is used by multiple recv-done instructions";
+      channel.recv_done = hlo;
+    }
+    return Status::OK();
+  };
+
+  for (HloModule* module : modules_) {
+    for (auto* computation : module->computations()) {
+      TF_RETURN_IF_ERROR(computation->Accept(visitor));
+    }
+  }
+  return Status::OK();
+}
+
+Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
+                                            HloInstruction* instruction2) {
+  TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile ||
+               instruction1->opcode() == HloOpcode::kConditional);
+  VLOG(2) << "adding as companions:" << instruction1->ToString() << " and "
+          << instruction2->ToString();
+
+  if (!ContainsKey(companion_set_index_, instruction1) &&
+      !ContainsKey(companion_set_index_, instruction2)) {
+    companion_sets_.push_back(
+        absl::make_unique<std::unordered_set<HloInstruction*>>());
+    auto companion_set = companion_sets_.back().get();
+    companion_set->insert(instruction1);
+    companion_set->insert(instruction2);
+    companion_set_index_[instruction1] = companion_sets_.size() - 1;
+    companion_set_index_[instruction2] = companion_sets_.size() - 1;
+  } else if (!ContainsKey(companion_set_index_, instruction1)) {
+    companion_sets_[companion_set_index_[instruction2]]->insert(instruction1);
+    companion_set_index_[instruction1] = companion_set_index_[instruction2];
+  } else if (!ContainsKey(companion_set_index_, instruction2)) {
+    companion_sets_[companion_set_index_[instruction1]]->insert(instruction2);
+    companion_set_index_[instruction2] = companion_set_index_[instruction1];
+  } else if (companion_set_index_[instruction1] !=
+             companion_set_index_[instruction2]) {
+    companion_sets_[companion_set_index_[instruction1]]->insert(
+        Companions(instruction2).begin(), Companions(instruction2).end());
+    int64 index_to_remove = companion_set_index_[instruction2];
+    for (HloInstruction* hlo : Companions(instruction2)) {
+      companion_set_index_[hlo] = companion_set_index_[instruction1];
+    }
+    companion_sets_.erase(companion_sets_.begin() + index_to_remove);
+  }
+  return Status::OK();
+}
+
+Status HloModuleGroupMetadata::VerifyChannelInstructions() {
+  for (const Channel& channel : channels_) {
+    if (channel.send == nullptr) {
+      return FailedPrecondition("missing send for id : %lld", channel.id);
+    }
+    if (channel.recv == nullptr) {
+      return FailedPrecondition("missing recv for id : %lld", channel.id);
+    }
+    if (channel.send_done == nullptr) {
+      return FailedPrecondition("missing send-done for id : %lld", channel.id);
+    }
+    if (channel.recv_done == nullptr) {
+      return FailedPrecondition("missing recv-done for id : %lld", channel.id);
+    }
+  }
+
+  // Check if the shapes match for each channel.
+  for (const Channel& channel : channels_) {
+    const Shape& send_shape = channel.send->operand(0)->shape();
+    const Shape& recv_shape = channel.recv_done->shape();
+    if (!ShapeUtil::Compatible(send_shape, recv_shape)) {
+      return FailedPrecondition("send/recv shapes do not match");
+    }
+  }
+
+  // Check if channel instructions are used only in allowed computations.
+  const auto allowed = [this](HloInstruction* hlo) {
+    HloComputation* computation = hlo->parent();
+    const HloModule* module = computation->parent();
+    if (module->entry_computation() == computation ||
+        tracked_instructions_.count(computation) > 0) {
+      return true;
+    }
+    return false;
+  };
+  for (const Channel& channel : channels_) {
+    if (!allowed(channel.send) || !allowed(channel.send_done) ||
+        !allowed(channel.recv) || !allowed(channel.recv_done)) {
+      return FailedPrecondition("channel is used in disallowed computation");
+    }
+  }
+  // Check if the nest levels match for each channel.
+  for (const Channel& channel : channels_) {
+    std::vector<TrackedInstruction> path = GetCompanionsPath(channel.send);
+    if (!CheckCompanionPathsCompatibility(
+            path, GetCompanionsPath(channel.send_done)) ||
+        !CheckCompanionPathsCompatibility(path,
+                                          GetCompanionsPath(channel.recv)) ||
+        !CheckCompanionPathsCompatibility(
+            path, GetCompanionsPath(channel.recv_done))) {
+      return FailedPrecondition(
+          "Nest companion paths do not match for channel %lld", channel.id);
+    }
+  }
+  return Status::OK();
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
new file mode 100644 (file)
index 0000000..15cdbda
--- /dev/null
@@ -0,0 +1,230 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
+
+#include <memory>
+#include <set>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Class for bookkeeping the information on the given modules, in particular on
+// the interaction between computations.
+//
+// Companion instructions are one of the information collected as we build the
+// metadata. For example, for each While instruction, companion instructions
+// refer to a set of While instructions in other computations that communicate
+// with each other.
+// In the example below with 3 modules, {While_0, While_2, While_5}, {While_1,
+// While_4}, {While_3, While_6} are companion sets.
+//
+// <Module 0>               <Module 1>                 <Module 2>
+// While_0() {              While_2() {                While_5() {
+//   While_1() { Send(0) }    While_3() { Send(1) }      While_6() { Recv(1) }
+// }                          While_4() { Recv(0) }
+//                          }
+//
+// Companion instructions are used to detect cycles in the graph and also for
+// global scheduling.
+class HloModuleGroupMetadata {
+ public:
+  // The kind of companion computation a given instruction can be within.
+  enum class ComputationKind {
+    kInvalid,
+    kWhileCondition,
+    kWhileBody,
+    kConditionalTrue,
+    kConditionalFalse,
+  };
+
+  // Tracks the instruction mapped to a given computation, and the computation
+  // kind.
+  // For example, a body computation of a while instruction, will generate a
+  // TrackedInstruction with instruction being the while instruction, and
+  // kind being ComputationKind::kWhileBody.
+  class TrackedInstruction {
+   public:
+    TrackedInstruction() = default;
+    TrackedInstruction(HloInstruction* instruction, ComputationKind kind)
+        : instruction_(instruction), kind_(kind) {}
+
+    bool operator==(const TrackedInstruction& rhs) const {
+      return instruction_->opcode() == rhs.instruction_->opcode() &&
+             kind_ == rhs.kind_;
+    }
+    bool operator!=(const TrackedInstruction& rhs) const {
+      return !operator==(rhs);
+    }
+
+    HloInstruction* instruction() const { return instruction_; }
+
+    string ToString() const;
+
+   private:
+    HloInstruction* instruction_ = nullptr;
+    ComputationKind kind_ = ComputationKind::kInvalid;
+  };
+
+  // Represents a channel and the 4 instructions that form the channel.
+  struct Channel {
+    int64 id = -1;
+    HloInstruction* send = nullptr;
+    HloInstruction* recv = nullptr;
+    HloInstruction* send_done = nullptr;
+    HloInstruction* recv_done = nullptr;
+  };
+
+  explicit HloModuleGroupMetadata(const std::vector<HloModule*>& modules)
+      : modules_(modules) {}
+
+  ~HloModuleGroupMetadata() = default;
+
+  // Build and return the metadata for the given modules.
+  static StatusOr<std::unique_ptr<HloModuleGroupMetadata>> Build(
+      const std::vector<HloModule*>& modules);
+
+  // Returns true if the instruction is one of the 4 channel instructions (Send,
+  // Recv, SendDone, RecvDone).
+  bool IsChannelInstruction(const HloInstruction* instruction) const;
+
+  // Returns true if the instruction is a companion instruction. See the class
+  // comment above on companion instructions.
+  bool IsCompanionInstruction(HloInstruction* hlo) const;
+
+  // Returns true if the instruction is either a channel instruction or a
+  // companion instruction.
+  bool InstructionCommunicates(HloInstruction* hlo) const;
+
+  // Returns the Channel instance for the given channel id.
+  const Channel& GetChannel(int64 channel_id) const;
+
+  // Returns the computation that contains the peer channel instructions for
+  // the given instruction.
+  //
+  // Precondition: IsChannelInstruction(instruction) is true.
+  HloComputation* PeerComputation(const HloInstruction* instruction) const;
+
+  // Returns the path of the nested companion instructions, in terms of HLO
+  // instructions. The path goes from inner to outer companions.
+  // The returned path does not include the input hlo instruction, in case it
+  // is a companion instruction.
+  std::vector<TrackedInstruction> GetCompanionsPath(
+      const HloInstruction* hlo) const;
+
+  // Checks whether two companion paths (as returned by the GetCompanionsPath()
+  // API) are compatible. The two paths are compatible if the sequence of
+  // opcodes, and the companion kinds, of the two paths matches.
+  bool CheckCompanionPathsCompatibility(
+      const std::vector<TrackedInstruction>& path0,
+      const std::vector<TrackedInstruction>& path1) const;
+
+  // Returns the unique integer for each module. The returned id is the index of
+  // the module in the module vector.
+  int64 GetModuleId(const HloModule* module) const;
+
+  // Returns the companion instructions for the given instruction.
+  //
+  // Precondition: IsCompanionWhile(instruction) is true.
+  const std::unordered_set<HloInstruction*>& Companions(
+      HloInstruction* instruction) const {
+    CHECK_EQ(companion_set_index_.count(instruction), 1);
+    return companion_set(companion_set_index_.at(instruction));
+  }
+
+  // Returns the companion set at the given index.
+  const std::unordered_set<HloInstruction*>& companion_set(int64 index) const {
+    CHECK_LT(index, companion_sets_.size());
+    return *companion_sets_[index];
+  }
+
+  // Returns the companion set index of the given instruction.
+  int64 companion_set_index(HloInstruction* instruction) const {
+    return companion_set_index_.at(instruction);
+  }
+
+  // Returns the list of all companion sets in the HLO module group.
+  const std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>&
+  companion_sets() const {
+    return companion_sets_;
+  }
+
+ private:
+  Status Build();
+
+  // Record all channel instructions and While instructions.
+  Status RecordInstructions();
+
+  // Verifies the given HloModules are well-formed and follow the specification,
+  // in particular with respect to using channel instructions.
+  //
+  // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone).
+  // * The shape of channel instructions match.
+  // * The nest level of channel instructions match.
+  // * Channel instructions are used in allowed computations; i.e., in the
+  //   entry computation of the module or condition/body of While computations.
+  //
+  // TODO(b/62064342): Currently, HloModuleGroupScheduler checks if there is a
+  // cycle in the graph, but it would be good to verify here.
+  Status VerifyChannelInstructions();
+
+  // Adds metadata that the given two instructions are companions.
+  Status AddCompanion(HloInstruction* instruction1,
+                      HloInstruction* instruction2);
+
+  // Retrieves a pointer to the stored TrackedInstruction associated with a
+  // tracked computation, or nullptr in case such computation is not tracked.
+  const TrackedInstruction* GetTrackedInstruction(
+      const HloComputation* computation) const {
+    auto it = tracked_instructions_.find(computation);
+    return it != tracked_instructions_.end() ? &it->second : nullptr;
+  }
+
+  // List of all companion instructions sets in the module.
+  std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>
+      companion_sets_;
+
+  // Map from each companion while instruction to the index into companion_set_.
+  tensorflow::gtl::FlatMap<HloInstruction*, int64> companion_set_index_;
+
+  // Map from computation to the instruction using it (a kWhile, kConditional).
+  tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction>
+      tracked_instructions_;
+
+  // All channels in the module.
+  std::vector<Channel> channels_;
+
+  // Map from channel ids to the index in channels_.
+  tensorflow::gtl::FlatMap<int64, int64> channel_id_map_;
+
+  // The modules that this metadata was built from.
+  const std::vector<HloModule*>& modules_;
+};
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
new file mode 100644 (file)
index 0000000..289c96b
--- /dev/null
@@ -0,0 +1,316 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group_util.h"
+
+#include <algorithm>
+#include <list>
+#include <queue>
+#include <stack>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_reachability.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
+    HloInstruction* instruction) {
+  std::vector<HloInstruction*> predecessors;
+
+  // Adds to the unique predecessors list and also add companion instructions
+  // if the given predecessor has those.
+  auto add_unique_predecessor = [&](HloInstruction* predecessor) {
+    if (std::find(predecessors.begin(), predecessors.end(), predecessor) !=
+        predecessors.end()) {
+      return;
+    }
+    if (!metadata_.IsCompanionInstruction(predecessor)) {
+      predecessors.push_back(predecessor);
+      return;
+    }
+    for (HloInstruction* companion : metadata_.Companions(predecessor)) {
+      predecessors.push_back(companion);
+    }
+  };
+
+  // If the given instruction is a companion instruction, we need to find the
+  // predecessors of all of its companion instructions.
+  std::vector<HloInstruction*> instruction_group;
+  if (metadata_.IsCompanionInstruction(instruction)) {
+    for (HloInstruction* companion : metadata_.Companions(instruction)) {
+      instruction_group.push_back(companion);
+    }
+  } else {
+    instruction_group.push_back(instruction);
+  }
+
+  for (HloInstruction* hlo : instruction_group) {
+    for (HloInstruction* operand : hlo->operands()) {
+      add_unique_predecessor(operand);
+    }
+    for (HloInstruction* control_predecessor : hlo->control_predecessors()) {
+      add_unique_predecessor(control_predecessor);
+    }
+  }
+  if (instruction->opcode() == HloOpcode::kRecvDone) {
+    // Send is a remote predecessor of RecvDone.
+    HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
+    add_unique_predecessor(send);
+  }
+  if (instruction->opcode() == HloOpcode::kSend) {
+    // Recv is a remote predecessor of Send.
+    HloInstruction* recv_done =
+        metadata_.GetChannel(instruction->channel_id()).recv_done;
+    CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
+    CHECK_EQ(recv_done->operand_count(), 1);
+    HloInstruction* recv = recv_done->mutable_operand(0);
+    add_unique_predecessor(recv);
+  }
+  return predecessors;
+}
+
+std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
+    HloInstruction* instruction) {
+  std::vector<HloInstruction*> successors;
+
+  // Adds to the unique successors list and also add companion instructions
+  // if the given successor has those.
+  auto add_unique_successor = [&](HloInstruction* successor) {
+    if (std::find(successors.begin(), successors.end(), successor) !=
+        successors.end()) {
+      return;
+    }
+    if (!metadata_.IsCompanionInstruction(successor)) {
+      successors.push_back(successor);
+      return;
+    }
+    for (HloInstruction* companion : metadata_.Companions(successor)) {
+      successors.push_back(companion);
+    }
+  };
+
+  // If the given instruction is a companion instruction, we need to find the
+  // successors of all of its companion instructions.
+  std::vector<HloInstruction*> instruction_group;
+  if (metadata_.IsCompanionInstruction(instruction)) {
+    for (HloInstruction* companion : metadata_.Companions(instruction)) {
+      instruction_group.push_back(companion);
+    }
+  } else {
+    instruction_group.push_back(instruction);
+  }
+
+  for (HloInstruction* hlo : instruction_group) {
+    for (HloInstruction* user : hlo->users()) {
+      add_unique_successor(user);
+    }
+    for (HloInstruction* control_successor : hlo->control_successors()) {
+      add_unique_successor(control_successor);
+    }
+  }
+  if (instruction->opcode() == HloOpcode::kRecv) {
+    // Send is a remote successor of Recv.
+    const HloInstruction* recv_done = instruction->users().front();
+    CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
+    HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
+    add_unique_successor(send);
+  }
+  if (instruction->opcode() == HloOpcode::kSend) {
+    // RecvDone is a remote successor of Send.
+    HloInstruction* recv_done =
+        metadata_.GetChannel(instruction->channel_id()).recv_done;
+    add_unique_successor(recv_done);
+  }
+  return successors;
+}
+
+std::vector<HloInstruction*> HloModuleGroupUtil::RootInstructions(
+    tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+  std::vector<HloInstruction*> roots;
+  for (HloComputation* computation : computations) {
+    for (HloInstruction* instruction : computation->instructions()) {
+      if (GlobalSuccessors(instruction).empty()) {
+        roots.push_back(instruction);
+      }
+    }
+  }
+  return roots;
+}
+
+Status HloModuleGroupUtil::VisitTopologicalOrder(
+    VisitStates* visit_state, const VisitFunction& visit_function,
+    HloInstruction* root) {
+  // Stack of HLO instructions visited in DFS order.
+  std::stack<HloInstruction*> stack;
+  stack.push(root);
+
+  while (!stack.empty()) {
+    HloInstruction* hlo = stack.top();
+
+    // Find the instruction group of the currently visited instruction. The
+    // instruction group represents all companion instructions of the
+    // current instruction, and are considered to be a single entity for the
+    // purpose of the traversal (i.e., they must always be in the same visit
+    // state).
+    std::vector<HloInstruction*> instruction_group;
+    if (metadata_.IsCompanionInstruction(hlo)) {
+      for (HloInstruction* companion : metadata_.Companions(hlo)) {
+        instruction_group.push_back(companion);
+      }
+    } else {
+      instruction_group.push_back(hlo);
+    }
+
+    if ((*visit_state)[hlo] == VisitState::kVisited) {
+      // All instructions in the group must be in the same state.
+      for (HloInstruction* instruction : instruction_group) {
+        TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisited);
+      }
+      stack.pop();
+      continue;
+    }
+
+    if ((*visit_state)[hlo] == VisitState::kVisiting) {
+      TF_RETURN_IF_ERROR(visit_function(hlo, instruction_group));
+
+      // Set the visit state of all instructions in the group to kVisited.
+      for (HloInstruction* instruction : instruction_group) {
+        TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisiting);
+        (*visit_state)[instruction] = VisitState::kVisited;
+      }
+      stack.pop();
+      continue;
+    }
+
+    // Set the visit state of all instructions in the group to kVisiting.
+    for (HloInstruction* instruction : instruction_group) {
+      TF_RET_CHECK((*visit_state)[instruction] == VisitState::kNotVisited)
+          << instruction->ToString();
+      (*visit_state)[instruction] = VisitState::kVisiting;
+    }
+
+    // For each instruction in the group, visit its predecessors (operands,
+    // control predecessors and remote predecessors).
+    for (HloInstruction* instruction : instruction_group) {
+      for (HloInstruction* predecessor : GlobalPredecessors(instruction)) {
+        // Visiting a node that is already being visited implies that there is
+        // a cycle. Generate an error with the list of instructions in the
+        // cycle.
+        if ((*visit_state)[predecessor] == VisitState::kVisiting) {
+          string cyclic_instructions;
+          for (const auto& state : *visit_state) {
+            if (state.second == VisitState::kVisiting) {
+              tensorflow::strings::StrAppend(&cyclic_instructions,
+                                             state.first->ToString(), "\n");
+            }
+          }
+          // TODO(b/64305524): Improve the error message to print out the
+          // instructions in a deterministic order that forms the cycle.
+          return FailedPrecondition(
+              "Cross-computation cycle detected via communicating nodes. The "
+              "cycle contains the node %s. The cycle is found among the "
+              "following nodes. Note that the order of the nodes is arbitrary "
+              "and that the list may include nodes that are not part of the "
+              "cycle.\n%s",
+              predecessor->ToString().c_str(), cyclic_instructions.c_str());
+        }
+        stack.push(predecessor);
+      }
+    }
+  }
+
+  return Status::OK();
+}
+
+Status HloModuleGroupUtil::VerifyComputations(
+    tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+  auto visit_function =
+      [&](HloInstruction* instruction,
+          const std::vector<HloInstruction*>& instruction_group) {
+        return Status::OK();
+      };
+  int64 instructions_count = 0;
+  VisitStates visit_states;
+  for (HloComputation* computation : computations) {
+    // Visit all instructions, and not just from the root instruction of the
+    // computation. This allows us to detect dead cycles (i.e., cycles that
+    // are not reachable from the root) or to enforce an order for the
+    // communication instructions that are not reachable from any roots.
+    for (HloInstruction* instruction : computation->instructions()) {
+      TF_RETURN_IF_ERROR(
+          VisitTopologicalOrder(&visit_states, visit_function, instruction));
+    }
+    instructions_count += computation->instruction_count();
+  }
+
+  // Check if all instructions are visited and are in the visited state.
+  TF_RET_CHECK(visit_states.size() == instructions_count);
+  for (auto& state : visit_states) {
+    TF_RET_CHECK(state.second == VisitState::kVisited);
+  }
+
+  return Status::OK();
+}
+
+StatusOr<std::unique_ptr<HloReachabilityMap>>
+HloModuleGroupUtil::ComputeReachability(
+    tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+  std::list<HloInstruction*> post_order;
+  auto visit_function =
+      [&](HloInstruction* instruction,
+          const std::vector<HloInstruction*>& instruction_group) {
+        post_order.insert(post_order.end(), instruction_group.begin(),
+                          instruction_group.end());
+        return Status::OK();
+      };
+  HloModuleGroupUtil::VisitStates visit_states;
+  for (HloInstruction* root : RootInstructions(computations)) {
+    TF_RETURN_IF_ERROR(
+        VisitTopologicalOrder(&visit_states, visit_function, root));
+  }
+  auto reachability = absl::make_unique<HloReachabilityMap>(post_order);
+  for (HloInstruction* hlo : post_order) {
+    reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
+  }
+  return std::move(reachability);
+}
+
+void HloModuleGroupUtil::UpdateReachabilityThroughInstruction(
+    HloInstruction* instruction, HloReachabilityMap* reachability_map) {
+  std::queue<HloInstruction*> worklist;
+  worklist.push(instruction);
+
+  while (!worklist.empty()) {
+    HloInstruction* item = worklist.front();
+    worklist.pop();
+    if (reachability_map->SetReachabilityToUnion(GlobalPredecessors(item),
+                                                 item)) {
+      for (HloInstruction* successor : GlobalSuccessors(item)) {
+        worklist.push(successor);
+      }
+    }
+  }
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h
new file mode 100644 (file)
index 0000000..c25ca1a
--- /dev/null
@@ -0,0 +1,117 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_reachability.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace xla {
+
+// Collection of utilities for handling HloModuleGroups.
+class HloModuleGroupUtil {
+ public:
+  explicit HloModuleGroupUtil(const HloModuleGroupMetadata& metadata)
+      : metadata_(metadata) {}
+
+  // Returns all unique predecessors of the instruction. This includes:
+  // * predecessors in the same computation: operands and control predecessors
+  // * Recv is a predecessor of Send
+  // * Send is a predecessor of RecvDone
+  // * predecessors of companions (if the instruction is a companion while)
+  // * predecessors' companions (for any predecessor that is a companion while)
+  std::vector<HloInstruction*> GlobalPredecessors(HloInstruction* instruction);
+
+  // Returns all unique successors of the instruction. This includes:
+  // * successors in the same computation: users and control successors
+  // * Send is a successor of Recv
+  // * RecvDone is a predecessor of Send
+  // * successors of companions (if the instruction is a companion while)
+  // * successors' companions (for any successor that is a companion while)
+  std::vector<HloInstruction*> GlobalSuccessors(HloInstruction* instruction);
+
+  // Returns the root instructions of the computations.
+  std::vector<HloInstruction*> RootInstructions(
+      tensorflow::gtl::ArraySlice<HloComputation*> computations);
+
+  // Visit state of each instruction during DFS traversal.
+  enum VisitState {
+    kNotVisited = 0,
+    kVisiting,
+    kVisited,
+  };
+
+  // Function called on each instruction group during the DFS traversal. See the
+  // comment for VisitTopologicalOrder()).
+  using VisitFunction = std::function<Status(
+      HloInstruction* hlo,
+      const std::vector<HloInstruction*>& instruction_group)>;
+
+  // Given the hlo instruction as the root, recursively visits all its
+  // predecessor instructions in DFS order to visit nodes in topological order.
+  //
+  // Note that the DFS traversal does not only visit nodes in the same
+  // computation (parent of the root instruction), but also visits nodes in
+  // different computations connected via communication instructions. During the
+  // traversal, companion While instructions (see the class comment in
+  // HloModuleGroupMetadata) are treated as a single instruction (called
+  // instruction group, which contains only a single instruction if the visiting
+  // node is not a companion while) -- visiting one of the instructions in the
+  // group effectively visits all other instructions in the group, and then all
+  // predecessor instructions of the group are visited.
+  //
+  // * visit_state: map from each instruction to its visit state.
+  // * visit_function: function called when each instruction group.
+  // * root: the root instruction of the traversal.
+  using VisitStates = tensorflow::gtl::FlatMap<HloInstruction*, VisitState>;
+  Status VisitTopologicalOrder(VisitStates* visit_state,
+                               const VisitFunction& visit_function,
+                               HloInstruction* root);
+
+  // Verifies that the computations are well-formed (e.g., no cycles).
+  Status VerifyComputations(
+      tensorflow::gtl::ArraySlice<HloComputation*> computations);
+
+  // Below Reachability utils resemble those in HloComputation, except that
+  // they can handle instructions across multiple computations.
+  //
+  // Creates the reachability map for the instructions in the computations.
+  StatusOr<std::unique_ptr<HloReachabilityMap>> ComputeReachability(
+      tensorflow::gtl::ArraySlice<HloComputation*> computations);
+
+  // Updates the reachability of the given instruction, taking the global
+  // predeccessorss and successors into account.
+  void UpdateReachabilityThroughInstruction(
+      HloInstruction* instruction, HloReachabilityMap* reachability_map);
+
+ private:
+  const HloModuleGroupMetadata& metadata_;
+};
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_