Clear existing layouts before running the layout assignment.
authorHyoukJoong Lee <hyouklee@google.com>
Sat, 9 Dec 2017 01:28:52 +0000 (17:28 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 9 Dec 2017 01:34:28 +0000 (17:34 -0800)
PiperOrigin-RevId: 178449701

tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
tensorflow/compiler/xla/service/layout_assignment.cc

index 7975eba399fbc8404536291fb262d7bf5cccccce..78732c31f99540d090fcb0d324697c505f050934 100644 (file)
@@ -139,6 +139,10 @@ Status CpuLayoutAssignment::AddBackendConstraints(
         if (constraints->OperandBufferForwarded(instruction, operand_no)) {
           continue;
         }
+        // Skip operands with non-array shapes.
+        if (!ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) {
+          continue;
+        }
         Shape operand_shape(
             row_major_shape(instruction->operand(operand_no)->shape()));
         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
index 7eda7c2284c2457703fcfcd4226172e41dd4ae01..af726271aeb484ac6bb90e40aff7010b092967df 100644 (file)
@@ -1328,6 +1328,20 @@ Status LayoutAssignment::RunOnComputation(
           << ")";
   VLOG(2) << "  ComputationLayout = " << computation_layout.ToString();
 
+  // Clear existing layouts of the instructions. All layouts must be assigned by
+  // the LayoutAssignment pass, except for Infeed, Outfeed, Parameters and the
+  // computation result. The latter two are specified in computation_layout, so
+  // we only need to keep the existing layouts for Infeed and Outfeed. Clearing
+  // the layouts here avoids hiding potential bugs in the layout assignment pass
+  // that may accidently use the existing layout.
+  for (HloInstruction* instruction : computation->instructions()) {
+    if (instruction->opcode() == HloOpcode::kInfeed ||
+        instruction->opcode() == HloOpcode::kOutfeed) {
+      continue;
+    }
+    LayoutUtil::ClearLayout(instruction->mutable_shape());
+  }
+
   // Construct LayoutConstraints with all layout constraints of the computation.
   LayoutConstraints constraints(points_to_analysis, computation);