[XLA] Remove eq_shapes from Identical SlowPath since it is already checked in
authorBlake Hechtman <blakehechtman@google.com>
Thu, 17 May 2018 22:08:33 +0000 (15:08 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 22:11:21 +0000 (15:11 -0700)
Identical.

PiperOrigin-RevId: 197058888

tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h

index d2fbc83..66ff111 100644 (file)
@@ -1689,24 +1689,27 @@ bool HloInstruction::HasConstantOperand() const {
 bool HloInstruction::IdenticalSlowPath(
     const HloInstruction& other,
     const std::function<bool(const HloComputation*, const HloComputation*)>&
-        eq_computations,
-    const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const {
+        eq_computations) const {
   // Perform opcode specific checks.
   switch (opcode()) {
     // The result of these instructions only depend upon their opcode and
     // operands.
     case HloOpcode::kAbs:
     case HloOpcode::kAtan2:
-    case HloOpcode::kRoundNearestAfz:
     case HloOpcode::kAdd:
+    case HloOpcode::kBitcast:
+    case HloOpcode::kBitcastConvert:
     case HloOpcode::kCeil:
     case HloOpcode::kClamp:
     case HloOpcode::kClz:
     case HloOpcode::kComplex:
+    case HloOpcode::kConvert:
     case HloOpcode::kCopy:
     case HloOpcode::kCos:
     case HloOpcode::kCrossReplicaSum:
     case HloOpcode::kDivide:
+    case HloOpcode::kDynamicSlice:
+    case HloOpcode::kDynamicUpdateSlice:
     case HloOpcode::kEq:
     case HloOpcode::kExp:
     case HloOpcode::kExpm1:
@@ -1730,6 +1733,8 @@ bool HloInstruction::IdenticalSlowPath(
     case HloOpcode::kPower:
     case HloOpcode::kReal:
     case HloOpcode::kRemainder:
+    case HloOpcode::kReshape:
+    case HloOpcode::kRoundNearestAfz:
     case HloOpcode::kSelect:
     case HloOpcode::kShiftLeft:
     case HloOpcode::kShiftRightArithmetic:
@@ -1741,6 +1746,12 @@ bool HloInstruction::IdenticalSlowPath(
     case HloOpcode::kTuple:
       return true;
 
+    // Broadcast, Concatenate, and Transpose need the same dimensions field.
+    case HloOpcode::kBroadcast:
+    case HloOpcode::kConcatenate:
+    case HloOpcode::kTranspose:
+      return dimensions() == other.dimensions();
+
     case HloOpcode::kFusion:
       return fusion_kind() == other.fusion_kind() &&
              eq_computations(fused_instructions_computation(),
@@ -1753,10 +1764,7 @@ bool HloInstruction::IdenticalSlowPath(
       return false;
 
     case HloOpcode::kParameter:
-      return parameter_number() == other.parameter_number() &&
-             // Check the shape too because `this` and `other` may be in
-             // different HloComputations.
-             eq_shapes(shape(), other.shape());
+      return parameter_number() == other.parameter_number();
 
     case HloOpcode::kBatchNormTraining:
     case HloOpcode::kBatchNormInference:
@@ -1768,12 +1776,6 @@ bool HloInstruction::IdenticalSlowPath(
     case HloOpcode::kConstant:
       return literal() == other.literal();
 
-    // A convert result is determined by the primitive type that the operand is
-    // converted into.
-    case HloOpcode::kConvert:
-    case HloOpcode::kBitcastConvert:
-      return shape().element_type() == other.shape().element_type();
-
     // A reduce-precision operation is determined by the bit sizes.
     case HloOpcode::kReducePrecision:
       return exponent_bits() == other.exponent_bits() &&
@@ -1816,22 +1818,8 @@ bool HloInstruction::IdenticalSlowPath(
              eq_computations(scatter(), other.scatter()) &&
              protobuf_util::ProtobufEquals(window(), other.window());
 
-    case HloOpcode::kReshape:
-      return eq_shapes(shape(), other.shape());
-
-    // Transpose result is determined by the final shape and the permutation.
-    case HloOpcode::kTranspose:
-      return eq_shapes(shape(), other.shape()) &&
-             dimensions() == other.dimensions();
 
     // Remaining instructions with special values.
-    case HloOpcode::kBitcast:
-      return eq_shapes(shape(), other.shape());
-    case HloOpcode::kBroadcast:
-      return eq_shapes(shape(), other.shape()) &&
-             dimensions() == other.dimensions();
-    case HloOpcode::kConcatenate:
-      return dimensions() == other.dimensions();
     case HloOpcode::kGetTupleElement:
       return tuple_index() == other.tuple_index();
     case HloOpcode::kPad:
@@ -1841,11 +1829,6 @@ bool HloInstruction::IdenticalSlowPath(
       return slice_starts_ == other.slice_starts_ &&
              slice_limits_ == other.slice_limits_ &&
              slice_strides_ == other.slice_strides_;
-    case HloOpcode::kDynamicSlice:
-      return eq_shapes(shape(), other.shape()) &&
-             dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
-    case HloOpcode::kDynamicUpdateSlice:
-      return eq_shapes(shape(), other.shape());
     case HloOpcode::kCall:
     case HloOpcode::kMap:
       return eq_computations(to_apply(), other.to_apply());
index 0831a54..db78539 100644 (file)
@@ -746,10 +746,8 @@ class HloInstruction {
     if (opcode() != other.opcode()) {
       return false;
     }
-    using EqShapeFuncType = bool (*)(const Shape&, const Shape&);
-    EqShapeFuncType eq_shapes =
-        layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible;
-    if (!eq_shapes(shape(), other.shape())) {
+    if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
+                           : ShapeUtil::Compatible(shape(), other.shape()))) {
       return false;
     }
     if (operands().size() != other.operands().size()) {
@@ -764,7 +762,7 @@ class HloInstruction {
       }
     }
 
-    return IdenticalSlowPath(other, eq_computations, eq_shapes);
+    return IdenticalSlowPath(other, eq_computations);
   }
 
   // Returns whether the instruction has a constant operand.
@@ -1497,14 +1495,10 @@ class HloInstruction {
   class FusionReusesParamElements;
 
   // See comments on Identical().
-  // eq_shapes() is used to check shapes for equality, and would normally be
-  // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on
-  // whether we want a layout-sensitive check or not.
   bool IdenticalSlowPath(
       const HloInstruction& other,
       const std::function<bool(const HloComputation*, const HloComputation*)>&
-          eq_computations,
-      const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const;
+          eq_computations) const;
 
   // Creates an n-ary elementwise operation.
   static std::unique_ptr<HloInstruction> CreateNary(