From 1a151fdc011dc422d740a954c346827613961350 Mon Sep 17 00:00:00 2001 From: Murali Vijayaraghavan Date: Fri, 16 Dec 2022 04:51:01 +0000 Subject: [PATCH] [mlir][linalg] Downscale 2D pooling with unit dimensions for height to 1D pooling Differential Revision: https://reviews.llvm.org/D140187 --- .../Linalg/TransformOps/LinalgTransformOps.cpp | 45 ++++---- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 65 +++++++++-- .../Dialect/Linalg/transform-op-decompose.mlir | 126 +++++++++++++++++++++ 3 files changed, 204 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 5f1a4ec..347c530 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -67,26 +67,31 @@ DiagnosedSilenceableFailure transform::DecomposeOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { - FailureOr windowedNhwc = - tryApply>(target); - if (succeeded(windowedNhwc)) { - results.push_back(*windowedNhwc); - return DiagnosedSilenceableFailure::success(); - } - FailureOr windowedNchw = - tryApply>(target); - if (succeeded(windowedNchw)) { - results.push_back(*windowedNchw); - return DiagnosedSilenceableFailure::success(); - } - FailureOr depthwise = - tryApply(target); - if (succeeded(depthwise)) { - results.push_back(*depthwise); - return DiagnosedSilenceableFailure::success(); - } +#define DOWNSCALE(trans) \ + { \ + FailureOr res = tryApply(target); \ + if (succeeded(res)) { \ + results.push_back(*res); \ + return DiagnosedSilenceableFailure::success(); \ + } \ + } + +#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution +#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b)) + + DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp) + DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp) + DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp) + DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp) + DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp) + DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp) + DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp) + DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp) + DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp) + DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp) +#undef DOWNSCALE_NORMAL +#undef DOWNSCALE_CALL +#undef DOWNSCALE results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index b8c6115..77ea7af 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -613,23 +613,39 @@ FailureOr DownscaleSizeOneWindowed2DConvolution:: auto outputShape = outputType.getShape(); // Get domain indices based on conv2D layout. - int khIndex, kwIndex, ohIndex, owIndex; - - TypeSwitch(convOp) + auto [khIndex, kwIndex, ohIndex, owIndex] = + TypeSwitch>(convOp) .Case([&](linalg::Conv2DNhwcHwcfOp op) { - khIndex = 0; - kwIndex = 1; - ohIndex = 1; - owIndex = 2; + return std::make_tuple(0, 1, 1, 2); }) .Case([&](linalg::Conv2DNchwFchwOp op) { - khIndex = 2; - kwIndex = 3; - ohIndex = 2; - owIndex = 3; + return std::make_tuple(2, 3, 2, 3); + }) + .Case([&](linalg::PoolingNhwcSumOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::PoolingNchwSumOp op) { + return std::make_tuple(0, 1, 2, 3); + }) + .Case([&](linalg::PoolingNhwcMaxOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::PoolingNhwcMinOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::PoolingNchwMaxOp op) { + return std::make_tuple(0, 1, 2, 3); }) .Default([&](Operation *op) { - llvm_unreachable("unexpected conv2d operation."); + llvm_unreachable("unexpected conv2d/pool2d operation."); + return std::make_tuple(0, 0, 0, 0); }); // Only handle the case where at least one of the window dimensions is @@ -688,6 +704,20 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution; template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution< + PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution< + PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; FailureOr DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( @@ -765,4 +795,15 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, Conv1DNcwFcwOp>, DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), benefit); + patterns.add< + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution>( + patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir index 81ee39d..2c873b2 100644 --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -56,6 +56,132 @@ func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: t return %0: tensor<1x1x56x96xf32> } +// CHECK-LABEL: @pooling_nhwc_sum +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nchw_sum +// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor, +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor) +func.func @pooling_nchw_sum(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_max +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_max(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_max_unsigned +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_min +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_min(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_min_unsigned +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nchw_max +// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor, +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor) +func.func @pooling_nchw_max(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match interface{LinalgOp} in %arg1 -- 2.7.4