let hasFolder = 1;
}
-def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
+def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
let summary = "Determines if all input shapes are equal";
let description = [{
Given 1 or more input shapes, determine if all shapes are the exact same.
let results = (outs Shape_WitnessType:$result);
let assemblyFormat = "$inputs attr-dict";
-
- let hasCanonicalizer = 1;
- let hasFolder = 1;
}
def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> {
}
//===----------------------------------------------------------------------===//
-// CstrEqOp
-//===----------------------------------------------------------------------===//
-
-void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
- MLIRContext *context) {
- // If inputs are equal, return passing witness
- patterns.insert<CstrEqEqOps>(context);
-}
-
-OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
- if (llvm::all_of(operands,
- [&](Attribute a) { return a && a == operands[0]; }))
- return BoolAttr::get(true, getContext());
-
- // Because a failing witness result here represents an eventual assertion
- // failure, we do not try to replace it with a constant witness. Similarly, we
- // cannot if there are any non-const inputs.
- return nullptr;
-}
-
-//===----------------------------------------------------------------------===//
// ConstSizeOp
//===----------------------------------------------------------------------===//
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
-def AllInputShapesEq : Constraint<CPred< [{
- llvm::all_of($0, [&](mlir::Value val) {
- return $0[0] == val;
- })
-}]>>;
-
// Canonicalization patterns.
def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs),
(Shape_ConstWitnessOp ConstBoolAttrTrue),
[(EqualBinaryOperands $lhs, $rhs)]>;
-
-def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
- (Shape_ConstWitnessOp ConstBoolAttrTrue),
- [(AllInputShapesEq $shapes)]>;
return %0 : !shape.size
}
-
-// -----
-// cstr_eq with non-constant but known equal shapes can be removed.
-// CHECK-LABEL: func @f
-func @f(%arg0 : !shape.shape) {
- // CHECK-NEXT: shape.const_witness true
- // CHECK-NEXT: consume.witness
- // CHECK-NEXT: return
- %0 = shape.cstr_eq %arg0, %arg0, %arg0
- "consume.witness"(%0) : (!shape.witness) -> ()
- return
-}
-
-// -----
-// cstr_eq with equal const_shapes can be folded
-// CHECK-LABEL: func @f
-func @f() {
- // CHECK-NEXT: shape.const_witness true
- // CHECK-NEXT: consume.witness
- // CHECK-NEXT: return
- %cs0 = shape.const_shape [0, 1]
- %cs1 = shape.const_shape [0, 1]
- %cs2 = shape.const_shape [0, 1]
- %0 = shape.cstr_eq %cs0, %cs1, %cs2
- "consume.witness"(%0) : (!shape.witness) -> ()
- return
-}
-
-// -----
-// cstr_eq with unequal const_shapes cannot be folded
-// CHECK-LABEL: func @f
-func @f() {
- // CHECK-NEXT: shape.const_shape
- // CHECK-NEXT: shape.const_shape
- // CHECK-NEXT: shape.cstr_eq
- // CHECK-NEXT: consume.witness
- // CHECK-NEXT: return
- %cs0 = shape.const_shape [0, 1]
- %cs1 = shape.const_shape [3, 1]
- %0 = shape.cstr_eq %cs0, %cs1
- "consume.witness"(%0) : (!shape.witness) -> ()
- return
-}
-
-// -----
-// cstr_eq without const_shapes cannot be folded
-// CHECK-LABEL: func @f
-func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
- // CHECK-NEXT: shape.cstr_eq
- // CHECK-NEXT: consume.witness
- // CHECK-NEXT: return
- %0 = shape.cstr_eq %arg0, %arg1
- "consume.witness"(%0) : (!shape.witness) -> ()
- return
-}
-
// -----
// assuming_all with known passing witnesses can be folded
// CHECK-LABEL: func @f