From c6312773dd5473fb47f73c88c2f5c8f41e20c0fa Mon Sep 17 00:00:00 2001 From: Michael Kuperstein Date: Mon, 26 Feb 2018 10:52:05 -0800 Subject: [PATCH] [XLA] Do not recompute flattened sets inside layout assignment. Cache the flattened sets instead of recomputing them. This matters for large graphs, since we may request the flattened set thousands of times on the same instruction, and it may be fairly expensive to construct for large tuples. PiperOrigin-RevId: 187046642 --- .../compiler/xla/service/layout_assignment.cc | 31 +++++++++++++++++----- .../compiler/xla/service/layout_assignment.h | 10 +++++++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 0668f66..4929300 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -192,17 +192,34 @@ LayoutConstraints::LayoutConstraints( } } +PointsToSet::BufferSet* LayoutConstraints::GetBufferSet( + const HloInstruction* instruction) const { + auto it = buffer_sets_cache_.find(instruction); + if (it != buffer_sets_cache_.end()) { + return it->second.get(); + } + auto& buffer_set = + buffer_sets_cache_ + .emplace(instruction, MakeUnique()) + .first->second; + const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction); + points_to_set.ForEachElement( + [&buffer_set](const ShapeIndex& /*index*/, + const PointsToSet::BufferList& buffers) { + buffer_set->insert(buffers.begin(), buffers.end()); + }); + return buffer_set.get(); +} + bool LayoutConstraints::OperandBufferForwarded( const HloInstruction* instruction, int64 operand_no) const { // The operand is potentially forwarded if the intersection of points-to sets // of the operand and the instruction is non-empty. - auto output_buffers = - points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet(); - auto operand_buffers = - points_to_analysis_.GetPointsToSet(instruction->operand(operand_no)) - .CreateFlattenedSet(); - for (const LogicalBuffer* output_buffer : output_buffers) { - if (operand_buffers.count(output_buffer) > 0) { + PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction); + PointsToSet::BufferSet* operand_buffers = + GetBufferSet(instruction->operand(operand_no)); + for (const LogicalBuffer* output_buffer : *output_buffers) { + if (operand_buffers->count(output_buffer) > 0) { return true; } } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 2901858..7126cb5 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -199,6 +200,11 @@ class LayoutConstraints { string ToString() const; private: + // Find a bufferset in the bufferset cache. This is useful since we can + // currently create the flattened buffer set for the same instruction many + // times, which is often slow. + PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const; + // The set of BufferLayoutConstraints applied to the computation. std::unordered_map buffer_constraints_; @@ -221,6 +227,10 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set unconstrained_buffer_ids_; + mutable tensorflow::gtl::FlatMap> + buffer_sets_cache_; + HloComputation* computation_; }; -- 2.7.4