[mlir] Add canonicalization for tensor_cast + tensor_to_memref
authorThomas Raoux <thomasraoux@google.com>
Tue, 16 Feb 2021 05:10:07 +0000 (21:10 -0800)
committerThomas Raoux <thomasraoux@google.com>
Tue, 16 Feb 2021 15:11:09 +0000 (07:11 -0800)
This helps bufferization passes by removing tensor_cast operations.

Differential Revision: https://reviews.llvm.org/D96745

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir

index dd760af..4e6ff2e 100644 (file)
@@ -3078,6 +3078,7 @@ def TensorToMemrefOp : Std_Op<"tensor_to_memref",
   let assemblyFormat = "$tensor attr-dict `:` type($memref)";
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
index 3ef48ce..4908291 100644 (file)
@@ -3558,6 +3558,37 @@ OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) {
   return {};
 }
 
+namespace {
+/// Replace tensor_cast + tensor_to_memref by tensor_to_memref + memref_cast.
+struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> {
+  using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef,
+                                PatternRewriter &rewriter) const final {
+    auto tensorCastOperand =
+        tensorToMemRef.getOperand().getDefiningOp<tensor::CastOp>();
+    if (!tensorCastOperand)
+      return failure();
+    auto srcTensorType =
+        tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
+    if (!srcTensorType)
+      return failure();
+    auto memrefType = MemRefType::get(srcTensorType.getShape(),
+                                      srcTensorType.getElementType());
+    Value memref = rewriter.create<TensorToMemrefOp>(
+        tensorToMemRef.getLoc(), memrefType, tensorCastOperand.getOperand());
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(tensorToMemRef,
+                                              tensorToMemRef.getType(), memref);
+    return success();
+  }
+};
+} // namespace
+
+void TensorToMemrefOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<TensorCastToMemref>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
index 8187c2f..7b54938 100644 (file)
@@ -131,3 +131,15 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
   %2 = dim %0, %c1 : tensor<?x?xf32>
   return %1, %2: index, index
 }
+
+// CHECK-LABEL: func @tensor_cast_to_memref
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+//       CHECK:   %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8>
+//       CHECK:   %[[M1:.+]] = memref_cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+//       CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
+func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
+  memref<?x?x16x32xi8> {
+  %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+  %1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
+  return %1 : memref<?x?x16x32xi8>
+}