equivalent_instructions;
for (HloInstruction* user : operand->users()) {
if (user != instruction &&
- user->Identical(*instruction, eq_instructions, eq_computations) &&
- (!is_layout_sensitive_ ||
- ShapeUtil::Equal(user->shape(), instruction->shape()))) {
+ user->Identical(*instruction, eq_instructions, eq_computations,
+ is_layout_sensitive_)) {
equivalent_instructions.push_back(user);
}
}
bool HloInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const {
+ eq_computations,
+ const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const {
// Perform opcode specific checks.
switch (opcode()) {
// The result of these instructions only depend upon their opcode and
return parameter_number() == other.parameter_number() &&
// Check the shape too because `this` and `other` may be in
// different HloComputations.
- ShapeUtil::Compatible(shape(), other.shape());
+ eq_shapes(shape(), other.shape());
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
protobuf_util::ProtobufEquals(window(), other.window());
case HloOpcode::kReshape:
- return ShapeUtil::Compatible(shape(), other.shape());
+ return eq_shapes(shape(), other.shape());
// Transpose result is determined by the final shape and the permutation.
case HloOpcode::kTranspose:
- return ShapeUtil::Compatible(shape(), other.shape()) &&
+ return eq_shapes(shape(), other.shape()) &&
dimensions() == other.dimensions();
// Remaining instructions with special values.
case HloOpcode::kBitcast:
- return ShapeUtil::Equal(shape(), other.shape());
+ return eq_shapes(shape(), other.shape());
case HloOpcode::kBroadcast:
- return ShapeUtil::Compatible(shape(), other.shape()) &&
+ return eq_shapes(shape(), other.shape()) &&
dimensions() == other.dimensions();
case HloOpcode::kConcatenate:
return dimensions() == other.dimensions();
slice_limits_ == other.slice_limits_ &&
slice_strides_ == other.slice_strides_;
case HloOpcode::kDynamicSlice:
- return ShapeUtil::Compatible(shape(), other.shape()) &&
+ return eq_shapes(shape(), other.shape()) &&
dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
case HloOpcode::kDynamicUpdateSlice:
- return ShapeUtil::Compatible(shape(), other.shape());
+ return eq_shapes(shape(), other.shape());
case HloOpcode::kCall:
case HloOpcode::kMap:
return eq_computations(to_apply(), other.to_apply());
}
// Returns true if "other" performs the same computation as this instruction.
- // Layout of the instructions' output array is not considered.
bool Identical(
const HloInstruction& other,
const std::function<bool(const HloInstruction*, const HloInstruction*)>&
eq_operands = std::equal_to<const HloInstruction*>(),
const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations = std::equal_to<const HloComputation*>()) const {
+ eq_computations = std::equal_to<const HloComputation*>(),
+ bool layout_sensitive = true) const {
// An instruction is always identical to itself.
if (this == &other) {
return true;
}
- // Identical instruction must have the same opcode and identical operands.
- // In general, there is no need to check shape because shape is inferred
- // from the shape of the operands.
+ // Identical instruction must have the same opcode, shape, and identical
+ // operands.
if (opcode() != other.opcode()) {
return false;
}
+ auto eq_shapes = layout_sensitive
+ ? [](const Shape& a,
+ const Shape& b) { return ShapeUtil::Equal(a, b); }
+ : [](const Shape& a, const Shape& b) {
+ return ShapeUtil::Compatible(a, b);
+ };
+ if (!eq_shapes(shape(), other.shape())) {
+ return false;
+ }
if (operands().size() != other.operands().size()) {
return false;
}
+
// Use an explicit loop rather than ContainerEquals, because copying around
// std::functions may be too expensive in some cases.
for (size_t i = 0; i < operands().size(); ++i) {
}
}
- return IdenticalSlowPath(other, eq_computations);
+ return IdenticalSlowPath(other, eq_computations, eq_shapes);
}
// Returns whether the instruction has a constant operand.
class FusionReusesParamElements;
// See comments on Identical().
+ // eq_shapes() is used to check shapes for equality, and would normally be
+ // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on
+ // whether we want a layout-sensitive check or not.
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const;
+ eq_computations,
+ const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const;
// Creates an n-ary elementwise operation.
static std::unique_ptr<HloInstruction> CreateNary(