From: Jakub Kuderski Date: Fri, 6 Jan 2023 01:11:46 +0000 (-0500) Subject: [mlir][spirv] Add smul_extended expansion for WebGPU X-Git-Tag: upstream/17.0.6~21971 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1b8224537070337c2d983a204a08eb27bac1ded6;p=platform%2Fupstream%2Fllvm.git [mlir][spirv] Add smul_extended expansion for WebGPU We need this because WGSL does not support extended multiplication ops. Fixes: https://github.com/llvm/llvm-project/issues/59563 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D141096 --- diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp index 5f8426b..1ed71db 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -15,7 +15,9 @@ #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -45,90 +47,126 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) { return SplatElementsAttr::get(type, sizedValue); } +Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter, + Value lhs, Value rhs, + bool signExtendArguments) { + Location loc = mulOp->getLoc(); + Type argTy = lhs.getType(); + // Emulate 64-bit multiplication by splitting each input element of type i32 + // into 2 16-bit digits of type i32. This is so that the intermediate + // multiplications and additions do not overflow. We extract these 16-bit + // digits from i32 vector elements by masking (low digit) and shifting right + // (high digit). + // + // The multiplication algorithm used is the standard (long) multiplication. + // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit + // digits. + // - With zero-extended arguments, we end up emitting only 4 multiplications + // and 4 additions after constant folding. + // - With sign-extended arguments, we end up emitting 8 multiplications and + // and 12 additions after CSE. + Value cstLowMask = rewriter.create( + loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); + auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) { + return rewriter.create(loc, val, cstLowMask); + }; + + Value cst16 = rewriter.create(loc, lhs.getType(), + getScalarOrSplatAttr(argTy, 16)); + auto getHighDigit = [&rewriter, loc, cst16](Value val) { + return rewriter.create(loc, val, cst16); + }; + + auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) { + // We only need to shift arithmetically by 15, but the extra + // sign-extension bit will be truncated by the logical shift, so this is + // fine. We do not have to introduce an extra constant since any + // value in [15, 32) would do. + return getHighDigit( + rewriter.create(loc, val, cst16)); + }; + + Value cst0 = rewriter.create(loc, lhs.getType(), + getScalarOrSplatAttr(argTy, 0)); + + Value lhsLow = getLowDigit(lhs); + Value lhsHigh = getHighDigit(lhs); + Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0; + Value rhsLow = getLowDigit(rhs); + Value rhsHigh = getHighDigit(rhs); + Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0; + + std::array lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt}; + std::array rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt}; + std::array resultDigits = {cst0, cst0, cst0, cst0}; + + for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) { + for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) { + if (i + j >= resultDigits.size()) + continue; + + if (lhsDigit == cst0 || rhsDigit == cst0) + continue; + + Value &thisResDigit = resultDigits[i + j]; + Value mul = rewriter.create(loc, lhsDigit, rhsDigit); + Value current = rewriter.createOrFold(loc, thisResDigit, mul); + thisResDigit = getLowDigit(current); + + if (i + j + 1 != resultDigits.size()) { + Value &nextResDigit = resultDigits[i + j + 1]; + Value carry = rewriter.createOrFold(loc, nextResDigit, + getHighDigit(current)); + nextResDigit = carry; + } + } + } + + auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) { + Value highBits = rewriter.create(loc, high, cst16); + return rewriter.create(loc, low, highBits); + }; + Value low = combineDigits(resultDigits[0], resultDigits[1]); + Value high = combineDigits(resultDigits[2], resultDigits[3]); + + return rewriter.create( + loc, mulOp->getResultTypes().front(), llvm::makeArrayRef({low, high})); +} + //===----------------------------------------------------------------------===// // Rewrite Patterns //===----------------------------------------------------------------------===// -struct ExpandUMulExtendedPattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(UMulExtendedOp op, +template +struct ExpandMulExtendedPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MulExtendedOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value lhs = op.getOperand1(); Value rhs = op.getOperand2(); - Type argTy = lhs.getType(); // Currently, WGSL only supports 32-bit integer types. Any other integer // types should already have been promoted/demoted to i32. - auto elemTy = getElementTypeOrSelf(argTy).cast(); + auto elemTy = getElementTypeOrSelf(lhs.getType()).cast(); if (elemTy.getIntOrFloatBitWidth() != 32) return rewriter.notifyMatchFailure( loc, llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); - // Emulate 64-bit multiplication by splitting each input element of type i32 - // into 2 16-bit digits of type i32. This is so that the intermediate - // multiplications and additions do not overflow. We extract these 16-bit - // digits from i32 vector elements by masking (low digit) and shifting right - // (high digit). - // - // The multiplication algorithm used is the standard (long) multiplication. - // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit - // digits. After constant-folding, we end up emitting only 4 multiplications - // and 4 additions. - Value cstLowMask = rewriter.create( - loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); - auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) { - return rewriter.create(loc, val, cstLowMask); - }; - - Value cst16 = rewriter.create(loc, lhs.getType(), - getScalarOrSplatAttr(argTy, 16)); - auto getHighDigit = [&rewriter, loc, cst16](Value val) { - return rewriter.create(loc, val, cst16); - }; - - Value cst0 = rewriter.create(loc, lhs.getType(), - getScalarOrSplatAttr(argTy, 0)); - - Value lhsLow = getLowDigit(lhs); - Value lhsHigh = getHighDigit(lhs); - Value rhsLow = getLowDigit(rhs); - Value rhsHigh = getHighDigit(rhs); - - std::array lhsDigits = {lhsLow, lhsHigh}; - std::array rhsDigits = {rhsLow, rhsHigh}; - std::array resultDigits = {cst0, cst0, cst0, cst0}; - - for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) { - for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) { - Value &thisResDigit = resultDigits[i + j]; - Value mul = rewriter.create(loc, lhsDigit, rhsDigit); - Value current = rewriter.createOrFold(loc, thisResDigit, mul); - thisResDigit = getLowDigit(current); - - if (i + j + 1 != resultDigits.size()) { - Value &nextResDigit = resultDigits[i + j + 1]; - Value carry = rewriter.createOrFold(loc, nextResDigit, - getHighDigit(current)); - nextResDigit = carry; - } - } - } - - auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) { - Value highBits = rewriter.create(loc, high, cst16); - return rewriter.create(loc, low, highBits); - }; - Value low = combineDigits(resultDigits[0], resultDigits[1]); - Value high = combineDigits(resultDigits[2], resultDigits[3]); - - rewriter.replaceOpWithNewOp( - op, op.getType(), llvm::makeArrayRef({low, high})); + Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs, + SignExtendArguments); + rewriter.replaceOp(op, mul); return success(); } }; +using ExpandSMulExtendedPattern = + ExpandMulExtendedPattern; +using ExpandUMulExtendedPattern = + ExpandMulExtendedPattern; + //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// @@ -153,9 +191,8 @@ void populateSPIRVExpandExtendedMultiplicationPatterns( RewritePatternSet &patterns) { // WGSL currently does not support extended multiplication ops, see: // https://github.com/gpuweb/gpuweb/issues/1565. - // TODO(https://github.com/llvm/llvm-project/issues/59563): Add SMulExtended - // expansion. - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } } // namespace spirv } // namespace mlir diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir index d0720a3..91eeeda 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir @@ -70,4 +70,79 @@ spirv.func @umul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" { spirv.ReturnValue %0 : !spirv.struct<(i16, i16)> } +//===----------------------------------------------------------------------===// +// spirv.SMulExtended +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @smul_extended_i32 +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) +// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant 65535 : i32 +// CHECK-DAG: [[CST16:%.+]] = spirv.Constant 16 : i32 +// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : i32 +// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : i32 +// CHECK-NEXT: [[LHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG0]], [[CST16]] : i32 +// CHECK-NEXT: [[LHSEXT:%.+]] = spirv.ShiftRightLogical [[LHSSIGN]], [[CST16]] : i32 +// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : i32 +// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : i32 +// CHECK-NEXT: [[RHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG1]], [[CST16]] : i32 +// CHECK-NEXT: [[RHSEXT:%.+]] = spirv.ShiftRightLogical [[RHSSIGN]], [[CST16]] : i32 +// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]] +// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]] +// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSEXT]] +// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]] +// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]] +// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSEXT]] +// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSLOW]] +// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSHI]] +// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32 +// CHECK: spirv.BitwiseOr +// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32 +// CHECK: spirv.BitwiseOr +// CHECK: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)> +// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)> +spirv.func @smul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i32)> "None" { + %0 = spirv.SMulExtended %arg0, %arg1 : !spirv.struct<(i32, i32)> + spirv.ReturnValue %0 : !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: func @smul_extended_vector_i32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi32>, [[ARG1:%.+]]: vector<3xi32>) +// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32> +// CHECK-DAG: [[CST16:%.+]] = spirv.Constant dense<16> : vector<3xi32> +// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : vector<3xi32> +// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : vector<3xi32> +// CHECK-NEXT: [[LHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG0]], [[CST16]] : vector<3xi32> +// CHECK-NEXT: [[LHSEXT:%.+]] = spirv.ShiftRightLogical [[LHSSIGN]], [[CST16]] : vector<3xi32> +// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : vector<3xi32> +// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : vector<3xi32> +// CHECK-NEXT: [[RHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG1]], [[CST16]] : vector<3xi32> +// CHECK-NEXT: [[RHSEXT:%.+]] = spirv.ShiftRightLogical [[RHSSIGN]], [[CST16]] : vector<3xi32> +// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]] +// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]] +// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSEXT]] +// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]] +// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]] +// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSEXT]] +// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSLOW]] +// CHECK-DAG: spirv.IMul [[LHSEXT]], [[RHSHI]] +// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] +// CHECK: spirv.BitwiseOr +// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] +// CHECK: spirv.BitwiseOr +// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW:%.+]], [[RESHI:%.+]] +// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)> +spirv.func @smul_extended_vector_i32(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) + -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" { + %0 = spirv.SMulExtended %arg0, %arg1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> + spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> +} + +// CHECK-LABEL: func @smul_extended_i16 +// CHECK-NEXT: spirv.SMulExtended +// CHECK-NEXT: spirv.ReturnValue +spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" { + %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i16, i16)> + spirv.ReturnValue %0 : !spirv.struct<(i16, i16)> +} + } // end module diff --git a/mlir/test/mlir-vulkan-runner/smul_extended.mlir b/mlir/test/mlir-vulkan-runner/smul_extended.mlir new file mode 100644 index 0000000..32ad477 --- /dev/null +++ b/mlir/test/mlir-vulkan-runner/smul_extended.mlir @@ -0,0 +1,66 @@ +// Make sure that signed extended multiplication produces expected results +// with and without expansion to primitive mul/add ops for WebGPU. + +// RUN: mlir-vulkan-runner %s \ +// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: --entry-point-result=void | FileCheck %s + +// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \ +// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: --entry-point-result=void | FileCheck %s + +// CHECK: [0, 1, -2, 1, 1048560, -87620295, -131071, 560969770] +// CHECK: [0, 0, -1, 0, 0, -1, 0, -499807318] +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + gpu.module @kernels { + gpu.func @kernel_add(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %arg2 : memref<8xi32>, %arg3 : memref<8xi32>) + kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi} { + %0 = gpu.block_id x + %lhs = memref.load %arg0[%0] : memref<8xi32> + %rhs = memref.load %arg1[%0] : memref<8xi32> + %low, %hi = arith.mulsi_extended %lhs, %rhs : i32 + memref.store %low, %arg2[%0] : memref<8xi32> + memref.store %hi, %arg3[%0] : memref<8xi32> + gpu.return + } + } + + func.func @main() { + %buf0 = memref.alloc() : memref<8xi32> + %buf1 = memref.alloc() : memref<8xi32> + %buf2 = memref.alloc() : memref<8xi32> + %buf3 = memref.alloc() : memref<8xi32> + %i32_0 = arith.constant 0 : i32 + + // Initialize output buffers. + %buf4 = memref.cast %buf2 : memref<8xi32> to memref + %buf5 = memref.cast %buf3 : memref<8xi32> to memref + call @fillResource1DInt(%buf4, %i32_0) : (memref, i32) -> () + call @fillResource1DInt(%buf5, %i32_0) : (memref, i32) -> () + + %idx_0 = arith.constant 0 : index + %idx_1 = arith.constant 1 : index + %idx_8 = arith.constant 8 : index + + // Initialize input buffers. + %lhs_vals = arith.constant dense<[0, 1, -1, -1, 65535, 65535, -65535, 2088183954]> : vector<8xi32> + %rhs_vals = arith.constant dense<[0, 1, 2, -1, 16, -1337, -65535, -1028001427]> : vector<8xi32> + vector.store %lhs_vals, %buf0[%idx_0] : memref<8xi32>, vector<8xi32> + vector.store %rhs_vals, %buf1[%idx_0] : memref<8xi32>, vector<8xi32> + + gpu.launch_func @kernels::@kernel_add + blocks in (%idx_8, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1) + args(%buf0 : memref<8xi32>, %buf1 : memref<8xi32>, %buf2 : memref<8xi32>, %buf3 : memref<8xi32>) + %buf_low = memref.cast %buf4 : memref to memref<*xi32> + %buf_hi = memref.cast %buf5 : memref to memref<*xi32> + call @printMemrefI32(%buf_low) : (memref<*xi32>) -> () + call @printMemrefI32(%buf_hi) : (memref<*xi32>) -> () + return + } + func.func private @fillResource1DInt(%0 : memref, %1 : i32) + func.func private @printMemrefI32(%ptr : memref<*xi32>) +}