[MLIR][Shape] Canonicalize subsequent `index_to_size` and `size_to_index`
authorFrederik Gossen <frgossen@google.com>
Fri, 19 Jun 2020 14:25:10 +0000 (14:25 +0000)
committerFrederik Gossen <frgossen@google.com>
Thu, 25 Jun 2020 12:02:49 +0000 (12:02 +0000)
Eliminate the subsequent applications of `index_to_size` and `size_to_index`.

Differential Revision: https://reviews.llvm.org/D82082

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
mlir/test/Dialect/Shape/canonicalize.mlir

index 2430fe6..21d76a3 100644 (file)
@@ -392,6 +392,7 @@ def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
   let assemblyFormat = "$arg attr-dict";
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Shape_YieldOp : Shape_Op<"yield",
index 2d95218..b1f8043 100644 (file)
@@ -536,6 +536,11 @@ OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+void SizeToIndexOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<IndexToSizeToIndexCanonicalization>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//
index 78c9119..43ea27f 100644 (file)
@@ -9,6 +9,7 @@ def AllInputShapesEq : Constraint<CPred< [{
 }]>>;
 
 // Canonicalization patterns.
+
 def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs),
   (Shape_ConstWitnessOp ConstBoolAttrTrue),
   [(EqualBinaryOperands $lhs, $rhs)]>;
@@ -16,3 +17,8 @@ def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs),
 def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
   (Shape_ConstWitnessOp ConstBoolAttrTrue),
   [(AllInputShapesEq $shapes)]>;
+
+def IndexToSizeToIndexCanonicalization : Pat<
+  (Shape_SizeToIndexOp (Shape_IndexToSizeOp $arg)),
+  (replaceWithValue $arg)>;
+
index 9fb48e6..1da1b70 100644 (file)
@@ -492,3 +492,14 @@ func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
 %rank = shape.rank %shape
 return %rank : !shape.size
 }
+
+// Canonicalize redundant conversion from `index` to `size` and back.
+// CHECK-LABEL: @index_to_size_to_index
+// CHECK-SAME: (%[[IDX:.*]]: index) -> index
+func @index_to_size_to_index(%index : index) -> index {
+  // CHECK: return %[[IDX]] : index
+  %size = shape.index_to_size %index
+  %result = shape.size_to_index %size
+  return %result : index
+}
+