Do not force default layout when there is no need to.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 18 May 2018 00:11:47 +0000 (17:11 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 00:14:53 +0000 (17:14 -0700)
Allow the inner computations to negotiate a root and parameter layouts different from default.
END_PUBLIC

RELNOTES: n/a

---------------------
BEGIN_PUBLIC
Automated g4 rollback of changelist 194293187

PiperOrigin-RevId: 197076025

16 files changed:
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/computation_layout.cc
tensorflow/compiler/xla/service/computation_layout.h
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
tensorflow/compiler/xla/service/hlo_instruction.h
tensorflow/compiler/xla/service/interpreter/compiler.cc
tensorflow/compiler/xla/service/layout_assignment.cc
tensorflow/compiler/xla/service/layout_assignment.h
tensorflow/compiler/xla/service/layout_assignment_test.cc
tensorflow/compiler/xla/service/service.cc
tensorflow/compiler/xla/service/tuple_simplifier.cc

index 83ecea0..0a50f00 100644 (file)
@@ -2052,10 +2052,12 @@ cc_library(
     deps = [
         ":computation_layout",
         ":hlo",
+        ":hlo_dce",
         ":hlo_graph_dumper",
         ":hlo_pass",
         ":logical_buffer",
         ":tuple_points_to_analysis",
+        ":tuple_simplifier",
         "//tensorflow/compiler/xla:shape_layout",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
@@ -2594,6 +2596,7 @@ cc_library(
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/core:lib",
     ],
 )
index d2d4f14..cb61f3d 100644 (file)
@@ -23,12 +23,15 @@ limitations under the License.
 
 namespace xla {
 
-ComputationLayout::ComputationLayout(const ProgramShape& program_shape)
+ComputationLayout::ComputationLayout(const ProgramShape& program_shape,
+                                     bool ignore_layouts)
     : result_layout_(program_shape.result()) {
   for (auto& shape : program_shape.parameters()) {
     parameter_layouts_.emplace_back(shape);
   }
-  SetToDefaultLayout();
+  if (ignore_layouts) {
+    SetToDefaultLayout();
+  }
 }
 
 void ComputationLayout::SetToDefaultLayout() {
index 80e1024..53c3a3f 100644 (file)
@@ -34,8 +34,9 @@ class ComputationLayout {
  public:
   // Constructs a ComputationLayout from a ProgramShape. The layouts of the
   // parameters and results are set to the default layout. Layouts in the
-  // ProgramShape are ignored.
-  explicit ComputationLayout(const ProgramShape& program_shape);
+  // ProgramShape are ignored if ignore_layouts is true.
+  explicit ComputationLayout(const ProgramShape& program_shape,
+                             bool ignore_layouts = true);
 
   // Returns the layout of a particular parameter.
   const ShapeLayout& parameter_layout(int64 param_no) const {
index 7ae04e8..25b18ef 100644 (file)
@@ -304,7 +304,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
       ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
 
   pipeline.AddPass<CpuLayoutAssignment>(
-      module->device_entry_computation_layout(), &target_machine_features);
+      module->mutable_device_entry_computation_layout(),
+      &target_machine_features);
   // The LayoutAssignment pass may leave behind kCopy instructions which are
   // duplicate or NOPs, so remove them with algebraic simplification and CSE.
   pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
index 53536a2..3c4fe68 100644 (file)
@@ -29,7 +29,7 @@ namespace cpu {
 class CpuLayoutAssignment : public LayoutAssignment {
  public:
   explicit CpuLayoutAssignment(
-      const ComputationLayout& entry_computation_layout,
+      ComputationLayout* entry_computation_layout,
       const TargetMachineFeatures* target_machine_features)
       : LayoutAssignment(entry_computation_layout),
         target_machine_features_(*target_machine_features) {}
index f6c93d3..429fc7b 100644 (file)
@@ -54,7 +54,7 @@ class CpuLayoutAssignmentTest : public HloTestBase {
         [](int64 shape_size) {
           return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
         });
-    cpu::CpuLayoutAssignment layout_assignment(*entry_computation_layout,
+    cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout,
                                                &target_machine_features);
     EXPECT_IS_OK(layout_assignment.Run(module).status());
   }
@@ -321,7 +321,7 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
       [](int64 shape_size) {
         return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
       });
-  cpu::CpuLayoutAssignment layout_assignment(computation_layout,
+  cpu::CpuLayoutAssignment layout_assignment(&computation_layout,
                                              &target_machine_features);
   TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
                       layout_assignment.Run(module));
index df494a1..d50153d 100644 (file)
@@ -247,7 +247,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
   {
     HloPassPipeline pipeline("layout_assignment");
     pipeline.AddPass<GpuLayoutAssignment>(
-        hlo_module->device_entry_computation_layout());
+        hlo_module->mutable_device_entry_computation_layout());
 
     // The LayoutAssignment pass may leave behind kCopy instructions which are
     // duplicate or NOPs, so remove them with algebraic simplification and CSE.
index 51aae79..86a3a71 100644 (file)
@@ -27,8 +27,7 @@ namespace gpu {
 // layout constraints for operands and results of library calls.
 class GpuLayoutAssignment : public LayoutAssignment {
  public:
-  explicit GpuLayoutAssignment(
-      const ComputationLayout& entry_computation_layout)
+  explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout)
       : LayoutAssignment(entry_computation_layout) {}
   ~GpuLayoutAssignment() override {}
 
index 7c80195..4c45d2e 100644 (file)
@@ -69,7 +69,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
         *computation_layout.mutable_result_layout() =
             ShapeLayout(result_shape_with_layout);
 
-        GpuLayoutAssignment layout_assignment(computation_layout);
+        GpuLayoutAssignment layout_assignment(&computation_layout);
         EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
         for (const HloInstruction* operand : add->operands()) {
@@ -156,7 +156,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
         *computation_layout.mutable_result_layout() = ShapeLayout(result_shape);
       }
 
-      GpuLayoutAssignment layout_assignment(computation_layout);
+      GpuLayoutAssignment layout_assignment(&computation_layout);
       EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
       // The first operand to batchnorm should have the same layout as the
@@ -225,7 +225,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
                 {result_shape, offset_scale_shape, offset_scale_shape}));
       }
 
-      GpuLayoutAssignment layout_assignment(computation_layout);
+      GpuLayoutAssignment layout_assignment(&computation_layout);
       EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
       // The first operand to batchnorm should have the same layout as the
@@ -305,7 +305,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
                   {result_shape, scale_shape, scale_shape}));
         }
 
-        GpuLayoutAssignment layout_assignment(computation_layout);
+        GpuLayoutAssignment layout_assignment(&computation_layout);
         EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
         // The first and fourth operands to the batchnorm call should have the
index db78539..234dbc8 100644 (file)
@@ -1108,6 +1108,14 @@ class HloInstruction {
   void clear_sharding() { sharding_ = nullptr; }
   // Return true if this operator has a sharding assigned.
   bool has_sharding() const { return sharding_ != nullptr; }
+  // Checks whether the instruction has compatible sharding with the other
+  // instruction.
+  bool has_compatible_sharding(const HloInstruction* other) const {
+    if (!has_sharding()) {
+      return !other->has_sharding();
+    }
+    return other->has_sharding() ? sharding() == other->sharding() : false;
+  }
 
   // When creating a new instruction which either replaces, or shifts up (kCopy
   // insertion case), another instruction, we need to make sure the certain
index 3ff1551..c59189d 100644 (file)
@@ -44,8 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
   HloPassPipeline pipeline("Interpreter");
 
   pipeline.AddPass<LayoutAssignment>(
-      hlo_module->device_entry_computation_layout());
-
+      hlo_module->mutable_device_entry_computation_layout());
   return pipeline.Run(hlo_module).status();
 }
 
index cfa7ba5..7067b6f 100644 (file)
@@ -31,10 +31,12 @@ limitations under the License.
 #include "tensorflow/compiler/xla/ptr_util.h"
 #include "tensorflow/compiler/xla/service/computation_layout.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
 #include "tensorflow/compiler/xla/shape_layout.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -400,9 +402,9 @@ string LayoutConstraints::ToString() const {
 }
 
 Status LayoutAssignment::AddMandatoryConstraints(
-    const ComputationLayout& computation_layout,
-    const ChannelLayoutConstraints* channel_constraints,
-    HloComputation* computation, LayoutConstraints* constraints) {
+    const ComputationLayout* computation_layout,
+    ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
+    LayoutConstraints* constraints) {
   VLOG(3) << "Adding mandatory layout constraints to computation "
           << computation->name();
 
@@ -424,11 +426,16 @@ Status LayoutAssignment::AddMandatoryConstraints(
       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
           instruction->outfeed_shape(), instruction, 0));
     } else if (instruction->opcode() == HloOpcode::kParameter) {
-      // Parameter layouts must match the respective layout in
-      // ComputationLayout.
-      shape_with_layout =
-          &computation_layout.parameter_layout(instruction->parameter_number())
-               .shape();
+      if (computation_layout != nullptr) {
+        const ShapeLayout& parameter_layout =
+            computation_layout->parameter_layout(
+                instruction->parameter_number());
+        if (parameter_layout.LayoutIsSet()) {
+          // Parameter layouts must match the respective layout in
+          // ComputationLayout, if there is one.
+          shape_with_layout = &parameter_layout.shape();
+        }
+      }
     }
     if (shape_with_layout != nullptr) {
       TF_RETURN_IF_ERROR(
@@ -493,9 +500,8 @@ Status LayoutAssignment::AddMandatoryConstraints(
       HloComputation* body = instruction->while_body();
       HloComputation* condition = instruction->while_condition();
       const HloInstruction* init = instruction->operand(0);
-      const ComputationLayout& body_layout =
-          FindOrDie(computation_layouts_, body);
-      const ComputationLayout& condition_layout =
+      ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
+      ComputationLayout& condition_layout =
           FindOrDie(computation_layouts_, condition);
 
       // Check a few invariants irrespective of layout.
@@ -508,26 +514,19 @@ Status LayoutAssignment::AddMandatoryConstraints(
                                    condition_layout.parameter_shape(0)));
       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
 
-      // Return error if earlier layout assignment of the embedded computations
-      // has produced conflicting layouts.
-      if (!ShapeUtil::Equal(body_layout.result_shape(),
-                            body_layout.parameter_shape(0))) {
-        return InternalError(
-            "Parameter and result of body computation %s of while instruction "
-            "%s have different layouts: %s vs %s",
-            body->name().c_str(), instruction->name().c_str(),
-            ShapeUtil::HumanString(body_layout.result_shape()).c_str(),
-            ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str());
+      if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
+        VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
+                << " while=" << instruction->name()
+                << " shape=" << body_layout.result_layout().ToString();
+        *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
       }
-      if (!ShapeUtil::Equal(body->root_instruction()->shape(),
-                            condition->parameter_instruction(0)->shape())) {
-        return InternalError(
-            "Parameter of condition computation %s of while instruction "
-            "%s does not match body computation %s result: %s vs %s",
-            condition->name().c_str(), instruction->name().c_str(),
-            body->name().c_str(),
-            ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(),
-            ShapeUtil::HumanString(body_layout.result_shape()).c_str());
+      if (condition_layout.parameter_layout(0) !=
+          body_layout.parameter_layout(0)) {
+        VLOG(2) << "Reset %while condition parameter layout: cond="
+                << condition->name() << " while=" << instruction->name()
+                << " shape=" << body_layout.parameter_layout(0).ToString();
+        *condition_layout.mutable_parameter_layout(0) =
+            body_layout.parameter_layout(0);
       }
 
       // Constrain the output and the operand of the while instruction to match
@@ -557,7 +556,20 @@ Status LayoutAssignment::AddMandatoryConstraints(
                                    true_computation_layout.parameter_shape(0)));
       DCHECK(ShapeUtil::Compatible(
           false_operand->shape(), false_computation_layout.parameter_shape(0)));
-
+      if (true_computation_layout.result_layout() !=
+          false_computation_layout.result_layout()) {
+        // We assign layouts in DFS fashion, so the true and false computations
+        // might have negotiated a different layout. But for the conditional
+        // instruction POV the layout must match, so we run again on the false
+        // computation, this time with proper computation layout.
+        VLOG(2) << "Reset %conditional false computation result layout: "
+                   "false_computation="
+                << false_computation->name()
+                << " conditional=" << instruction->name() << " shape="
+                << true_computation_layout.result_layout().ToString();
+        *false_computation_layout.mutable_result_layout() =
+            true_computation_layout.result_layout();
+      }
       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
           true_computation_layout.result_shape(), instruction));
       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
