From 8e2ff05d31118724eb21c48b98cd45c64884e13c Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Thu, 17 May 2018 15:08:33 -0700 Subject: [PATCH] [XLA] Remove eq_shapes from Identical SlowPath since it is already checked in Identical. PiperOrigin-RevId: 197058888 --- tensorflow/compiler/xla/service/hlo_instruction.cc | 47 +++++++--------------- tensorflow/compiler/xla/service/hlo_instruction.h | 14 ++----- 2 files changed, 19 insertions(+), 42 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index d2fbc83..66ff111 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1689,24 +1689,27 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const { + eq_computations) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and // operands. case HloOpcode::kAbs: case HloOpcode::kAtan2: - case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: case HloOpcode::kComplex: + case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: @@ -1730,6 +1733,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: + case HloOpcode::kReshape: + case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: @@ -1741,6 +1746,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kTuple: return true; + // Broadcast, Concatenate, and Transpose need the same dimensions field. + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kTranspose: + return dimensions() == other.dimensions(); + case HloOpcode::kFusion: return fusion_kind() == other.fusion_kind() && eq_computations(fused_instructions_computation(), @@ -1753,10 +1764,7 @@ bool HloInstruction::IdenticalSlowPath( return false; case HloOpcode::kParameter: - return parameter_number() == other.parameter_number() && - // Check the shape too because `this` and `other` may be in - // different HloComputations. - eq_shapes(shape(), other.shape()); + return parameter_number() == other.parameter_number(); case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: @@ -1768,12 +1776,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kConstant: return literal() == other.literal(); - // A convert result is determined by the primitive type that the operand is - // converted into. - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - return shape().element_type() == other.shape().element_type(); - // A reduce-precision operation is determined by the bit sizes. case HloOpcode::kReducePrecision: return exponent_bits() == other.exponent_bits() && @@ -1816,22 +1818,8 @@ bool HloInstruction::IdenticalSlowPath( eq_computations(scatter(), other.scatter()) && protobuf_util::ProtobufEquals(window(), other.window()); - case HloOpcode::kReshape: - return eq_shapes(shape(), other.shape()); - - // Transpose result is determined by the final shape and the permutation. - case HloOpcode::kTranspose: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); // Remaining instructions with special values. - case HloOpcode::kBitcast: - return eq_shapes(shape(), other.shape()); - case HloOpcode::kBroadcast: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); - case HloOpcode::kConcatenate: - return dimensions() == other.dimensions(); case HloOpcode::kGetTupleElement: return tuple_index() == other.tuple_index(); case HloOpcode::kPad: @@ -1841,11 +1829,6 @@ bool HloInstruction::IdenticalSlowPath( return slice_starts_ == other.slice_starts_ && slice_limits_ == other.slice_limits_ && slice_strides_ == other.slice_strides_; - case HloOpcode::kDynamicSlice: - return eq_shapes(shape(), other.shape()) && - dynamic_slice_sizes_ == other.dynamic_slice_sizes_; - case HloOpcode::kDynamicUpdateSlice: - return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 0831a54..db78539 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -746,10 +746,8 @@ class HloInstruction { if (opcode() != other.opcode()) { return false; } - using EqShapeFuncType = bool (*)(const Shape&, const Shape&); - EqShapeFuncType eq_shapes = - layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; - if (!eq_shapes(shape(), other.shape())) { + if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) + : ShapeUtil::Compatible(shape(), other.shape()))) { return false; } if (operands().size() != other.operands().size()) { @@ -764,7 +762,7 @@ class HloInstruction { } } - return IdenticalSlowPath(other, eq_computations, eq_shapes); + return IdenticalSlowPath(other, eq_computations); } // Returns whether the instruction has a constant operand. @@ -1497,14 +1495,10 @@ class HloInstruction { class FusionReusesParamElements; // See comments on Identical(). - // eq_shapes() is used to check shapes for equality, and would normally be - // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on - // whether we want a layout-sensitive check or not. bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const; + eq_computations) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( -- 2.7.4