From 807e5467f3e1b115f53377ea36ecad5625ce8280 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 15 Feb 2021 21:10:07 -0800 Subject: [PATCH] [mlir] Add canonicalization for tensor_cast + tensor_to_memref This helps bufferization passes by removing tensor_cast operations. Differential Revision: https://reviews.llvm.org/D96745 --- mlir/include/mlir/Dialect/StandardOps/IR/Ops.td | 1 + mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 31 +++++++++++++++++++++++++ mlir/test/Dialect/Standard/canonicalize.mlir | 12 ++++++++++ 3 files changed, 44 insertions(+) diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index dd760af..4e6ff2e 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -3078,6 +3078,7 @@ def TensorToMemrefOp : Std_Op<"tensor_to_memref", let assemblyFormat = "$tensor attr-dict `:` type($memref)"; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 3ef48ce..4908291 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3558,6 +3558,37 @@ OpFoldResult TensorToMemrefOp::fold(ArrayRef) { return {}; } +namespace { +/// Replace tensor_cast + tensor_to_memref by tensor_to_memref + memref_cast. +struct TensorCastToMemref : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef, + PatternRewriter &rewriter) const final { + auto tensorCastOperand = + tensorToMemRef.getOperand().getDefiningOp(); + if (!tensorCastOperand) + return failure(); + auto srcTensorType = + tensorCastOperand.getOperand().getType().dyn_cast(); + if (!srcTensorType) + return failure(); + auto memrefType = MemRefType::get(srcTensorType.getShape(), + srcTensorType.getElementType()); + Value memref = rewriter.create( + tensorToMemRef.getLoc(), memrefType, tensorCastOperand.getOperand()); + rewriter.replaceOpWithNewOp(tensorToMemRef, + tensorToMemRef.getType(), memref); + return success(); + } +}; +} // namespace + +void TensorToMemrefOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index 8187c2f..7b54938 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -131,3 +131,15 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) { %2 = dim %0, %c1 : tensor return %1, %2: index, index } + +// CHECK-LABEL: func @tensor_cast_to_memref +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8> +// CHECK: %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8> +// CHECK: %[[M1:.+]] = memref_cast %[[M]] : memref<4x6x16x32xi8> to memref +// CHECK: return %[[M1]] : memref +func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) -> + memref { + %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor + %1 = tensor_to_memref %0 : memref + return %1 : memref +} -- 2.7.4