@@ -593,10 +605,14 @@ Status LayoutAssignment::AddMandatoryConstraints(
       }
     }
   }
-
-  // Finally set the result layout to match ComputationLayout.
-  return constraints->SetResultLayout(
-      computation_layout.result_layout().shape());
+  // Finally set the result layout to match ComputationLayout, if there is one.
+  if (computation_layout != nullptr) {
+    const ShapeLayout& result_layout = computation_layout->result_layout();
+    if (result_layout.LayoutIsSet()) {
+      TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
+    }
+  }
+  return Status::OK();
 }
 
 namespace {
@@ -760,6 +776,7 @@ StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
     HloInstruction* copy =
         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
             instruction->shape(), HloOpcode::kCopy, instruction));
+    RegisterAddedCopy(copy);
     SetupCopiedInstruction(*instruction, copy, {});
     LayoutUtil::ClearLayout(copy->mutable_shape());
     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
@@ -783,13 +800,19 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
   TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
 
   if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
+    VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
+            << instruction->ToString();
     // Operand layout already matches our constraint. Nothing to do.
     return Status::OK();
   }
+  VLOG(4) << "Operand " << operand->ToString() << " layout does not match "
+          << operand_layout.ToString() << " in " << instruction->ToString();
 
   TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
                       CreateCopyWithNewLayout(operand_layout.shape(), operand));
 
