[mlir][Linalg] Refactor conv vectorization to decouple memory from vector ops.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 29 Oct 2021 10:17:24 +0000 (10:17 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 3 Nov 2021 08:03:40 +0000 (08:03 +0000)
This refactoring prepares conv1d vectorization for a future integration into
the generic codegen path.
Once transfer_read / transfer_write vectorization also supports sliding windows,
the special pattern for conv can disappear.
This will also likely need a vector.conv operation.

Differential Revision: https://reviews.llvm.org/D112797

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir

index b7520f1..0678563 100644 (file)
@@ -1396,8 +1396,7 @@ namespace {
 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
 /// ```
-/// w and kw are unrolled.
-/// TODO: do not unroll w (resp. kw) when the strideW ( resp. dilationW) is > 1.
+/// kw is unrolled, w is unrolled iff dilationW > 1.
 struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
   Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
                            int dilationW)
@@ -1455,52 +1454,58 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
     vector::TransferWriteOp write;
     Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
 
+    // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
+    // When strideW == 1, we can batch the contiguous loads and avoid unrolling
     int64_t wSizeStep = strideW == 1 ? wSize : 1;
 
+    VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize},
+                                         lhsShapedType.getElementType());
+    VectorType rhsType =
+        VectorType::get({cSize, fSize}, rhsShapedType.getElementType());
+    VectorType resType = VectorType::get({nSize, wSizeStep, fSize},
+                                         resShapedType.getElementType());
+
+    SmallVector<Value> lhsVals, rhsVals, resVals;
     // Unroll along kw and read slices of lhs and rhs.
     // Alternatively we could preload both 3-d slices and extract smaller slices
     // iteratively without touching memory. But this will quickly spill.
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       // Read rhs slice of size {c, f} @ [kw, 0, 0].
       Value kwVal = builder.create<arith::ConstantIndexOp>(loc, kw);
-      VectorType rhsType =
-          VectorType::get({cSize, fSize}, rhsShapedType.getElementType());
-      Value rhs = builder.create<vector::TransferReadOp>(
-          loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero});
+      rhsVals.push_back(builder.create<vector::TransferReadOp>(
+          loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero}));
 
       for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
         // Read lhs slice of size {n, wSizeStep, c}
         //   @ [0, sw * w + dw * kw, 0].
         Value lhsStridedIdx = builder.create<arith::ConstantIndexOp>(
             loc, strideW * w_iv + dilationW * kw);
-        VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize},
-                                             lhsShapedType.getElementType());
-        Value lhs = builder.create<vector::TransferReadOp>(
-            loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero});
+        lhsVals.push_back(builder.create<vector::TransferReadOp>(
+            loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero}));
 
         // Read res slice: {n, wSizeStep, f} @ [0, w, 0].
         Value wVal = builder.create<arith::ConstantIndexOp>(loc, w_iv);
-        VectorType resType = VectorType::get({nSize, wSizeStep, fSize},
-                                             resShapedType.getElementType());
         // When operating on tensors, reading from the updated value is required
         // for vector.transfer_read/write hoisting to function as expected.
-        Value res = builder.create<vector::TransferReadOp>(
-            loc, resType, resShaped, ValueRange{zero, wVal, zero});
-
+        resVals.push_back(builder.create<vector::TransferReadOp>(
+            loc, resType, resShaped, ValueRange{zero, wVal, zero}));
+      }
+    }
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
         // Compute contraction: I{n, w, c} * F{c, f} -> O{n, w, f}
