Check that the module group metadata builder correctly detects whether there are...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 12 May 2018 14:13:06 +0000 (07:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 12 May 2018 14:15:41 +0000 (07:15 -0700)
PiperOrigin-RevId: 196369766

tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
tensorflow/compiler/xla/service/hlo_module_group_metadata.h

index 67f4c37..a41cfa7 100644 (file)
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
 
+#include <sstream>
 #include <string>
 #include <utility>
 
@@ -110,6 +111,31 @@ Status HloModuleGroupMetadata::Build() {
       TF_RETURN_IF_ERROR(computation->Accept(visitor));
     }
   }
+  TF_RETURN_IF_ERROR(VerifyCompanionSets());
+  return Status::OK();
+}
+
+Status HloModuleGroupMetadata::VerifyCompanionSets() const {
+  // TODO(dlibenzi): Migrate this to use the device instead of module ID, once
+  // the kDomain CL goes in.
+  for (const auto& companions : companion_sets_) {
+    // A companion set must be composed at most of an instruction per
+    // device/module.
+    std::unordered_set<int64> devices;
+    for (HloInstruction* instruction : *companions) {
+      int64 device = GetModuleId(instruction->parent()->parent());
+      if (!devices.insert(device).second) {
+        std::stringstream ss;
+        ss << "Companion set:" << std::endl;
+        for (HloInstruction* hlo : *companions) {
+          ss << "  " << hlo->name() << " ("
+             << GetModuleId(hlo->parent()->parent()) << ")" << std::endl;
+        }
+        ss << "has multiple instructions on the same device";
+        return FailedPrecondition("%s", ss.str().c_str());
+      }
+    }
+  }
   return Status::OK();
 }
 
index 88ed9a2..3ef4542 100644 (file)
@@ -207,6 +207,11 @@ class HloModuleGroupMetadata {
   // within the graph.
   Status CheckCommunicatingInstruction(HloInstruction* instruction) const;
 
+  // Performs a consistency check on the companion sets built for the input
+  // modules. Check that a companion set does not include instructions from the
+  // same module/device.
+  Status VerifyCompanionSets() const;
+
   // Retrieves a pointer to the stored TrackedInstruction associated with a
   // tracked computation, or nullptr in case such computation is not tracked.
   const TrackedInstruction* GetTrackedInstruction(