+  VLOG(4) << "New copy of " << operand->ToString() << " is "
+          << operand_copy->ToString();
   return instruction->ReplaceOperandWith(operand_no, operand_copy);
 }
 
@@ -896,32 +919,31 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
       }
     }
   }
-
-  // Finally verify the result layout matches the layout of the entry
+  // Finally verify the result layout, if set, matches the layout of the entry
   // computation root.
-  TF_RET_CHECK(ShapeUtil::Equal(
-      module->entry_computation()->root_instruction()->shape(),
+  const ShapeLayout& result_layout =
       FindOrDie(computation_layouts_, module->entry_computation())
-          .result_layout()
-          .shape()));
-
+          .result_layout();
+  if (result_layout.LayoutIsSet()) {
+    TF_RET_CHECK(ShapeUtil::Equal(
+        module->entry_computation()->root_instruction()->shape(),
+        result_layout.shape()));
+  }
   return Status::OK();
 }
 
 LayoutAssignment::LayoutAssignment(
-    const ComputationLayout& entry_computation_layout,
+    ComputationLayout* entry_computation_layout,
     ChannelLayoutConstraints* channel_constraints)
     : entry_computation_layout_(entry_computation_layout),
       channel_layout_constraints_(channel_constraints) {
-  VLOG(1) << "entry computation layout given to layout assignment: "
-          << entry_computation_layout_.ToString();
+  VLOG(1) << "Entry computation layout given to layout assignment: "
+          << entry_computation_layout_->ToString();
   // Layouts of all parameter instructions must be set.
   for (const ShapeLayout& parameter_layout :
-       entry_computation_layout_.parameter_layouts()) {
+       entry_computation_layout_->parameter_layouts()) {
     CHECK(parameter_layout.LayoutIsSet());
   }
-  // TODO(b/29118294): Choose a better layout if the result layout is not set.
-  CHECK(entry_computation_layout_.result_layout().LayoutIsSet());
 }
 
 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
@@ -1481,16 +1503,60 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
   return Status::OK();
 }
 
+Status LayoutAssignment::CalculateComputationLayout(
+    HloComputation* computation) {
+  ComputationLayout computation_layout(computation->ComputeProgramShape(),
+                                       /*ignore_layouts=*/false);
+  InsertOrDie(&computation_layouts_, computation, computation_layout);
+  VLOG(2) << "  Calculated ComputationLayout = "
+          << computation_layout.ToString();
+  return Status::OK();
+}
+
+Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
+  // Clear existing layouts of the instructions.  All layouts must be assigned
+  // by the LayoutAssignment pass, except for those on infeeds, parameters,
+  // and the computation result. The latter two are specified in
+  // computation_layout, so we only need to keep the existing layouts for
+  // infeeds.  Clearing the layouts here avoids hiding potential bugs in the
+  // layout assignment pass that may accidently use the existing layout.
+  for (HloInstruction* instruction : computation->instructions()) {
+    if (instruction->opcode() == HloOpcode::kBitcast) {
+      // bitcasts are inherently layout sensitive and so a bitcast instruction
+      // present in the IR before layout assignment is a bug.
+      return InternalError(
+          "Unexpected bitcast operation seen during layout assignment: %s.",
+          instruction->ToString().c_str());
+    }
+    if (instruction->opcode() != HloOpcode::kInfeed) {
+      LayoutUtil::ClearLayout(instruction->mutable_shape());
+    }
+  }
+  return Status::OK();
+}
+
 Status LayoutAssignment::RunOnComputation(
-    const ComputationLayout& computation_layout,
+    ComputationLayout* computation_layout,
     const TuplePointsToAnalysis& points_to_analysis,
     HloComputation* computation,
     ChannelLayoutConstraints* channel_constraints) {
-  DCHECK(computation_layout.LayoutIsSet());
-  InsertOrDie(&computation_layouts_, computation, computation_layout);
   VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
           << ")";
-  VLOG(2) << "  ComputationLayout = " << computation_layout.ToString();
+  TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
+  if (computation_layout != nullptr) {
+    auto it = computation_layouts_.find(computation);
+    if (it == computation_layouts_.end()) {
+      VLOG(2) << "  New ComputationLayout = " << computation_layout->ToString();
+      computation_layouts_.emplace(computation, *computation_layout);
+    } else {
+      TF_RET_CHECK(computation_layout == &it->second ||
+                   computation_layout == entry_computation_layout_);
+      VLOG(2) << "  Existing ComputationLayout = "
+              << computation_layout->ToString();
+    }
+  } else {
+    VLOG(2) << "  No ComputationLayout specified (will be calculated)";
+  }
 
   // Construct LayoutConstraints with all layout constraints of the computation.
   LayoutConstraints constraints(points_to_analysis, computation);
@@ -1533,12 +1599,19 @@ Status LayoutAssignment::RunOnComputation(
     CHECK_LT(constraints.unconstrained_buffer_ids().size(),
              unconstrained_count);
   }
-
   // All logical buffers should have constraints at this point. All that
   // remains is assign the constraints to the buffers and infer layouts for
   // aliased buffers.
   TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
 
+  // If the computation layout wasn't specified, now it is the time to compute
+  // it according to the parameters and root instruction layouts.
+  // This allows the first pass through this API to record the best flowing
+  // layout to parameters and root instruction.
+  if (computation_layout == nullptr) {
+    TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
+  }
+
   // Record the layouts assigned for any communication ops in
   // channel_constraints so that they are constrained for future modules.
   for (HloInstruction* instruction : computation->instructions()) {
@@ -1553,6 +1626,34 @@ Status LayoutAssignment::RunOnComputation(
   return Status::OK();
 }
 
+Status LayoutAssignment::PropagateComputationLayouts(
+    HloComputation* computation, ComputationLayout* computation_layout) {
+  ComputationLayout computed_computation_layout(
+      computation->ComputeProgramShape(),
+      /*ignore_layouts=*/false);
+  for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) {
+    ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
+    if (!param_layout->LayoutIsSet()) {
+      VLOG(4) << "Assigning layout to parameter " << i << " of computation "
+              << computation->name() << ": "
+              << computed_computation_layout.parameter_layout(i).ToString();
+      *param_layout = computed_computation_layout.parameter_layout(i);
+    } else {
+      TF_RET_CHECK(computed_computation_layout.parameter_layout(i) ==
+                   *param_layout);
+    }
+  }
+  ShapeLayout* result_layout = computation_layout->mutable_result_layout();
+  if (!result_layout->LayoutIsSet()) {
+    VLOG(4) << "Assigning result layout of computation " << computation->name()
+            << ": " << computed_computation_layout.result_layout().ToString();
+    *result_layout = computed_computation_layout.result_layout();
+  } else {
+    TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout);
+  }
+  return Status::OK();
+}
+
 StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
   VLOG(2) << "Running layout assignment on module " << module->name();
   XLA_VLOG_LINES(3, module->ToString());
@@ -1561,52 +1662,45 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
                                 "before layout assignment",
                                 module->config().debug_options());
   }
