[TF:XLA] Add a helper to update HLO reachability.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 22 May 2018 21:52:36 +0000 (14:52 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 21:54:52 +0000 (14:54 -0700)
This can be used if the user does not care if reachability changed after an
update.

PiperOrigin-RevId: 197628007

tensorflow/compiler/xla/service/hlo_reachability.cc
tensorflow/compiler/xla/service/hlo_reachability.h

index 8e16763..4738e46 100644 (file)
@@ -33,17 +33,27 @@ bool HloReachabilityMap::SetReachabilityToUnion(
     const HloInstruction* instruction) {
   BitVector& bit_vector = GetBitVector(instruction);
   tmp_bit_vector_ = bit_vector;
+  SetReachabilityToUnionHelper(inputs, instruction, &bit_vector);
+  return bit_vector != tmp_bit_vector_;
+}
 
+void HloReachabilityMap::FastSetReachabilityToUnion(
+    tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+    const HloInstruction* instruction) {
+  SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction));
+}
+
+void HloReachabilityMap::SetReachabilityToUnionHelper(
+    tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+    const HloInstruction* instruction, BitVector* bit_vector) {
   // If instruction is part of inputs, don't reset the bit_vector.
   if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) {
-    bit_vector.SetToZero();
+    bit_vector->SetToZero();
   }
-  bit_vector.Set(GetIndex(instruction));
+  bit_vector->Set(GetIndex(instruction));
   for (const HloInstruction* input : inputs) {
-    bit_vector.OrWith(GetBitVector(input));
+    bit_vector->OrWith(GetBitVector(input));
   }
-
-  return bit_vector != tmp_bit_vector_;
 }
 
 void HloReachabilityMap::SetReachable(const HloInstruction* a,
index 553ec11..69bb2b3 100644 (file)
@@ -57,6 +57,11 @@ class HloReachabilityMap {
       tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
       const HloInstruction* instruction);
 
+  // As above, but faster because it does not check if the reachability changed.
+  void FastSetReachabilityToUnion(
+      tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+      const HloInstruction* instruction);
+
   // Sets entry so that IsReachable(a, b) will return true
   //
   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
@@ -133,6 +138,11 @@ class HloReachabilityMap {
     return bit_vectors_[GetIndex(instruction)];
   }
 
+  // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
+  void SetReachabilityToUnionHelper(
+      tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+      const HloInstruction* instruction, BitVector* bit_vector);
+
   // Return the index of the given instruction. The value is used to index into
   // the vector of BitVectors and the BitVectors themselves.
   int GetIndex(const HloInstruction* instruction) const {