[mlir] Folding and canonicalization of shape.cstr_eq
authorTres Popp <tpopp@google.com>
Wed, 20 May 2020 13:54:57 +0000 (15:54 +0200)
committerTres Popp <tpopp@google.com>
Fri, 5 Jun 2020 09:00:20 +0000 (11:00 +0200)
In the case of all inputs being constant and equal, cstr_eq will be
replaced with a true_witness.

Differential Revision: https://reviews.llvm.org/D80303

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
mlir/test/Dialect/Shape/canonicalize.mlir

index a05273f..6fb7cbf 100644 (file)
@@ -555,7 +555,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
   let hasFolder = 1;
 }
 
-def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
+def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
   let summary = "Determines if all input shapes are equal";
   let description = [{
     Given 1 or more input shapes, determine if all shapes are the exact same.
@@ -572,6 +572,9 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
   let results = (outs Shape_WitnessType:$result);
 
   let assemblyFormat = "$inputs attr-dict";
+
+  let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> {
index 3a8831c..e12e23b 100644 (file)
@@ -291,6 +291,27 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
+// CstrEqOp
+//===----------------------------------------------------------------------===//
+
+void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+                                           MLIRContext *context) {
+  // If inputs are equal, return passing witness
+  patterns.insert<CstrEqEqOps>(context);
+}
+
+OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
+  if (llvm::all_of(operands,
+                   [&](Attribute a) { return a && a == operands[0]; }))
+    return BoolAttr::get(true, getContext());
+
+  // Because a failing witness result here represents an eventual assertion
+  // failure, we do not try to replace it with a constant witness. Similarly, we
+  // cannot if there are any non-const inputs.
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
 // ConstSizeOp
 //===----------------------------------------------------------------------===//
 
index 9a73a88..78c9119 100644 (file)
@@ -2,7 +2,17 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td"
 
 def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
 
+def AllInputShapesEq : Constraint<CPred< [{
+  llvm::all_of($0, [&](mlir::Value val) {
+    return $0[0] == val;
+  })
+}]>>;
+
 // Canonicalization patterns.
 def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs),
   (Shape_ConstWitnessOp ConstBoolAttrTrue),
   [(EqualBinaryOperands $lhs, $rhs)]>;
+
+def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
+  (Shape_ConstWitnessOp ConstBoolAttrTrue),
+  [(AllInputShapesEq $shapes)]>;
index 93ce36a..32fa496 100644 (file)
@@ -213,6 +213,62 @@ func @not_const(%arg0: !shape.shape) -> !shape.size {
   return %0 : !shape.size
 }
 
+
+// -----
+// cstr_eq with non-constant but known equal shapes can be removed.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.cstr_eq %arg0, %arg0, %arg0
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// cstr_eq with equal const_shapes can be folded
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [0, 1]
+  %cs1 = shape.const_shape [0, 1]
+  %cs2 = shape.const_shape [0, 1]
+  %0 = shape.cstr_eq %cs0, %cs1, %cs2
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// cstr_eq with unequal const_shapes cannot be folded
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: shape.const_shape
+  // CHECK-NEXT: shape.const_shape
+  // CHECK-NEXT: shape.cstr_eq
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [0, 1]
+  %cs1 = shape.const_shape [3, 1]
+  %0 = shape.cstr_eq %cs0, %cs1
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// cstr_eq without const_shapes cannot be folded
+// CHECK-LABEL: func @f
+func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
+  // CHECK-NEXT: shape.cstr_eq
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.cstr_eq %arg0, %arg1
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
 // -----
 // assuming_all with known passing witnesses can be folded
 // CHECK-LABEL: func @f