[mlir][linalg] Fix vectorization of linalg depthwise conv for int types
authorRob Suderman <suderman@google.com>
Tue, 8 Nov 2022 00:18:39 +0000 (16:18 -0800)
committerRob Suderman <suderman@google.com>
Wed, 9 Nov 2022 00:21:05 +0000 (16:21 -0800)
Vectorization of Linalg's depthwise convolution only supports floating
point types. Previous version assumed floating point operations would
work. This version checks whether the computation is integer or floating
point and adjust the inner loop computation.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D137595

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir

index cedec72..2cf74a6 100644 (file)
@@ -1746,11 +1746,17 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        resVals[w] = depthwiseConv1dSliceAsFma(
+        resVals[w] = depthwiseConv1dSliceAsMulAcc(
             builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
       }
     }
 
+    // Its possible we failed to create the Fma
+    for (auto v : resVals) {
+      if (!v)
+        return failure();
+    }
+
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
@@ -1770,11 +1776,45 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
         .getOperation();
   }
 
-  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma.
-  Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
-                                  Value rhs, Value res) {
-    Value bcast = builder.create<vector::BroadcastOp>(loc, res.getType(), rhs);
-    return b.create<vector::FMAOp>(loc, lhs, bcast, res);
+  // Take a value of element type T and widen to the destination type.
+  Value promote(OpBuilder &b, Location loc, Value val, Type ty) {
+    if (val.getType() == ty)
+      return val;
+
+    const int64_t srcWidth =
+        getElementTypeOrSelf(val.getType()).getIntOrFloatBitWidth();
+    const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth();
+
+    if (getElementTypeOrSelf(ty).isa<FloatType>() && srcWidth < destWidth)
+      return builder.create<arith::ExtFOp>(loc, ty, val);
+
+    if (getElementTypeOrSelf(ty).isa<IntegerType>() && srcWidth < destWidth)
+      return builder.create<arith::ExtSIOp>(loc, ty, val);
+
+    return nullptr;
+  }
+
+  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
+  Value depthwiseConv1dSliceAsMulAcc(OpBuilder &b, Location loc, Value lhs,
+                                     Value rhs, Value res) {
+    auto rhsTy = rhs.getType().cast<ShapedType>();
+    auto resTy = res.getType().cast<ShapedType>();
+
+    // TODO(suderman): Change this to use a vector.ima intrinsic.
+    lhs = promote(b, loc, lhs, resTy);
+
+    rhs = builder.create<vector::BroadcastOp>(
+        loc, resTy.clone(rhsTy.getElementType()), rhs);
+    rhs = promote(b, loc, rhs, resTy);
+
+    if (!lhs || !rhs)
+      return nullptr;
+
+    if (resTy.getElementType().isa<FloatType>())
+      return b.create<vector::FMAOp>(loc, lhs, rhs, res);
+
+    auto mul = b.create<arith::MulIOp>(loc, lhs, rhs);
+    return b.create<arith::AddIOp>(loc, mul, res);
   }
 
   /// Entry point that transposes into the common form:
index 1374c99..f1f00cf 100644 (file)
@@ -463,7 +463,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 
 // -----
 
-func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
   linalg.depthwise_conv_1d_nwc_wc
     {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
     ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
@@ -471,7 +471,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt
   return
 }
 
-//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4_memref
+//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref
 //  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
 
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
@@ -504,6 +504,51 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt
 
 // -----
 
+func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %filter: memref<2x4xi8>, %output: memref<3x2x4xi32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
+    outs(%output : memref<3x2x4xi32>)
+  return
+}
+
+//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref
+//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>)
+
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+
+/// Read the whole data in one shot.
+//      CHECK:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+//      CHECK:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xi8>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xi8>
+
+/// w == 0, kw = 
+//      CHECK:  %[[EXT_INPUT_0:.*]] = arith.extsi %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8>
+//      CHECK:  %[[EXT_FILTER_0:.*]] = arith.extsi %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x2x4xi32>
+//      CHECK:  %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[V_OUTPUT_R]] : vector<3x2x4xi32>
+
+/// w == 0, kw = 1
+//      CHECK:  %[[EXT_INPUT_1:.*]] = arith.extsi %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8>
+//      CHECK:  %[[EXT_FILTER_1:.*]] = arith.extsi %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x2x4xi32>
+//      CHECK:  %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x2x4xi32>
+
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[ADD_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+// -----
+
 func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {
   linalg.conv_1d_nwc_wcf
   {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}