[XLA] Sink layout sensitivity from CSE into HloInstruction::Identical, and make it...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 5 Feb 2018 23:46:50 +0000 (15:46 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 23:51:07 +0000 (15:51 -0800)
PiperOrigin-RevId: 184598903

tensorflow/compiler/xla/service/hlo_cse.cc
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h

index 7feda2b..279edd4 100644 (file)
@@ -119,9 +119,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
           equivalent_instructions;
       for (HloInstruction* user : operand->users()) {
         if (user != instruction &&
-            user->Identical(*instruction, eq_instructions, eq_computations) &&
-            (!is_layout_sensitive_ ||
-             ShapeUtil::Equal(user->shape(), instruction->shape()))) {
+            user->Identical(*instruction, eq_instructions, eq_computations,
+                            is_layout_sensitive_)) {
           equivalent_instructions.push_back(user);
         }
       }
index fac6b43..277648f 100644 (file)
@@ -1612,7 +1612,8 @@ bool HloInstruction::HasConstantOperand() const {
 bool HloInstruction::IdenticalSlowPath(
     const HloInstruction& other,
     const std::function<bool(const HloComputation*, const HloComputation*)>&
-        eq_computations) const {
+        eq_computations,
+    const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const {
   // Perform opcode specific checks.
   switch (opcode()) {
     // The result of these instructions only depend upon their opcode and
@@ -1671,7 +1672,7 @@ bool HloInstruction::IdenticalSlowPath(
       return parameter_number() == other.parameter_number() &&
              // Check the shape too because `this` and `other` may be in
              // different HloComputations.
-             ShapeUtil::Compatible(shape(), other.shape());
+             eq_shapes(shape(), other.shape());
 
     case HloOpcode::kBatchNormTraining:
     case HloOpcode::kBatchNormInference:
@@ -1727,18 +1728,18 @@ bool HloInstruction::IdenticalSlowPath(
              protobuf_util::ProtobufEquals(window(), other.window());
 
     case HloOpcode::kReshape:
-      return ShapeUtil::Compatible(shape(), other.shape());
+      return eq_shapes(shape(), other.shape());
 
     // Transpose result is determined by the final shape and the permutation.
     case HloOpcode::kTranspose:
-      return ShapeUtil::Compatible(shape(), other.shape()) &&
+      return eq_shapes(shape(), other.shape()) &&
              dimensions() == other.dimensions();
 
     // Remaining instructions with special values.
     case HloOpcode::kBitcast:
-      return ShapeUtil::Equal(shape(), other.shape());
+      return eq_shapes(shape(), other.shape());
     case HloOpcode::kBroadcast:
-      return ShapeUtil::Compatible(shape(), other.shape()) &&
+      return eq_shapes(shape(), other.shape()) &&
              dimensions() == other.dimensions();
     case HloOpcode::kConcatenate:
       return dimensions() == other.dimensions();
@@ -1752,10 +1753,10 @@ bool HloInstruction::IdenticalSlowPath(
              slice_limits_ == other.slice_limits_ &&
              slice_strides_ == other.slice_strides_;
     case HloOpcode::kDynamicSlice:
-      return ShapeUtil::Compatible(shape(), other.shape()) &&
+      return eq_shapes(shape(), other.shape()) &&
              dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
     case HloOpcode::kDynamicUpdateSlice:
-      return ShapeUtil::Compatible(shape(), other.shape());
+      return eq_shapes(shape(), other.shape());
     case HloOpcode::kCall:
     case HloOpcode::kMap:
       return eq_computations(to_apply(), other.to_apply());
index bce9ebd..50931c5 100644 (file)
@@ -554,27 +554,36 @@ class HloInstruction {
   }
 
   // Returns true if "other" performs the same computation as this instruction.
-  // Layout of the instructions' output array is not considered.
   bool Identical(
       const HloInstruction& other,
       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
           eq_operands = std::equal_to<const HloInstruction*>(),
       const std::function<bool(const HloComputation*, const HloComputation*)>&
-          eq_computations = std::equal_to<const HloComputation*>()) const {
+          eq_computations = std::equal_to<const HloComputation*>(),
+      bool layout_sensitive = true) const {
     // An instruction is always identical to itself.
     if (this == &other) {
       return true;
     }
 
-    // Identical instruction must have the same opcode and identical operands.
-    // In general, there is no need to check shape because shape is inferred
-    // from the shape of the operands.
+    // Identical instruction must have the same opcode, shape, and identical
+    // operands.
     if (opcode() != other.opcode()) {
       return false;
     }
+    auto eq_shapes = layout_sensitive
+                         ? [](const Shape& a,
+                              const Shape& b) { return ShapeUtil::Equal(a, b); }
+                         : [](const Shape& a, const Shape& b) {
+                             return ShapeUtil::Compatible(a, b);
+                           };
+    if (!eq_shapes(shape(), other.shape())) {
+      return false;
+    }
     if (operands().size() != other.operands().size()) {
       return false;
     }
+
     // Use an explicit loop rather than ContainerEquals, because copying around
     // std::functions may be too expensive in some cases.
     for (size_t i = 0; i < operands().size(); ++i) {
@@ -583,7 +592,7 @@ class HloInstruction {
       }
     }
 
-    return IdenticalSlowPath(other, eq_computations);
+    return IdenticalSlowPath(other, eq_computations, eq_shapes);
   }
 
   // Returns whether the instruction has a constant operand.
@@ -1232,10 +1241,14 @@ class HloInstruction {
   class FusionReusesParamElements;
 
   // See comments on Identical().
+  // eq_shapes() is used to check shapes for equality, and would normally be
+  // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on
+  // whether we want a layout-sensitive check or not.
   bool IdenticalSlowPath(
       const HloInstruction& other,
       const std::function<bool(const HloComputation*, const HloComputation*)>&
-          eq_computations) const;
+          eq_computations,
+      const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const;
 
   // Creates an n-ary elementwise operation.
   static std::unique_ptr<HloInstruction> CreateNary(