-
-  TF_ASSIGN_OR_RETURN(auto points_to_analysis,
-                      TuplePointsToAnalysis::Run(module));
-
-  // Assign layouts to computations in an order such that a callee computation
-  // is handled before its caller computation. This ensures that the layout of
-  // all callers of a computation will agree.
-  std::list<HloComputation*> computation_post_order =
-      module->MakeComputationPostOrder();
-  for (auto* computation : module->MakeComputationPostOrder()) {
-    if (computation->IsFusionComputation()) {
-      continue;
-    }
-    // Clear existing layouts of the instructions.  All layouts must be assigned
-    // by the LayoutAssignment pass, except for those on infeeds, parameters,
-    // and the computation result. The latter two are specified in
-    // computation_layout, so we only need to keep the existing layouts for
-    // infeeds.  Clearing the layouts here avoids hiding potential bugs in the
-    // layout assignment pass that may accidently use the existing layout.
-    for (HloInstruction* instruction : computation->instructions()) {
-      if (instruction->opcode() == HloOpcode::kBitcast) {
-        // bitcasts are inherently layout sensitive and so a bitcast instruction
-        // present in the IR before layout assignment is a bug.
-        return InternalError(
-            "Unexpected bitcast operation seen during layout assignment: %s.",
-            instruction->ToString().c_str());
+  TF_RETURN_IF_ERROR(Init());
+
+  // We do two passes. The first one we pass a nullptr ComputationLayout to
+  // the RunOnComputation() calls (for non entry computations), and we register
+  // the ComputationLayout which are naturally flowing in DFS fashion to the
+  // parameters and root instruction.
+  // Walking in DFS mode though, means that we can end up with incorrect layouts
+  // when seen from an outer instruction, which has across-computation
+  // constraints to impose.
+  // For example, the kWhile instruction needs to enforce the same layouts for
+  // the parameters and root of the bosy, as well as the condition parameters.
+  // Similarly, the kConditional instruction needs to enforce the same layouts
+  // for the root of the true and false computations.
+  // So in the first pass, while allowing the layouts to flow to parameters and
+  // root, we also fix up the eventually inconsistent ComputationLayout, which
+  // will be then made mandatory by the second pass.
+  for (int64 i = 0; i < 2; ++i) {
+    TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
+    TF_ASSIGN_OR_RETURN(auto points_to_analysis,
+                        TuplePointsToAnalysis::Run(module));
+    for (auto* computation : module->MakeComputationPostOrder()) {
+      if (computation->IsFusionComputation()) {
+        continue;
       }
-      if (instruction->opcode() != HloOpcode::kInfeed) {
-        LayoutUtil::ClearLayout(instruction->mutable_shape());
+      if (computation == module->entry_computation()) {
+        TF_RETURN_IF_ERROR(RunOnComputation(
+            entry_computation_layout_, *points_to_analysis,
+            module->entry_computation(), channel_layout_constraints_));
+      } else {
+        ComputationLayout* computation_layout =
+            (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation);
+        TF_RETURN_IF_ERROR(RunOnComputation(computation_layout,
+                                            *points_to_analysis, computation,
+                                            channel_layout_constraints_));
       }
     }
-    if (computation == module->entry_computation()) {
-      TF_RETURN_IF_ERROR(RunOnComputation(
-          entry_computation_layout_, *points_to_analysis,
-          module->entry_computation(), channel_layout_constraints_));
-    } else {
-      ComputationLayout computation_layout(computation->ComputeProgramShape());
-      // Setting all embedded computations to the default layout is potentially
-      // suboptimal.
-      computation_layout.SetToDefaultLayout();
-      TF_RETURN_IF_ERROR(RunOnComputation(computation_layout,
-                                          *points_to_analysis, computation,
-                                          channel_layout_constraints_));
-    }
   }
-
+  TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
+                                                 entry_computation_layout_));
   TF_RETURN_IF_ERROR(CheckLayouts(module));
 
   VLOG(3) << "After layout assignment:";
@@ -1616,9 +1710,54 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
                                 "after layout assignment",
                                 module->config().debug_options());
   }
-
   // All layouts are reset then reassigned by this pass.
   return true;
 }
 
+Status LayoutAssignment::Init() {
+  computation_layouts_.clear();
+  return Status::OK();
+}
+
+Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
+  // Clear all the copies which have been added, and all the related
+  // instructions (like GTE and tuples).
+  int64 removed_copies = 0;
+  for (HloComputation* computation : module->computations()) {
+    for (HloInstruction* instruction :
+         computation->MakeInstructionPostOrder()) {
+      if (instruction->opcode() == HloOpcode::kCopy &&
+          added_copies_.count(instruction) > 0) {
+        VLOG(5) << "Removing added copy: " << instruction->ToString();
+        TF_RETURN_IF_ERROR(
+            instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
+        TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
+        ++removed_copies;
+      }
+    }
+  }
+  added_copies_.clear();
+  if (removed_copies > 0) {
+    TupleSimplifier tuple_simplifier;
+    HloDCE dce;
+    TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
+    TF_RETURN_IF_ERROR(dce.Run(module).status());
+  }
+  return Status::OK();
+}
+
+Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
+                                           int64 operand_number) {
+  HloInstruction* operand = instruction->mutable_operand(operand_number);
+  if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
+    HloInstruction* copy =
+        instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
+            operand->shape(), HloOpcode::kCopy, operand));
+    SetupCopiedInstruction(*operand, copy, {});
+    LayoutUtil::ClearLayout(copy->mutable_shape());
+    TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
+  }
+  return Status::OK();
+}
+
 }  // namespace xla
