From 991945f4410af9df33f0889bf3c0695fd45a28b1 Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Wed, 8 Mar 2023 14:30:16 -0800 Subject: [PATCH] [mlir][linalg] Downscale 2D convolution with unit dimensions to 1D convolution Decompose conv_2d -> conv_1d. This MR follows a similar approach to https://reviews.llvm.org/D112928. This patch adds support to convert conv_2D operation with either unit height or unit width to conv_1D operation. This is useful when 2D convolution is tiled to have a single dimension for either height or width and then can be vectorized once it is decomposed into 1D convolution. This patch https://reviews.llvm.org/D145160 adds vector support for linalg.conv_1d operation and thereby allowing us to vectorize linalg.conv_2d operation after proper tiling. This missing feature is reported here: https://discourse.llvm.org/t/vectorization-of-convolution-op/60458. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D145162 --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 13 +++++ .../Linalg/TransformOps/LinalgTransformOps.cpp | 1 + mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 61 +++++++++++++++++++++- .../Dialect/Linalg/transform-op-decompose.mlir | 17 ++++++ 4 files changed, 90 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 4dd7641..eaf9fec 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1041,6 +1041,19 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final } }; +struct DownscaleConv2DOp final : public OpRewritePattern { + DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + FailureOr returningMatchAndRewrite(Conv2DOp convOp, + PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(Conv2DOp convOp, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(convOp, rewriter); + } +}; + /// /// Linalg generalization pattern. /// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 600cdde..3e6e1df 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -266,6 +266,7 @@ transform::DecomposeOp::applyToOne(LinalgOp target, DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp) DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp) DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp) + DOWNSCALE(DownscaleConv2DOp) #undef DOWNSCALE_NORMAL #undef DOWNSCALE_CALL #undef DOWNSCALE diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 01f2c17..9de0f76 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1361,14 +1361,71 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( return conv1DOp; } +FailureOr +DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, + PatternRewriter &rewriter) const { + if (convOp.hasBufferSemantics()) + return failure(); // To be implemented. + + Value input = convOp.getInputs().front(); + Value kernel = convOp.getInputs().back(); + Value output = convOp.getOutputs().front(); + + auto inputType = input.getType().dyn_cast(); + auto kernelType = kernel.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + + auto kernelShape = kernelType.getShape(); + auto outputShape = outputType.getShape(); + + // Only handle the case where at least one of the window dimensions is + // of size 1. Other cases can rely on tiling to reduce to such cases. + int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; + int64_t ohSize = outputShape[0], owSize = outputShape[1]; + bool removeH = (khSize == 1 && ohSize == 1); + bool removeW = (kwSize == 1 && owSize == 1); + if (!removeH && !removeW) + return failure(); + + // Get new shapes and types for all operands by removing the size-1 + // dimension. + using RTTBuilder = RankedTensorType::Builder; + RankedTensorType newInputType = + RTTBuilder(inputType).dropDim((removeH ? 0 : 1)); + RankedTensorType newKernelType = + RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); + RankedTensorType newOutputType = + RTTBuilder(outputType).dropDim(removeH ? 0 : 1); + + // Rank-reduce operands. + Location loc = convOp.getLoc(); + Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, input, newInputType); + Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, kernel, newKernelType); + Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, output, newOutputType); + + auto conv1DOp = rewriter.create(loc, newOutputType, + ValueRange{newInput, newKernel}, + ValueRange{newOutput}); + + // Insert back. + Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, conv1DOp.getResult(0), output); + rewriter.replaceOp(convOp, inserted); + + return conv1DOp; +} + void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add, DownscaleSizeOneWindowed2DConvolution, - DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), - benefit); + DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>( + patterns.getContext(), benefit); patterns.add< DownscaleSizeOneWindowed2DConvolution, DownscaleSizeOneWindowed2DConvolution, diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir index e023e64..82795ec 100644 --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -56,6 +56,23 @@ func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: t return %0: tensor<1x1x56x96xf32> } +// CHECK-LABEL: @conv_2d +// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>, +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>) +func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.conv_1d + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.conv_2d + ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>) + outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32> + // CHECK: return %[[RES]] + return %0 : tensor<1x?xf32> +} + // CHECK-LABEL: @pooling_nhwc_sum // CHECK-SAME: %[[ARG0:.+]]: tensor, // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> -- 2.7.4