[mlir][linalg] Downscale 2D pooling with unit dimensions for height to 1D pooling
authorMurali Vijayaraghavan <muralivi@google.com>
Fri, 16 Dec 2022 04:51:01 +0000 (04:51 +0000)
committerMurali Vijayaraghavan <muralivi@google.com>
Mon, 19 Dec 2022 22:34:43 +0000 (22:34 +0000)
Differential Revision: https://reviews.llvm.org/D140187

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-op-decompose.mlir

index 5f1a4ec..347c530 100644 (file)
@@ -67,26 +67,31 @@ DiagnosedSilenceableFailure
 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
                                    SmallVectorImpl<Operation *> &results,
                                    transform::TransformState &state) {
-  FailureOr<LinalgOp> windowedNhwc =
-      tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
-                                                     Conv1DNwcWcfOp>>(target);
-  if (succeeded(windowedNhwc)) {
-    results.push_back(*windowedNhwc);
-    return DiagnosedSilenceableFailure::success();
-  }
-  FailureOr<LinalgOp> windowedNchw =
-      tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
-                                                     Conv1DNcwFcwOp>>(target);
-  if (succeeded(windowedNchw)) {
-    results.push_back(*windowedNchw);
-    return DiagnosedSilenceableFailure::success();
-  }
-  FailureOr<LinalgOp> depthwise =
-      tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
-  if (succeeded(depthwise)) {
-    results.push_back(*depthwise);
-    return DiagnosedSilenceableFailure::success();
-  }
+#define DOWNSCALE(trans) \
+    { \
+      FailureOr<LinalgOp> res = tryApply<trans>(target); \
+      if (succeeded(res)) { \
+        results.push_back(*res); \
+        return DiagnosedSilenceableFailure::success(); \
+      } \
+    }
+
+#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
+#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
+
+  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
+  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
+  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
+  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
+  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
+  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
+  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
+  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
+  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
+  DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp)
+#undef DOWNSCALE_NORMAL
+#undef DOWNSCALE_CALL
+#undef DOWNSCALE
   results.assign(1, nullptr);
   return emitDefaultSilenceableFailure(target);
 }
index b8c6115..77ea7af 100644 (file)
@@ -613,23 +613,39 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
   auto outputShape = outputType.getShape();
 
   // Get domain indices based on conv2D layout.
-  int khIndex, kwIndex, ohIndex, owIndex;
-
-  TypeSwitch<Operation *>(convOp)
+  auto [khIndex, kwIndex, ohIndex, owIndex] =
+      TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t,
+                                         int64_t>>(convOp)
       .Case([&](linalg::Conv2DNhwcHwcfOp op) {
-        khIndex = 0;
-        kwIndex = 1;
-        ohIndex = 1;
-        owIndex = 2;
+        return std::make_tuple(0, 1, 1, 2);
       })
       .Case([&](linalg::Conv2DNchwFchwOp op) {
-        khIndex = 2;
-        kwIndex = 3;
-        ohIndex = 2;
-        owIndex = 3;
+        return std::make_tuple(2, 3, 2, 3);
+      })
+      .Case([&](linalg::PoolingNhwcSumOp op) {
+        return std::make_tuple(0, 1, 1, 2);
+      })
+      .Case([&](linalg::PoolingNchwSumOp op) {
+        return std::make_tuple(0, 1, 2, 3);
+      })
+      .Case([&](linalg::PoolingNhwcMaxOp op) {
+        return std::make_tuple(0, 1, 1, 2);
+      })
+      .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
+        return std::make_tuple(0, 1, 1, 2);
+      })
+      .Case([&](linalg::PoolingNhwcMinOp op) {
+        return std::make_tuple(0, 1, 1, 2);
+      })
+      .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
+        return std::make_tuple(0, 1, 1, 2);
+      })
+      .Case([&](linalg::PoolingNchwMaxOp op) {
+        return std::make_tuple(0, 1, 2, 3);
       })
       .Default([&](Operation *op) {
-        llvm_unreachable("unexpected conv2d operation.");
+        llvm_unreachable("unexpected conv2d/pool2d operation.");
+        return std::make_tuple(0, 0, 0, 0);
       });
 
   // Only handle the case where at least one of the window dimensions is
@@ -688,6 +704,20 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
                                                               Conv1DNwcWcfOp>;
 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
                                                               Conv1DNcwFcwOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
+                                                              PoolingNwcSumOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
+                                                              PoolingNcwSumOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
+                                                              PoolingNwcMaxOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<
+    PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
+                                                              PoolingNwcMinOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<
+    PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
+                                                              PoolingNcwMaxOp>;
 
 FailureOr<DepthwiseConv1DNwcWcOp>
 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
@@ -765,4 +795,15 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
                                                      Conv1DNcwFcwOp>,
                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
                                                   benefit);
+  patterns.add<
+      DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
+      DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
+      DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
+      DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
+                                            PoolingNwcMaxUnsignedOp>,
+      DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
+      DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
+                                            PoolingNwcMinUnsignedOp>,
+      DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
+      patterns.getContext(), benefit);
 }
index 81ee39d..2c873b2 100644 (file)
@@ -56,6 +56,132 @@ func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: t
   return %0: tensor<1x1x56x96xf32>
 }
 
+// CHECK-LABEL: @pooling_nhwc_sum
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @pooling_nhwc_sum(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
+    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nchw_sum
+// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
+// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
+// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
+func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
+    outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x?x1x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nhwc_max
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
+    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nhwc_max_unsigned
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
+    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nhwc_min
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
+    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nhwc_min_unsigned
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
+    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nchw_max
+// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
+// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
+// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
+func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
+    outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x?x1x?xf32>
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match interface{LinalgOp} in %arg1