From a3a5e5cad0bfdd28f43223980f64ce367c732aad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 22 May 2018 14:52:36 -0700 Subject: [PATCH] [TF:XLA] Add a helper to update HLO reachability. This can be used if the user does not care if reachability changed after an update. PiperOrigin-RevId: 197628007 --- tensorflow/compiler/xla/service/hlo_reachability.cc | 20 +++++++++++++++----- tensorflow/compiler/xla/service/hlo_reachability.h | 10 ++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 8e16763..4738e46 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -33,17 +33,27 @@ bool HloReachabilityMap::SetReachabilityToUnion( const HloInstruction* instruction) { BitVector& bit_vector = GetBitVector(instruction); tmp_bit_vector_ = bit_vector; + SetReachabilityToUnionHelper(inputs, instruction, &bit_vector); + return bit_vector != tmp_bit_vector_; +} +void HloReachabilityMap::FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction) { + SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction)); +} + +void HloReachabilityMap::SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { - bit_vector.SetToZero(); + bit_vector->SetToZero(); } - bit_vector.Set(GetIndex(instruction)); + bit_vector->Set(GetIndex(instruction)); for (const HloInstruction* input : inputs) { - bit_vector.OrWith(GetBitVector(input)); + bit_vector->OrWith(GetBitVector(input)); } - - return bit_vector != tmp_bit_vector_; } void HloReachabilityMap::SetReachable(const HloInstruction* a, diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 553ec11..69bb2b3 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -57,6 +57,11 @@ class HloReachabilityMap { tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); + // As above, but faster because it does not check if the reachability changed. + void FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction); + // Sets entry so that IsReachable(a, b) will return true // // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency @@ -133,6 +138,11 @@ class HloReachabilityMap { return bit_vectors_[GetIndex(instruction)]; } + // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. + void SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector); + // Return the index of the given instruction. The value is used to index into // the vector of BitVectors and the BitVectors themselves. int GetIndex(const HloInstruction* instruction) const { -- 2.7.4