index c83ae03..8b4e079 100644 (file)
@@ -39,6 +39,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace xla {
@@ -288,7 +289,7 @@ class LayoutAssignment : public HloPassInterface {
   // If channel_constraints is nullptr, no kSend or kRecvs must be contained
   // within any module passed to `Run`.
   explicit LayoutAssignment(
-      const ComputationLayout& entry_computation_layout,
+      ComputationLayout* entry_computation_layout,
       ChannelLayoutConstraints* channel_constraints = nullptr);
   ~LayoutAssignment() override {}
   tensorflow::StringPiece name() const override { return "layout-assignment"; }
@@ -362,12 +363,15 @@ class LayoutAssignment : public HloPassInterface {
       int64 operand_no);
 
  private:
+  // Initializes the layout assignment object for a new Run() call.
+  Status Init();
+
   // Adds constraints which must be satisfied for correctness on all
   // backends. Called once prior to propagating constraints.
-  Status AddMandatoryConstraints(
-      const ComputationLayout& computation_layout,
-      const ChannelLayoutConstraints* channel_constraints,
-      HloComputation* computation, LayoutConstraints* constraints);
+  Status AddMandatoryConstraints(const ComputationLayout* computation_layout,
+                                 ChannelLayoutConstraints* channel_constraints,
+                                 HloComputation* computation,
+                                 LayoutConstraints* constraints);
 
   // This method can be overridden to add backend-specific constraints to the
   // layout of the instructions of a computation. This method is called after
@@ -378,10 +382,12 @@ class LayoutAssignment : public HloPassInterface {
   }
 
   // Construct contraints and assign layouts to all instructions in the
-  // computation satisfying the given ComputationLayout. Layouts constraints are
-  // added, then propagated until all LogicalBuffers in the computation are
-  // constrained.
-  Status RunOnComputation(const ComputationLayout& computation_layout,
+  // computation satisfying the given ComputationLayout, if not nullptr.
+  // Otherwise the ComputationLayout will be calculated by propagating the
+  // computation instruction contraints.
+  // Layouts constraints are added, then propagated until all LogicalBuffers in
+  // the computation are constrained.
+  Status RunOnComputation(ComputationLayout* computation_layout,
                           const TuplePointsToAnalysis& points_to_analysis,
                           HloComputation* computation,
                           ChannelLayoutConstraints* channel_constraints);
@@ -402,7 +408,26 @@ class LayoutAssignment : public HloPassInterface {
   // necessary conditions.
   Status CheckLayouts(HloModule* module);
 
-  const ComputationLayout& entry_computation_layout_;
+  // Computes the ComputationLayout of the given computation based of the
+  // layouts assigned to parameters and root instruction, and inserts it to the
+  // computation_layouts_ map.
+  Status CalculateComputationLayout(HloComputation* computation);
+
+  // Clears all the layouts which can be cleared within a computation.
+  Status ClearComputationLayouts(HloComputation* computation);
+
+  // Clears the side effects of a previous pass, like added copy instructions.
+  Status ClearPreviousPassSideEffects(HloModule* module);
+
+  // Propagates the layouts computed by the layout assignment pass on the given
+  // computation, to the computation layout passed in to this API.
+  // This API propagates missing layout, and also checks that the caller
+  // specified have been respected, by comparing those with the parameters and
+  // root computation instruction.
+  Status PropagateComputationLayouts(HloComputation* computation,
+                                     ComputationLayout* computation_layout);
+
+  ComputationLayout* entry_computation_layout_;
 
  protected:
   // Sets up the copy instruction according to the characteristic (sharding,
@@ -418,21 +443,37 @@ class LayoutAssignment : public HloPassInterface {
   // Creates and returns a copy of the given instruction with a different
   // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
   // instruction producing the copy is returned.
-  static StatusOr<HloInstruction*> CreateCopyWithNewLayout(
+  StatusOr<HloInstruction*> CreateCopyWithNewLayout(
       const Shape& shape_with_layout, HloInstruction* instruction);
 
   // Creates a copy of the given operand if the operand's layout does not match
   // the given layout. This copy replaces the use in the given instruction.
   // Tuple operands will be deep-copied.
-  static Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
-                                           HloInstruction* instruction,
-                                           int64 operand_no);
+  Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
+                                    HloInstruction* instruction,
+                                    int64 operand_no);
+
+  // Registers a copy instruction added by the layout assignment pass.
+  void RegisterAddedCopy(HloInstruction* copy) {
+    CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
+    added_copies_.insert(copy);
+  }
+
+  // Adds a copy for the operand of an instruction, unless such operand is
+  // already a copy, and has a single user (which is forcibly the instruction
+  // itself).
+  Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number);
 
   // Map containing the layouts of all computations assigned so
   // far. Computations are handled in a topological sort where computations are
   // handled before their caller instructions so the layouts of caller
   // instructions can be set to match the computation.
   std::map<HloComputation*, ComputationLayout> computation_layouts_;
+
+  // Every copy added to the module by the layout assignment pass is registered
+  // here.
+  tensorflow::gtl::FlatSet<HloInstruction*> added_copies_;
+
   ChannelLayoutConstraints* channel_layout_constraints_;
 };
 
index 986e177..7508013 100644 (file)
@@ -53,7 +53,7 @@ class LayoutAssignmentTest : public HloTestBase {
  protected:
   void AssignLayouts(HloModule* module,
                      ComputationLayout* entry_computation_layout) {
-    LayoutAssignment layout_assignment(*entry_computation_layout);
+    LayoutAssignment layout_assignment(entry_computation_layout);
     EXPECT_IS_OK(layout_assignment.Run(module).status());
   }
 };
@@ -285,7 +285,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
       result_shape));
 
