[XLA] Do not recompute flattened sets inside layout assignment.
authorMichael Kuperstein <mkuper@google.com>
Mon, 26 Feb 2018 18:52:05 +0000 (10:52 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
Cache the flattened sets instead of recomputing them. This matters for large graphs, since we may request the flattened set thousands of times on the same instruction, and it may be fairly expensive to construct for large tuples.

PiperOrigin-RevId: 187046642

tensorflow/compiler/xla/service/layout_assignment.cc
tensorflow/compiler/xla/service/layout_assignment.h

index 0668f66..4929300 100644 (file)
@@ -192,17 +192,34 @@ LayoutConstraints::LayoutConstraints(
   }
 }
 
+PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
+    const HloInstruction* instruction) const {
+  auto it = buffer_sets_cache_.find(instruction);
+  if (it != buffer_sets_cache_.end()) {
+    return it->second.get();
+  }
+  auto& buffer_set =
+      buffer_sets_cache_
+          .emplace(instruction, MakeUnique<PointsToSet::BufferSet>())
+          .first->second;
+  const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
+  points_to_set.ForEachElement(
+      [&buffer_set](const ShapeIndex& /*index*/,
+                    const PointsToSet::BufferList& buffers) {
+        buffer_set->insert(buffers.begin(), buffers.end());
+      });
+  return buffer_set.get();
+}
+
 bool LayoutConstraints::OperandBufferForwarded(
     const HloInstruction* instruction, int64 operand_no) const {
   // The operand is potentially forwarded if the intersection of points-to sets
   // of the operand and the instruction is non-empty.
-  auto output_buffers =
-      points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet();
-  auto operand_buffers =
-      points_to_analysis_.GetPointsToSet(instruction->operand(operand_no))
-          .CreateFlattenedSet();
-  for (const LogicalBuffer* output_buffer : output_buffers) {
-    if (operand_buffers.count(output_buffer) > 0) {
+  PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
+  PointsToSet::BufferSet* operand_buffers =
+      GetBufferSet(instruction->operand(operand_no));
+  for (const LogicalBuffer* output_buffer : *output_buffers) {
+    if (operand_buffers->count(output_buffer) > 0) {
       return true;
     }
   }
index 2901858..7126cb5 100644 (file)
@@ -38,6 +38,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace xla {
@@ -199,6 +200,11 @@ class LayoutConstraints {
   string ToString() const;
 
  private:
+  // Find a bufferset in the bufferset cache. This is useful since we can
+  // currently create the flattened buffer set for the same instruction many
+  // times, which is often slow.
+  PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const;
+
   // The set of BufferLayoutConstraints applied to the computation.
   std::unordered_map<const LogicalBuffer*, BufferLayoutConstraint>
       buffer_constraints_;
@@ -221,6 +227,10 @@ class LayoutConstraints {
   // Array-shaped buffers which have not yet been constrained.
   std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
 
+  mutable tensorflow::gtl::FlatMap<const HloInstruction*,
+                                   std::unique_ptr<PointsToSet::BufferSet>>
+      buffer_sets_cache_;
+
   HloComputation* computation_;
 };