[mlir] Propagate arith.index_cast past tensor.extract
authorRob Suderman <rob.suderman@gmail.com>
Wed, 26 Jan 2022 06:15:55 +0000 (22:15 -0800)
committerRob Suderman <rob.suderman@gmail.com>
Wed, 26 Jan 2022 06:16:07 +0000 (22:16 -0800)
If we are extracting it is more useful to push the index_cast past the
extraction. This increases the chance the tensor.extract can evaluated at
compile time.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D118204

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir

index 5ae13a6..bba7f97 100644 (file)
@@ -425,11 +425,51 @@ struct ExtractElementFromTensorFromElements
   }
 };
 
+// Pushes the index_casts that occur before extractions to after the extract.
+// This minimizes type conversion in some cases and enables the extract
+// canonicalizer. This changes:
+//
+// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
+// %extract = tensor.extract %cast[%index] : tensor<1xindex>
+//
+// to the following:
+//
+// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
+// %cast = arith.index_cast %extract : i32 to index
+//
+// to just %element.
+//
+// Consider expanding this to a template and handle all tensor cast operations.
+struct ExtractElementFromIndexCast
+    : public OpRewritePattern<tensor::ExtractOp> {
+  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+                                PatternRewriter &rewriter) const final {
+    Location loc = extract.getLoc();
+    auto indexCast = extract.tensor().getDefiningOp<arith::IndexCastOp>();
+    if (!indexCast)
+      return failure();
+
+    Type elementTy = getElementTypeOrSelf(indexCast.getIn());
+
+    auto newExtract = rewriter.create<tensor::ExtractOp>(
+        loc, elementTy, indexCast.getIn(), extract.indices());
+
+    rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
+                                                    newExtract);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results.add<ExtractElementFromTensorFromElements>(context);
+  results
+      .add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
index e0ea5d7..3084a26 100644 (file)
@@ -1200,3 +1200,17 @@ func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
   %1 = tensor.expand_shape %0 [] : tensor<i32> into tensor<1xi32>
   return %1 : tensor<1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func @propogate_index_cast
+func @propogate_index_cast(%arg0: tensor<1xi32>) -> index {
+  // CHECK: %[[IDX:.+]] = arith.constant 0
+  // CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32>
+  // CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]]
+  // CHECK: return %[[CAST]] : index
+  %c0 = arith.constant 0 : index
+  %0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex>
+  %1 = tensor.extract %0[%c0] : tensor<1xindex>
+  return %1 : index
+}