From: Rob Suderman Date: Tue, 8 Nov 2022 00:18:39 +0000 (-0800) Subject: [mlir][linalg] Fix vectorization of linalg depthwise conv for int types X-Git-Tag: upstream/17.0.6~28170 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9c923f4e58357af870925781f0b33d1ee958b9d8;p=platform%2Fupstream%2Fllvm.git [mlir][linalg] Fix vectorization of linalg depthwise conv for int types Vectorization of Linalg's depthwise convolution only supports floating point types. Previous version assumed floating point operations would work. This version checks whether the computation is integer or floating point and adjust the inner loop computation. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D137595 --- diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index cedec72b9cb3..2cf74a67df20 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1746,11 +1746,17 @@ struct Conv1DGenerator : public StructuredGenerator { // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals[w] = depthwiseConv1dSliceAsFma( + resVals[w] = depthwiseConv1dSliceAsMulAcc( builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } + // Its possible we failed to create the Fma + for (auto v : resVals) { + if (!v) + return failure(); + } + // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. // This does not depend on kw. for (int64_t w = 0; w < wSize; w += wSizeStep) { @@ -1770,11 +1776,45 @@ struct Conv1DGenerator : public StructuredGenerator { .getOperation(); } - /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma. - Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs, - Value rhs, Value res) { - Value bcast = builder.create(loc, res.getType(), rhs); - return b.create(loc, lhs, bcast, res); + // Take a value of element type T and widen to the destination type. + Value promote(OpBuilder &b, Location loc, Value val, Type ty) { + if (val.getType() == ty) + return val; + + const int64_t srcWidth = + getElementTypeOrSelf(val.getType()).getIntOrFloatBitWidth(); + const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth(); + + if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) + return builder.create(loc, ty, val); + + if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) + return builder.create(loc, ty, val); + + return nullptr; + } + + /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc + Value depthwiseConv1dSliceAsMulAcc(OpBuilder &b, Location loc, Value lhs, + Value rhs, Value res) { + auto rhsTy = rhs.getType().cast(); + auto resTy = res.getType().cast(); + + // TODO(suderman): Change this to use a vector.ima intrinsic. + lhs = promote(b, loc, lhs, resTy); + + rhs = builder.create( + loc, resTy.clone(rhsTy.getElementType()), rhs); + rhs = promote(b, loc, rhs, resTy); + + if (!lhs || !rhs) + return nullptr; + + if (resTy.getElementType().isa()) + return b.create(loc, lhs, rhs, res); + + auto mul = b.create(loc, lhs, rhs); + return b.create(loc, mul, res); } /// Entry point that transposes into the common form: diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir index 1374c996128a..f1f00cf16d1b 100644 --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -463,7 +463,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x // ----- -func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) { +func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) { linalg.depthwise_conv_1d_nwc_wc {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>) @@ -471,7 +471,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt return } -// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4_memref +// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref // CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index @@ -502,6 +502,51 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt // CHECK: vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +// ----- + +func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %filter: memref<2x4xi8>, %output: memref<3x2x4xi32>) { + linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>) + outs(%output : memref<3x2x4xi32>) + return +} + +// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref +// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + +/// Read the whole data in one shot. +// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]] +// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8> + +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xi8> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xi8> + +/// w == 0, kw = +// CHECK: %[[EXT_INPUT_0:.*]] = arith.extsi %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x2x4xi32> +// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8> +// CHECK: %[[EXT_FILTER_0:.*]] = arith.extsi %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x2x4xi32> +// CHECK: %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x2x4xi32> +// CHECK: %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[V_OUTPUT_R]] : vector<3x2x4xi32> + +/// w == 0, kw = 1 +// CHECK: %[[EXT_INPUT_1:.*]] = arith.extsi %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x2x4xi32> +// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8> +// CHECK: %[[EXT_FILTER_1:.*]] = arith.extsi %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x2x4xi32> +// CHECK: %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x2x4xi32> +// CHECK: %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x2x4xi32> + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[ADD_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + // ----- func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {