From 787bf5e383a32b3ebc87332ff9e868db8f937056 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 2 Oct 2020 05:40:52 -0400 Subject: [PATCH] [mlir] Add canonicalization for the `subtensor` op Differential revision: https://reviews.llvm.org/D88656 --- mlir/include/mlir/Dialect/StandardOps/IR/Ops.td | 2 +- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 60 ++++++++++++++++--------- mlir/test/Transforms/canonicalize.mlir | 29 ++++++++++++ 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index dbc3e4c..3d9daee 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -3164,7 +3164,7 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> { ArrayRef staticStrides); }]; - // let hasCanonicalizer = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index d684a4b..5548274 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2989,50 +2989,59 @@ void canonicalizeSubViewPart(SmallVectorImpl &values, } } +static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op, + SubViewOp newOp) { + rewriter.replaceOpWithNewOp(op, newOp, op.getType()); +} + +static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op, + SubTensorOp newOp) { + rewriter.replaceOpWithNewOp(op, newOp, op.getType()); +} + /// Pattern to rewrite a subview op with constant arguments. -class SubViewOpConstantArgumentFolder final - : public OpRewritePattern { +template +class OpWithOffsetSizesAndStridesConstantArgumentFolder final + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(SubViewOp subViewOp, + LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override { // No constant operand, just return; - if (llvm::none_of(subViewOp.getOperands(), [](Value operand) { + if (llvm::none_of(op.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return failure(); // At least one of offsets/sizes/strides is a new constant. // Form the new list of operands and constant attributes from the existing. - SmallVector newOffsets(subViewOp.offsets()); + SmallVector newOffsets(op.offsets()); SmallVector newStaticOffsets = - extractFromI64ArrayAttr(subViewOp.static_offsets()); - assert(newStaticOffsets.size() == subViewOp.getSourceRank()); + extractFromI64ArrayAttr(op.static_offsets()); + assert(newStaticOffsets.size() == op.getSourceRank()); canonicalizeSubViewPart(newOffsets, newStaticOffsets, ShapedType::isDynamicStrideOrOffset); - SmallVector newSizes(subViewOp.sizes()); + SmallVector newSizes(op.sizes()); SmallVector newStaticSizes = - extractFromI64ArrayAttr(subViewOp.static_sizes()); - assert(newStaticOffsets.size() == subViewOp.getSourceRank()); + extractFromI64ArrayAttr(op.static_sizes()); + assert(newStaticOffsets.size() == op.getSourceRank()); canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic); - SmallVector newStrides(subViewOp.strides()); + SmallVector newStrides(op.strides()); SmallVector newStaticStrides = - extractFromI64ArrayAttr(subViewOp.static_strides()); - assert(newStaticOffsets.size() == subViewOp.getSourceRank()); + extractFromI64ArrayAttr(op.static_strides()); + assert(newStaticOffsets.size() == op.getSourceRank()); canonicalizeSubViewPart(newStrides, newStaticStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. - auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), newStaticOffsets, - newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides); + auto newOp = rewriter.create( + op.getLoc(), op.source(), newStaticOffsets, newStaticSizes, + newStaticStrides, newOffsets, newSizes, newStrides); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); + replaceWithNewOp(rewriter, op, newOp); return success(); } @@ -3183,8 +3192,8 @@ public: void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert( - context); + results.insert, + SubViewOpMemRefCastFolder>(context); } //===----------------------------------------------------------------------===// @@ -3275,6 +3284,13 @@ static LogicalResult verify(SubTensorOp op) { return success(); } +void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results + .insert>( + context); +} + //===----------------------------------------------------------------------===// // TensorCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 3603c47..dc7be09 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1110,3 +1110,32 @@ func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> { // CHECK-NEXT: return %[[C2]] return %1 : tensor<8x4xi32> } + +// ----- + +// CHECK-LABEL: func @subtensor +// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index +func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index) + -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c7 = constant 7 : index + %c11 = constant 11 : index + + // CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] : + // CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32> + // CHECK: tensor_cast %{{.*}} : tensor<7x11x2xf32> to tensor + %1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] + : tensor<8x16x4xf32> to tensor + + // Test: subtensor with one dynamic operand can also be folded. + // CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] : + // CHECK-SAME: tensor to tensor<2x?x2xf32> + // CHECK: tensor_cast %{{.*}} : tensor<2x?x2xf32> to tensor + %2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1] + : tensor to tensor + + return %2 : tensor +} -- 2.7.4