[mlir][linalg] Add decomposition from conv_2d_nchw
authorStanley Winata <stanley@nod-labs.com>
Fri, 9 Sep 2022 23:00:33 +0000 (16:00 -0700)
committerHanhan Wang <hanchung@google.com>
Fri, 9 Sep 2022 23:00:37 +0000 (16:00 -0700)
Decompose conv_2d_nchw_fchw -> conv_1d_ncw_fcw

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D133551

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-op-decompose.mlir

index abd0243..e6a1b07 100644 (file)
@@ -754,20 +754,19 @@ private:
 
 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
 /// convolution ops.
+template <typename Conv2DOp, typename Conv1DOp>
 struct DownscaleSizeOneWindowed2DConvolution final
-    : public OpRewritePattern<Conv2DNhwcHwcfOp> {
+    : public OpRewritePattern<Conv2DOp> {
   DownscaleSizeOneWindowed2DConvolution(
       MLIRContext *context,
       LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
-      : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
-        filter(std::move(f)) {}
+      : OpRewritePattern<Conv2DOp>(context, benefit), filter(std::move(f)) {}
 
-  FailureOr<Conv1DNwcWcfOp>
-  returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
-                           PatternRewriter &rewriter) const;
+  FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+                                               PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+  LogicalResult matchAndRewrite(Conv2DOp convOp,
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
index 80e4d31..f4241f4 100644 (file)
@@ -76,10 +76,18 @@ DiagnosedSilenceableFailure
 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
                                    SmallVectorImpl<Operation *> &results,
                                    transform::TransformState &state) {
-  FailureOr<LinalgOp> windowed =
-      tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
-  if (succeeded(windowed)) {
-    results.push_back(*windowed);
+  FailureOr<LinalgOp> windowedNhwc =
+      tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
+                                                     Conv1DNwcWcfOp>>(target);
+  if (succeeded(windowedNhwc)) {
+    results.push_back(*windowedNhwc);
+    return DiagnosedSilenceableFailure(success());
+  }
+  FailureOr<LinalgOp> windowedNchw =
+      tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
+                                                     Conv1DNcwFcwOp>>(target);
+  if (succeeded(windowedNchw)) {
+    results.push_back(*windowedNchw);
     return DiagnosedSilenceableFailure(success());
   }
   FailureOr<LinalgOp> depthwise =
index 2fcbe68..1c4ceaa 100644 (file)
@@ -828,9 +828,9 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
 // and then turning back to named ops. But for now it's fine to have a few
 // patterns matching special ops to get started.
 
-FailureOr<Conv1DNwcWcfOp>
-DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
-    linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const {
+template <typename Conv2DOp, typename Conv1DOp>
+FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
+    returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
   if (failed(filter.checkAndNotify(rewriter, convOp)))
     return failure();
   if (convOp.hasBufferSemantics())
@@ -847,10 +847,30 @@ DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
   auto kernelShape = kernelType.getShape();
   auto outputShape = outputType.getShape();
 
+  // Get domain indices based on conv2D layout.
+  int khIndex, kwIndex, ohIndex, owIndex;
+
+  TypeSwitch<Operation *>(convOp)
+      .Case([&](linalg::Conv2DNhwcHwcfOp op) {
+        khIndex = 0;
+        kwIndex = 1;
+        ohIndex = 1;
+        owIndex = 2;
+      })
+      .Case([&](linalg::Conv2DNchwFchwOp op) {
+        khIndex = 2;
+        kwIndex = 3;
+        ohIndex = 2;
+        owIndex = 3;
+      })
+      .Default([&](Operation *op) {
+        llvm_unreachable("unexpected conv2d operation.");
+      });
+
   // Only handle the case where at least one of the window dimensions is
   // of size 1. Other cases can rely on tiling to reduce to such cases.
-  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
-  int64_t ohSize = outputShape[1], owSize = outputShape[2];
+  int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
+  int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
   bool removeH = (khSize == 1 && ohSize == 1);
   bool removeW = (kwSize == 1 && owSize == 1);
   if (!removeH && !removeW)
@@ -860,11 +880,11 @@ DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
   // dimension.
   using RTTBuilder = RankedTensorType::Builder;
   RankedTensorType newInputType =
-      RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+      RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
   RankedTensorType newKernelType =
-      RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+      RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
   RankedTensorType newOutputType =
-      RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+      RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
 
   // Rank-reduce operands.
   Location loc = convOp.getLoc();
@@ -877,16 +897,17 @@ DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
 
   // Rank-reduce strides and dilations too.
   // TODO: dropDim 1-liner helper.
-  auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
+  auto strides =
+      llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
   strides.erase(strides.begin() + (removeH ? 0 : 1));
   auto stridesAttr = rewriter.getI64VectorAttr(strides);
 
   auto dilations =
-      llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
+      llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
 
-  auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
+  auto conv1DOp = rewriter.create<Conv1DOp>(
       loc, newOutputType, ValueRange{newInput, newKernel},
       ValueRange{newOutput}, stridesAttr, dilationsAttr);
 
@@ -973,7 +994,10 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
 void linalg::populateDecomposeConvolutionPatterns(
     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
     PatternBenefit benefit) {
-  patterns.add<DownscaleSizeOneWindowed2DConvolution,
+  patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
+                                                     Conv1DNwcWcfOp>,
+               DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
+                                                     Conv1DNcwFcwOp>,
                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
                                                   benefit);
 }
index 988d706..7348289 100644 (file)
@@ -18,6 +18,24 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x
   return %0 : tensor<?x1x?x?xf32>
 }
 
+// CHECK-LABEL: @conv_2d_nchw_fchw
+// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
+// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
+// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
+func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>,
+                                 strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>)
+    outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x?x1x?xf32>
+}
+
 // CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>