const HloInstruction* instruction) {
BitVector& bit_vector = GetBitVector(instruction);
tmp_bit_vector_ = bit_vector;
+ SetReachabilityToUnionHelper(inputs, instruction, &bit_vector);
+ return bit_vector != tmp_bit_vector_;
+}
+void HloReachabilityMap::FastSetReachabilityToUnion(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ const HloInstruction* instruction) {
+ SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction));
+}
+
+void HloReachabilityMap::SetReachabilityToUnionHelper(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ const HloInstruction* instruction, BitVector* bit_vector) {
// If instruction is part of inputs, don't reset the bit_vector.
if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) {
- bit_vector.SetToZero();
+ bit_vector->SetToZero();
}
- bit_vector.Set(GetIndex(instruction));
+ bit_vector->Set(GetIndex(instruction));
for (const HloInstruction* input : inputs) {
- bit_vector.OrWith(GetBitVector(input));
+ bit_vector->OrWith(GetBitVector(input));
}
-
- return bit_vector != tmp_bit_vector_;
}
void HloReachabilityMap::SetReachable(const HloInstruction* a,
tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
const HloInstruction* instruction);
+ // As above, but faster because it does not check if the reachability changed.
+ void FastSetReachabilityToUnion(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ const HloInstruction* instruction);
+
// Sets entry so that IsReachable(a, b) will return true
//
// !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
return bit_vectors_[GetIndex(instruction)];
}
+ // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
+ void SetReachabilityToUnionHelper(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ const HloInstruction* instruction, BitVector* bit_vector);
+
// Return the index of the given instruction. The value is used to index into
// the vector of BitVectors and the BitVectors themselves.
int GetIndex(const HloInstruction* instruction) const {