From 0a554e607ff6247b79d1c4f184999750e5ad53b9 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Wed, 20 May 2020 15:54:57 +0200 Subject: [PATCH] [mlir] Folding and canonicalization of shape.cstr_eq In the case of all inputs being constant and equal, cstr_eq will be replaced with a true_witness. Differential Revision: https://reviews.llvm.org/D80303 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 5 +- mlir/lib/Dialect/Shape/IR/Shape.cpp | 21 ++++++++ mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td | 10 ++++ mlir/test/Dialect/Shape/canonicalize.mlir | 56 ++++++++++++++++++++++ 4 files changed, 91 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index a05273f..6fb7cbf 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -555,7 +555,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> { let hasFolder = 1; } -def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> { +def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> { let summary = "Determines if all input shapes are equal"; let description = [{ Given 1 or more input shapes, determine if all shapes are the exact same. @@ -572,6 +572,9 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> { let results = (outs Shape_WitnessType:$result); let assemblyFormat = "$inputs attr-dict"; + + let hasCanonicalizer = 1; + let hasFolder = 1; } def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 3a8831c..e12e23b 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -291,6 +291,27 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// +// CstrEqOp +//===----------------------------------------------------------------------===// + +void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + // If inputs are equal, return passing witness + patterns.insert(context); +} + +OpFoldResult CstrEqOp::fold(ArrayRef operands) { + if (llvm::all_of(operands, + [&](Attribute a) { return a && a == operands[0]; })) + return BoolAttr::get(true, getContext()); + + // Because a failing witness result here represents an eventual assertion + // failure, we do not try to replace it with a constant witness. Similarly, we + // cannot if there are any non-const inputs. + return nullptr; +} + +//===----------------------------------------------------------------------===// // ConstSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td index 9a73a88..78c9119 100644 --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -2,7 +2,17 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td" def EqualBinaryOperands : Constraint>; +def AllInputShapesEq : Constraint>; + // Canonicalization patterns. def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs), (Shape_ConstWitnessOp ConstBoolAttrTrue), [(EqualBinaryOperands $lhs, $rhs)]>; + +def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes), + (Shape_ConstWitnessOp ConstBoolAttrTrue), + [(AllInputShapesEq $shapes)]>; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 93ce36a..32fa496 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -213,6 +213,62 @@ func @not_const(%arg0: !shape.shape) -> !shape.size { return %0 : !shape.size } + +// ----- +// cstr_eq with non-constant but known equal shapes can be removed. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.cstr_eq %arg0, %arg0, %arg0 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// cstr_eq with equal const_shapes can be folded +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [0, 1] + %cs1 = shape.const_shape [0, 1] + %cs2 = shape.const_shape [0, 1] + %0 = shape.cstr_eq %cs0, %cs1, %cs2 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// cstr_eq with unequal const_shapes cannot be folded +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.cstr_eq + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [0, 1] + %cs1 = shape.const_shape [3, 1] + %0 = shape.cstr_eq %cs0, %cs1 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// cstr_eq without const_shapes cannot be folded +// CHECK-LABEL: func @f +func @f(%arg0: !shape.shape, %arg1: !shape.shape) { + // CHECK-NEXT: shape.cstr_eq + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.cstr_eq %arg0, %arg1 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + // ----- // assuming_all with known passing witnesses can be folded // CHECK-LABEL: func @f -- 2.7.4