-  LayoutAssignment layout_assignment(computation_layout);
+  LayoutAssignment layout_assignment(&computation_layout);
   AssignLayouts(module.get(), &computation_layout);
 
   // Layout assignment should have deep copied the result of the computation to
@@ -488,7 +488,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
  public:
   explicit OperandsMustBeTheSameLayoutAssignment(
       ComputationLayout* entry_computation_layout)
-      : LayoutAssignment(*entry_computation_layout) {}
+      : LayoutAssignment(entry_computation_layout) {}
 
  protected:
   Status PropagateBufferConstraint(
@@ -807,7 +807,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
 
   ComputationLayout computation_layout(
       module->entry_computation()->ComputeProgramShape());
-  LayoutAssignment layout_assignment(computation_layout);
+  LayoutAssignment layout_assignment(&computation_layout);
   Status error_status = layout_assignment.Run(module.get()).status();
   EXPECT_FALSE(error_status.ok());
   EXPECT_THAT(
index 047cadb..cb0f76e 100644 (file)
@@ -340,6 +340,9 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
     // If the result layout is not set, then choose the default.
     // TODO(b/29118294): Allow the compiler to choose a better layout in this
     // case.
+    // TODO(b/78356948): We are forcing the default layout here. We should fix
+    // clients which expect a default layout, to be explicit about it, by
+    // passing the proper ExecutionOptions with shape_with_output_layout set.
     host_computation_layout->mutable_result_layout()->SetToDefaultLayout();
     device_computation_layout->mutable_result_layout()->SetToDefaultLayout();
   }
index 113c2e2..d668855 100644 (file)
@@ -69,6 +69,7 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
       //       Tuple
       //
       HloInstruction* top_tuple = nullptr;
+      HloInstruction* first_gte = nullptr;
       bool can_simplify = true;
       for (int64 operand_number = 0;
            operand_number < instruction->operand_count(); ++operand_number) {
@@ -78,11 +79,17 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
           can_simplify = false;
           break;
         }
-
+        if (first_gte == nullptr) {
+          first_gte = operand;
+        } else if (!first_gte->has_compatible_sharding(operand)) {
+          can_simplify = false;
+          break;
+        }
         if (top_tuple == nullptr) {
           top_tuple = operand->mutable_operand(0);
           if (!ShapeUtil::Compatible(top_tuple->shape(),
-                                     instruction->shape())) {
+                                     instruction->shape()) ||
+              !instruction->has_compatible_sharding(top_tuple)) {
             can_simplify = false;
             break;
           }
@@ -108,15 +115,17 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
       //          |
       //         GTE
       if (instruction->operand(0)->opcode() == HloOpcode::kTuple) {
-        changed = true;
         HloInstruction* element_source =
             instruction->mutable_operand(0)->mutable_operand(
                 instruction->tuple_index());
-        TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
-        for (HloInstruction* user : element_source->users()) {
-          if (user->opcode() == HloOpcode::kTuple ||
-              user->opcode() == HloOpcode::kGetTupleElement) {
-            worklist.push(user);
+        if (instruction->has_compatible_sharding(element_source)) {
+          changed = true;
+          TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
+          for (HloInstruction* user : element_source->users()) {
+            if (user->opcode() == HloOpcode::kTuple ||
+                user->opcode() == HloOpcode::kGetTupleElement) {
+              worklist.push(user);
+            }
           }
         }
       }