[mlir][linalg] Fix vectorization of linalg depthwise conv for int types
authorRob Suderman <suderman@google.com>
Tue, 8 Nov 2022 00:18:39 +0000 (16:18 -0800)
committerRob Suderman <suderman@google.com>
Wed, 9 Nov 2022 00:21:05 +0000 (16:21 -0800)
Vectorization of Linalg's depthwise convolution only supports floating
point types. Previous version assumed floating point operations would
work. This version checks whether the computation is integer or floating
point and adjust the inner loop computation.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D137595

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir

index cedec72b9cb334ecbdd5317453f2758db0d71b1f..2cf74a67df20e765609ca4a2aab596d36506be5e 100644 (file)
@@ -1746,11 +1746,17 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        resVals[w] = depthwiseConv1dSliceAsFma(
+        resVals[w] = depthwiseConv1dSliceAsMulAcc(
             builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
       }
     }
 
+    // Its possible we failed to create the Fma
+    for (auto v : resVals) {
+      if (!v)
+        return failure();
+    }
+
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
@@ -1770,11 +1776,45 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
         .getOperation();
   }
 
-  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma.
-  Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
-                                  Value rhs, Value res) {
-    Value bcast = builder.create<vector::BroadcastOp>(loc, res.getType(), rhs);
-    return b.create<vector::FMAOp>(loc, lhs, bcast, res);
+  // Take a value of element type T and widen to the destination type.
+  Value promote(OpBuilder &b, Location loc, Value val, Type ty) {
+    if (val.getType() == ty)
+      return val;
+
+    const int64_t srcWidth =
+        getElementTypeOrSelf(val.getType()).getIntOrFloatBitWidth();
+    const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth();
+
+    if (getElementTypeOrSelf(ty).isa<FloatType>() && srcWidth < destWidth)
+      return builder.create<arith::ExtFOp>(loc, ty, val);
+
+    if (getElementTypeOrSelf(ty).isa<IntegerType>() && srcWidth < destWidth)
+      return builder.create<arith::ExtSIOp>(loc, ty, val);
+
+    return nullptr;
+  }
+
+  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
+  Value depthwiseConv1dSliceAsMulAcc(OpBuilder &b, Location loc, Value lhs,
+                                     Value rhs, Value res) {
+    auto rhsTy = rhs.getType().cast<ShapedType>();
+    auto resTy = res.getType().cast<ShapedType>();
+
+    // TODO(suderman): Change this to use a vector.ima intrinsic.
+    lhs = promote(b, loc, lhs, resTy);
+
+    rhs = builder.create<vector::BroadcastOp>(
+        loc, resTy.clone(rhsTy.getElementType()), rhs);
+    rhs = promote(b, loc, rhs, resTy);
+
+    if (!lhs || !rhs)
+      return nullptr;
+
+    if (resTy.getElementType().isa<FloatType>())
+      return b.create<vector::FMAOp>(loc, lhs, rhs, res);
+
+    auto mul = b.create<arith::MulIOp>(loc, lhs, rhs);
+    return b.create<arith::AddIOp>(loc, mul, res);
   }
 
   /// Entry point that transposes into the common form:
index 1374c996128a17e3a06b7b0d936814d9aba16d38..f1f00cf16d1b06267254d4503686eba9544b55e3 100644 (file)
@@ -463,7 +463,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 
 // -----
 
-func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
   linalg.depthwise_conv_1d_nwc_wc
     {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
     ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
@@ -471,7 +471,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt
   return
 }
 
-//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4_memref
+//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref
 //  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
 
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
@@ -502,6 +502,51 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt
 //      CHECK:   vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
 
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %filter: memref<2x4xi8>, %output: memref<3x2x4xi32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
+    outs(%output : memref<3x2x4xi32>)
+  return
+}
+
+//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref
+//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>)
+
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+
+/// Read the whole data in one shot.
+//      CHECK:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+//      CHECK:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xi8>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xi8>
+
+/// w == 0, kw = 
+//      CHECK:  %[[EXT_INPUT_0:.*]] = arith.extsi %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8>
+//      CHECK:  %[[EXT_FILTER_0:.*]] = arith.extsi %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x2x4xi32>
+//      CHECK:  %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[V_OUTPUT_R]] : vector<3x2x4xi32>
+
+/// w == 0, kw = 1
+//      CHECK:  %[[EXT_INPUT_1:.*]] = arith.extsi %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8>
+//      CHECK:  %[[EXT_FILTER_1:.*]] = arith.extsi %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x2x4xi32>
+//      CHECK:  %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x2x4xi32>
+
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[ADD_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
 // -----
 
 func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {