[XLA] Assign mandatory constraints in a DFS order and non-manatory constraints in...
authorBlake Hechtman <blakehechtman@google.com>
Sun, 4 Feb 2018 07:45:00 +0000 (23:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 4 Feb 2018 07:48:56 +0000 (23:48 -0800)
PiperOrigin-RevId: 184429818

tensorflow/compiler/xla/service/layout_assignment.cc
tensorflow/compiler/xla/service/layout_assignment.h

index 5413b95..fce135e 100644 (file)
@@ -61,8 +61,8 @@ std::ostream& operator<<(std::ostream& out,
 
 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
                                                const LogicalBuffer& buffer,
-                                               bool mandatory)
-    : LayoutConstraint(mandatory), layout_(layout), buffer_(&buffer) {
+                                               bool mandatory, bool dfs)
+    : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
   CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
 }
 
@@ -74,8 +74,8 @@ string BufferLayoutConstraint::ToString() const {
 
 OperandLayoutConstraint::OperandLayoutConstraint(
     const ShapeLayout& shape_layout, const HloInstruction* instruction,
-    int64 operand_no, bool mandatory)
-    : LayoutConstraint(mandatory),
+    int64 operand_no, bool mandatory, bool dfs)
+    : LayoutConstraint(mandatory, dfs),
       shape_layout_(shape_layout),
       instruction_(instruction),
       operand_no_(operand_no) {
@@ -134,7 +134,7 @@ bool LayoutConstraints::OperandBufferForwarded(
 
 Status LayoutConstraints::SetBufferLayout(const Layout& layout,
                                           const LogicalBuffer& buffer,
-                                          bool mandatory) {
+                                          bool mandatory, bool dfs) {
   VLOG(3) << "SetBufferLayout : " << buffer << " : "
           << LayoutUtil::HumanString(layout);
 
@@ -171,10 +171,11 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
   if (!overwrite) {
     iter = buffer_constraints_
                .insert(std::make_pair(
-                   &buffer, BufferLayoutConstraint(layout, buffer, mandatory)))
+                   &buffer,
+                   BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
                .first;
   } else {
-    iter->second = BufferLayoutConstraint(layout, buffer, /*mandatory=*/true);
+    iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
   }
   added_constraints_.push_back(&iter->second);
 
@@ -188,7 +189,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
 
 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
                                            const HloInstruction* instruction,
-                                           int64 operand_no, bool mandatory) {
+                                           int64 operand_no, bool mandatory,
+                                           bool dfs) {
   VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
           << operand_no << " : "
           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
@@ -226,12 +228,12 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
   if (iter == operand_constraints_.end()) {
     auto pair = std::make_pair(
         key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
-                                     instruction, operand_no, mandatory));
+                                     instruction, operand_no, mandatory, dfs));
     iter = operand_constraints_.insert(pair).first;
   } else {
     iter->second =
         OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
-                                operand_no, /*mandatory=*/true);
+                                operand_no, mandatory, dfs);
   }
   added_constraints_.push_back(&iter->second);
 
@@ -240,16 +242,17 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
 
 Status LayoutConstraints::SetArrayOperandLayout(
     const Layout& layout, const HloInstruction* instruction, int64 operand_no,
-    bool mandatory) {
+    bool mandatory, bool dfs) {
   const HloInstruction* operand = instruction->operand(operand_no);
   TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
   Shape shape(operand->shape());
   *shape.mutable_layout() = layout;
   TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
-  return SetOperandLayout(shape, instruction, operand_no, mandatory);
+  return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
 }
 
-Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) {
+Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
+                                          bool dfs) {
   VLOG(3) << "SetResultLayout : "
           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
 
@@ -267,14 +270,15 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) {
   }
 
   result_constraint_.reset(
-      new ResultLayoutConstraint(ShapeLayout(shape_with_layout)));
+      new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
   added_constraints_.push_back(result_constraint_.get());
 
   return Status::OK();
 }
 
 Status LayoutConstraints::SetInstructionLayout(
-    const Shape& shape_with_layout, const HloInstruction* instruction) {
+    const Shape& shape_with_layout, const HloInstruction* instruction,
+    bool mandatory, bool dfs) {
   VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
 
@@ -290,8 +294,8 @@ Status LayoutConstraints::SetInstructionLayout(
   // instruction.
   return ShapeUtil::ForEachSubshapeWithStatus(
       shape_with_layout,
-      [this, instruction](const Shape& subshape,
-                          const ShapeIndex& index) -> Status {
+      [this, instruction, mandatory](const Shape& subshape,
+                                     const ShapeIndex& index) -> Status {
         // The precondition for this method is that the instruction defines all
         // buffers in its output.
         auto buffers =
@@ -300,7 +304,7 @@ Status LayoutConstraints::SetInstructionLayout(
         CHECK_EQ(buffers[0]->instruction(), instruction);
 
         if (ShapeUtil::IsArray(subshape)) {
-          return SetBufferLayout(subshape.layout(), *buffers[0]);
+          return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
         } else {
           return Status::OK();
         }
@@ -394,8 +398,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
       // Constrain the input to the Outfeed instruction to be the expected
       // layout of the Outfeed.
       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
-          instruction->outfeed_shape(), instruction, 0,
-          /*mandatory=*/true));
+          instruction->outfeed_shape(), instruction, 0));
     } else if (instruction->opcode() == HloOpcode::kParameter) {
       // Parameter layouts must match the respective layout in
       // ComputationLayout.
@@ -434,8 +437,8 @@ Status LayoutAssignment::AddMandatoryConstraints(
                                                                  {0}));
         Shape new_shape = channel_constraints->LayoutShapeForChannel(
             recv_buffer_shape, instruction->channel_id());
-        TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
-            new_shape.layout(), *buffer, /*mandatory=*/true));
+        TF_RETURN_IF_ERROR(
+            constraints->SetBufferLayout(new_shape.layout(), *buffer));
       }
     }
   }
@@ -457,7 +460,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
       for (int64 i = 0; i < instruction->operand_count(); ++i) {
         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
             called_computation_layout.parameter_layout(i).shape(), instruction,
-            i, /*mandatory=*/true));
+            i));
       }
     } else if (instruction->opcode() == HloOpcode::kWhile) {
       // Layout of input and output of kWhile instruction must be equal and must
@@ -508,8 +511,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
           body_layout.result_shape(), instruction));
       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
-          body_layout.result_shape(), instruction, 0,
-          /*mandatory=*/true));
+          body_layout.result_shape(), instruction, 0));
     } else if (instruction->opcode() == HloOpcode::kCustomCall) {
       if (!CustomCallRequiresMajorFirstLayout(instruction)) {
         continue;
@@ -533,7 +535,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
                 operand_shape.element_type(),
                 AsInt64Slice(operand_shape.dimensions()));
         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
-            row_major_operand_shape, instruction, i, /*mandatory=*/true));
+            row_major_operand_shape, instruction, i));
       }
     }
   }
@@ -907,7 +909,11 @@ Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
   auto add_new_constraints_to_worklist = [constraints, &worklist]() {
     // Add constraints to the front of the deque for DFS ordering.
     for (auto* constraint : constraints->ConsumeAddedConstraints()) {
-      worklist.push_front(constraint);
+      if (constraint->dfs()) {
+        worklist.push_front(constraint);
+      } else {
+        worklist.push_back(constraint);
+      }
     }
   };
   add_new_constraints_to_worklist();
@@ -1390,7 +1396,7 @@ Status LayoutAssignment::RunOnComputation(
   // Add any backend-specific constraints.
   TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
 
-  // Propagates layouts from an HLO to its neighbors.
+  // Propagates layouts from mandatory and backend constraints.
   TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
 
   // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
@@ -1455,7 +1461,12 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
   // Assign layouts to computations in an order such that a callee computation
   // is handled before its caller computation. This ensures that the layout of
   // all callers of a computation will agree.
+  std::list<HloComputation*> computation_post_order =
+      module->MakeComputationPostOrder();
   for (auto* computation : module->MakeComputationPostOrder()) {
+    if (computation->IsFusionComputation()) {
+      continue;
+    }
     // Clear existing layouts of the instructions.  All layouts must be assigned
     // by the LayoutAssignment pass, except for those on infeeds, parameters,
     // and the computation result. The latter two are specified in
@@ -1467,13 +1478,10 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
         LayoutUtil::ClearLayout(instruction->mutable_shape());
       }
     }
-
     if (computation == module->entry_computation()) {
       TF_RETURN_IF_ERROR(RunOnComputation(
           *entry_computation_layout_, *points_to_analysis,
           module->entry_computation(), channel_layout_constraints_));
-    } else if (computation->IsFusionComputation()) {
-      continue;
     } else {
       ComputationLayout computation_layout(computation->ComputeProgramShape());
       // Setting all embedded computations to the default layout is potentially
index 6bfae29..2901858 100644 (file)
@@ -46,7 +46,8 @@ namespace xla {
 // gathered together in LayoutConstraints object.
 class LayoutConstraint {
  public:
-  LayoutConstraint(bool mandatory) : mandatory_(mandatory) {}
+  LayoutConstraint(bool mandatory, bool dfs)
+      : mandatory_(mandatory), dfs_(dfs) {}
   virtual ~LayoutConstraint() = default;
 
   virtual string ToString() const = 0;
@@ -54,8 +55,12 @@ class LayoutConstraint {
   // True if this constraint cannot be overwritten by a different constraint.
   bool mandatory() const { return mandatory_; }
 
+  // When true, propagate in DFS. When false, constraint will propagate in BFS.
+  bool dfs() const { return dfs_; }
+
  private:
   bool mandatory_;
+  bool dfs_;
 };
 
 std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint);
@@ -65,7 +70,7 @@ std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint);
 class BufferLayoutConstraint : public LayoutConstraint {
  public:
   BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer,
-                         bool mandatory);
+                         bool mandatory, bool dfs);
 
   const LogicalBuffer& buffer() const { return *buffer_; }
   const Layout& layout() const { return layout_; }
@@ -86,7 +91,7 @@ class OperandLayoutConstraint : public LayoutConstraint {
  public:
   OperandLayoutConstraint(const ShapeLayout& shape_layout,
                           const HloInstruction* instruction, int64 operand_no,
-                          bool mandatory);
+                          bool mandatory, bool dfs);
 
   const ShapeLayout& shape_layout() const { return shape_layout_; }
   const HloInstruction* instruction() const { return instruction_; }
@@ -106,8 +111,10 @@ class OperandLayoutConstraint : public LayoutConstraint {
 // Constraint on the layout of the result of the entry computation.
 class ResultLayoutConstraint : public LayoutConstraint {
  public:
-  explicit ResultLayoutConstraint(const ShapeLayout& shape_layout)
-      : LayoutConstraint(/*mandatory=*/true), shape_layout_(shape_layout) {}
+  explicit ResultLayoutConstraint(const ShapeLayout& shape_layout,
+                                  bool dfs = false)
+      : LayoutConstraint(/*mandatory=*/true, dfs),
+        shape_layout_(shape_layout) {}
 
   const ShapeLayout& shape_layout() const { return shape_layout_; }
   string ToString() const override;
@@ -157,23 +164,25 @@ class LayoutConstraints {
   // operand of the instruction, or the layout of the result of the computation,
   // respectively.
   Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer,
-                         bool mandatory = true);
+                         bool mandatory = true, bool dfs = true);
   Status SetOperandLayout(const Shape& shape_with_layout,
                           const HloInstruction* instruction, int64 operand_no,
-                          bool mandatory = true);
-  Status SetResultLayout(const Shape& shape_with_layout);
+                          bool mandatory = true, bool dfs = true);
+  Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true);
 
   // Convenience wrapper around SetOperandLayout for setting the layout of a
   // operand using a Layout object. The operand must be array-shaped.
   Status SetArrayOperandLayout(const Layout& layout,
                                const HloInstruction* instruction,
-                               int64 operand_no, bool mandatory = true);
+                               int64 operand_no, bool mandatory = true,
+                               bool dfs = true);
 
   // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers
   // created by the instruction to the layouts in the given shape. The
   // instruction must define every logical buffer in its output.
   Status SetInstructionLayout(const Shape& shape_with_layout,
-                              const HloInstruction* instruction);
+                              const HloInstruction* instruction,
+                              bool mandatory = true, bool dfs = true);
 
   // Returns true if any buffer in the given operand is forwarded to the output
   // of the given instruction. For example, the Tuple instruction forwards the