-        StringRef par = Par().strRef, red = Red().strRef;
-        AffineExpr n, w, f, c;
-        bindDims(ctx, n, w, f, c);
-        // clang-format off
-        res = builder.create<vector::ContractionOp>(
-          loc, lhs, rhs, res,
-          /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
-          /*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
-        // clang-format on
-
+        resVals[kw * (wSize / wSizeStep) + w_iv] = conv1dSliceAsContraction(
+            builder, loc, lhsVals[kw * (wSize / wSizeStep) + w_iv], rhsVals[kw],
+            resVals[kw * (wSize / wSizeStep) + w_iv]);
+      }
+    }
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) {
+        Value wVal = builder.create<arith::ConstantIndexOp>(loc, w_iv);
         // Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
         write = builder.create<vector::TransferWriteOp>(
-            loc, res, resShaped, ValueRange{zero, wVal, zero});
+            loc, resVals[kw * (wSize / wSizeStep) + w_iv], resShaped,
+            ValueRange{zero, wVal, zero});
         if (write.getNumResults() == 1)
           resShaped = write->getResult(0);
       }
@@ -1509,6 +1514,19 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
     return write.getOperation();
   }
 
+  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
+  vector::ContractionOp conv1dSliceAsContraction(OpBuilder &b, Location loc,
+                                                 Value lhs, Value rhs,
+                                                 Value res) {
+    StringRef par = Par().strRef, red = Red().strRef;
+    AffineExpr n, w, f, c;
+    bindDims(ctx, n, w, f, c);
+    return builder.create<vector::ContractionOp>(
+        loc, lhs, rhs, res,
+        /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
+        /*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
+  }
+
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
   FailureOr<Operation *> generateConv() {
index b1802fd..31be54e 100644 (file)
@@ -24,21 +24,26 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
 //      CHECK:   %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 //      CHECK:   %[[V_INPUT0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 //      CHECK:   %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+/// w == 1, kw == 0
+//      CHECK:   %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
 //      CHECK:   %[[CONTRACT0:.+]] = vector.contract {
+
+/// w == 0, kw == 0
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-//      CHECK:   vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-
 /// w == 1, kw == 0
-//      CHECK:   %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
 //      CHECK:   %[[CONTRACT1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+
+/// w == 0, kw == 0
+//      CHECK:   vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+/// w == 1, kw == 0
 //      CHECK:   vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
 
 // -----
@@ -69,48 +74,53 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 //      CHECK:   %[[V_FILTER_A:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 //      CHECK:   %[[V_INPUT0_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 //      CHECK:   %[[V_OUTPUT_0_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+/// w == 0, kw == 1
+//      CHECK:   %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+/// w == 1, kw == 0
+//      CHECK:   %[[V_FILTER_B:.+]]   = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_INPUT0_B:.+]]   = vector.transfer_read  %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+/// w == 1, kw == 1
+//      CHECK:     %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+
+/// w == 0, kw == 0
 //      CHECK:   %[[CONTRACT0_A:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-//      CHECK:   vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-
 /// w == 0, kw == 1
-//      CHECK:   %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
 //      CHECK:   %[[CONTRACT1_A:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-//      CHECK:   vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
-
 /// w == 1, kw == 0
-//      CHECK:   %[[V_FILTER_B:.+]]   = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_INPUT0_B:.+]]   = vector.transfer_read  %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 //      CHECK:   %[[CONTRACT0_B:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
-//      CHECK:   vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-
 /// w == 1, kw == 1
-//      CHECK:     %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]]
-//      CHECK:   %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
 //      CHECK:   %[[CONTRACT1_B:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]]
 // CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+
+/// w == 0, kw == 0
+//      CHECK:   vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+/// w == 0, kw == 1
+//      CHECK:   vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+/// w == 1, kw == 0
+//      CHECK:   vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+/// w == 1, kw == 1
 //      CHECK:   vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
 
 // -----
 
-
-
 // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 // CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
 // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -127,22 +137,27 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 //      CHECK:   %[[V_FILTER_000:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
 //      CHECK:   %[[V_INPUT_000:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
 //      CHECK:   %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
+/// w == 0, kw == 1
+//      CHECK:   %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
+//      CHECK:   %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
+//      CHECK:   %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
+
+/// w == 0, kw == 0
 //      CHECK:   %[[CONTRACT0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT_000]], %[[V_FILTER_000]], %[[V_OUTPUT_0]]
 // CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
-//      CHECK:   vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-
 /// w == 0, kw == 1
-//      CHECK:   %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32>
-//      CHECK:   %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32>
-//      CHECK:   %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32>
 //      CHECK:   %[[CONTRACT1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     %[[V_INPUT_020]], %[[V_FILTER_100]], %[[V_OUTPUT_1]]
 // CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
+
+/// w == 0, kw == 0
+//      CHECK:   vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+/// w == 0, kw == 1
 //      CHECK:   vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
   linalg.conv_1d_nwc_wcf
     {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}