[mlir][spirv] Add smul_extended expansion for WebGPU
authorJakub Kuderski <kubak@google.com>
Fri, 6 Jan 2023 01:11:46 +0000 (20:11 -0500)
committerJakub Kuderski <kubak@google.com>
Fri, 6 Jan 2023 01:11:47 +0000 (20:11 -0500)
We need this because WGSL does not support extended multiplication ops.

Fixes: https://github.com/llvm/llvm-project/issues/59563

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D141096

mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
mlir/test/mlir-vulkan-runner/smul_extended.mlir [new file with mode: 0644]

index 5f8426b..1ed71db 100644 (file)
@@ -15,7 +15,9 @@
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -45,90 +47,126 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
   return SplatElementsAttr::get(type, sizedValue);
 }
 
+Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
+                                  Value lhs, Value rhs,
+                                  bool signExtendArguments) {
+  Location loc = mulOp->getLoc();
+  Type argTy = lhs.getType();
+  // Emulate 64-bit multiplication by splitting each input element of type i32
+  // into 2 16-bit digits of type i32. This is so that the intermediate
+  // multiplications and additions do not overflow. We extract these 16-bit
+  // digits from i32 vector elements by masking (low digit) and shifting right
+  // (high digit).
+  //
+  // The multiplication algorithm used is the standard (long) multiplication.
+  // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
+  // digits.
+  //   - With zero-extended arguments, we end up emitting only 4 multiplications
+  //     and 4 additions after constant folding.
+  //   - With sign-extended arguments, we end up emitting 8 multiplications and
+  //     and 12 additions after CSE.
+  Value cstLowMask = rewriter.create<ConstantOp>(
+      loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
+  auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
+    return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
+  };
+
+  Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
+                                            getScalarOrSplatAttr(argTy, 16));
+  auto getHighDigit = [&rewriter, loc, cst16](Value val) {
+    return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
+  };
+
+  auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
+    // We only need to shift arithmetically by 15, but the extra
+    // sign-extension bit will be truncated by the logical shift, so this is
+    // fine. We do not have to introduce an extra constant since any
+    // value in [15, 32) would do.
+    return getHighDigit(
+        rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
+  };
+
+  Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
+                                           getScalarOrSplatAttr(argTy, 0));
+
+  Value lhsLow = getLowDigit(lhs);
+  Value lhsHigh = getHighDigit(lhs);
+  Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
+  Value rhsLow = getLowDigit(rhs);
+  Value rhsHigh = getHighDigit(rhs);
+  Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
+
+  std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
+  std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
+  std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
+
+  for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
+    for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
+      if (i + j >= resultDigits.size())
+        continue;
+
+      if (lhsDigit == cst0 || rhsDigit == cst0)
+        continue;
+
+      Value &thisResDigit = resultDigits[i + j];
+      Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
+      Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
+      thisResDigit = getLowDigit(current);
+
+      if (i + j + 1 != resultDigits.size()) {
+        Value &nextResDigit = resultDigits[i + j + 1];
+        Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
+                                                    getHighDigit(current));
+        nextResDigit = carry;
+      }
+    }
+  }
+
+  auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
+    Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
+    return rewriter.create<BitwiseOrOp>(loc, low, highBits);
+  };
+  Value low = combineDigits(resultDigits[0], resultDigits[1]);
+  Value high = combineDigits(resultDigits[2], resultDigits[3]);
+
+  return rewriter.create<CompositeConstructOp>(
+      loc, mulOp->getResultTypes().front(), llvm::makeArrayRef({low, high}));
+}
+
 //===----------------------------------------------------------------------===//
 // Rewrite Patterns
 //===----------------------------------------------------------------------===//
-struct ExpandUMulExtendedPattern final : OpRewritePattern<UMulExtendedOp> {
-  using OpRewritePattern::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(UMulExtendedOp op,
+template <typename MulExtendedOp, bool SignExtendArguments>
+struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
+  using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MulExtendedOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
     Value lhs = op.getOperand1();
     Value rhs = op.getOperand2();
-    Type argTy = lhs.getType();
 
     // Currently, WGSL only supports 32-bit integer types. Any other integer
     // types should already have been promoted/demoted to i32.
-    auto elemTy = getElementTypeOrSelf(argTy).cast<IntegerType>();
+    auto elemTy = getElementTypeOrSelf(lhs.getType()).cast<IntegerType>();
     if (elemTy.getIntOrFloatBitWidth() != 32)
       return rewriter.notifyMatchFailure(
           loc,
           llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
 
-    // Emulate 64-bit multiplication by splitting each input element of type i32
-    // into 2 16-bit digits of type i32. This is so that the intermediate
-    // multiplications and additions do not overflow. We extract these 16-bit
-    // digits from i32 vector elements by masking (low digit) and shifting right
-    // (high digit).
-    //
-    // The multiplication algorithm used is the standard (long) multiplication.
-    // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
-    // digits. After constant-folding, we end up emitting only 4 multiplications
-    // and 4 additions.
-    Value cstLowMask = rewriter.create<ConstantOp>(
-        loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
-    auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
-      return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
-    };
-
-    Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
-                                              getScalarOrSplatAttr(argTy, 16));
-    auto getHighDigit = [&rewriter, loc, cst16](Value val) {
-      return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
-    };
-
-    Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
-                                             getScalarOrSplatAttr(argTy, 0));
-
-    Value lhsLow = getLowDigit(lhs);
-    Value lhsHigh = getHighDigit(lhs);
-    Value rhsLow = getLowDigit(rhs);
-    Value rhsHigh = getHighDigit(rhs);
-
-    std::array<Value, 2> lhsDigits = {lhsLow, lhsHigh};
-    std::array<Value, 2> rhsDigits = {rhsLow, rhsHigh};
-    std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
-
-    for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
-      for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
-        Value &thisResDigit = resultDigits[i + j];
-        Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
-        Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
-        thisResDigit = getLowDigit(current);
-
-        if (i + j + 1 != resultDigits.size()) {
-          Value &nextResDigit = resultDigits[i + j + 1];
-          Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
-                                                      getHighDigit(current));
-          nextResDigit = carry;
-        }
-      }
-    }
-
-    auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
-      Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
-      return rewriter.create<BitwiseOrOp>(loc, low, highBits);
-    };
-    Value low = combineDigits(resultDigits[0], resultDigits[1]);
-    Value high = combineDigits(resultDigits[2], resultDigits[3]);
-
-    rewriter.replaceOpWithNewOp<CompositeConstructOp>(
-        op, op.getType(), llvm::makeArrayRef({low, high}));
+    Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
+                                            SignExtendArguments);
+    rewriter.replaceOp(op, mul);
     return success();
   }
 };
 
+using ExpandSMulExtendedPattern =
+    ExpandMulExtendedPattern<SMulExtendedOp, true>;
+using ExpandUMulExtendedPattern =
+    ExpandMulExtendedPattern<UMulExtendedOp, false>;
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
@@ -153,9 +191,8 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
     RewritePatternSet &patterns) {
   // WGSL currently does not support extended multiplication ops, see:
   // https://github.com/gpuweb/gpuweb/issues/1565.
-  // TODO(https://github.com/llvm/llvm-project/issues/59563): Add SMulExtended
-  // expansion.
-  patterns.add<ExpandUMulExtendedPattern>(patterns.getContext());
+  patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
+      patterns.getContext());
 }
 } // namespace spirv
 } // namespace mlir
index d0720a3..91eeeda 100644 (file)
@@ -70,4 +70,79 @@ spirv.func @umul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
   spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @smul_extended_i32
+// CHECK-SAME:       ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32)
+// CHECK-DAG:        [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
+// CHECK-DAG:        [[CST16:%.+]]   = spirv.Constant 16 : i32
+// CHECK-NEXT:       [[LHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : i32
+// CHECK-NEXT:       [[LHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : i32
+// CHECK-NEXT:       [[LHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG0]], [[CST16]] : i32
+// CHECK-NEXT:       [[LHSEXT:%.+]]  = spirv.ShiftRightLogical [[LHSSIGN]], [[CST16]] : i32
+// CHECK-NEXT:       [[RHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : i32
+// CHECK-NEXT:       [[RHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : i32
+// CHECK-NEXT:       [[RHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG1]], [[CST16]] : i32
+// CHECK-NEXT:       [[RHSEXT:%.+]]  = spirv.ShiftRightLogical [[RHSSIGN]], [[CST16]] : i32
+// CHECK-DAG:                          spirv.IMul [[LHSLOW]], [[RHSLOW]]
+// CHECK-DAG:                          spirv.IMul [[LHSLOW]], [[RHSHI]]
+// CHECK-DAG:                          spirv.IMul [[LHSLOW]], [[RHSEXT]]
+// CHECK-DAG:                          spirv.IMul [[LHSHI]],  [[RHSLOW]]
+// CHECK-DAG:                          spirv.IMul [[LHSHI]],  [[RHSHI]]
+// CHECK-DAG:                          spirv.IMul [[LHSHI]],  [[RHSEXT]]
+// CHECK-DAG:                          spirv.IMul [[LHSEXT]], [[RHSLOW]]
+// CHECK-DAG:                          spirv.IMul [[LHSEXT]], [[RHSHI]]
+// CHECK:                              spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
+// CHECK:                              spirv.BitwiseOr
+// CHECK:                              spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
+// CHECK:                              spirv.BitwiseOr
+// CHECK:            [[RES:%.+]]     = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
+// CHECK-NEXT:       spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
+spirv.func @smul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i32)> "None" {
+  %0 = spirv.SMulExtended %arg0, %arg1 : !spirv.struct<(i32, i32)>
+  spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: func @smul_extended_vector_i32
+// CHECK-SAME:       ([[ARG0:%.+]]: vector<3xi32>, [[ARG1:%.+]]: vector<3xi32>)
+// CHECK-DAG:        [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
+// CHECK-DAG:        [[CST16:%.+]]   = spirv.Constant dense<16> : vector<3xi32>
+// CHECK-NEXT:       [[LHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : vector<3xi32>
+// CHECK-NEXT:       [[LHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT:       [[LHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG0]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT:       [[LHSEXT:%.+]]  = spirv.ShiftRightLogical [[LHSSIGN]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT:       [[RHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : vector<3xi32>
+// CHECK-NEXT:       [[RHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT:       [[RHSSIGN:%.+]] = spirv.ShiftRightArithmetic [[ARG1]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT:       [[RHSEXT:%.+]]  = spirv.ShiftRightLogical [[RHSSIGN]], [[CST16]] : vector<3xi32>
+// CHECK-DAG:                          spirv.IMul [[LHSLOW]], [[RHSLOW]]
+// CHECK-DAG:                          spirv.IMul [[LHSLOW]], [[RHSHI]]
+// CHECK-DAG:                          spirv.IMul [[LHSLOW]], [[RHSEXT]]
+// CHECK-DAG:                          spirv.IMul [[LHSHI]],  [[RHSLOW]]
+// CHECK-DAG:                          spirv.IMul [[LHSHI]],  [[RHSHI]]
+// CHECK-DAG:                          spirv.IMul [[LHSHI]],  [[RHSEXT]]
+// CHECK-DAG:                          spirv.IMul [[LHSEXT]], [[RHSLOW]]
+// CHECK-DAG:                          spirv.IMul [[LHSEXT]], [[RHSHI]]
+// CHECK:                              spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
+// CHECK:                              spirv.BitwiseOr
+// CHECK:                              spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
+// CHECK:                              spirv.BitwiseOr
+// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[RESLOW:%.+]], [[RESHI:%.+]]
+// CHECK-NEXT:       spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+spirv.func @smul_extended_vector_i32(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>)
+  -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
+  %0 = spirv.SMulExtended %arg0, %arg1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// CHECK-LABEL: func @smul_extended_i16
+// CHECK-NEXT:       spirv.SMulExtended
+// CHECK-NEXT:       spirv.ReturnValue
+spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
+  %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i16, i16)>
+  spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
+}
+
 } // end module
diff --git a/mlir/test/mlir-vulkan-runner/smul_extended.mlir b/mlir/test/mlir-vulkan-runner/smul_extended.mlir
new file mode 100644 (file)
index 0000000..32ad477
--- /dev/null
@@ -0,0 +1,66 @@
+// Make sure that signed extended multiplication produces expected results
+// with and without expansion to primitive mul/add ops for WebGPU.
+
+// RUN: mlir-vulkan-runner %s \
+// RUN:  --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
+// RUN:  --entry-point-result=void | FileCheck %s
+
+// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \
+// RUN:  --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
+// RUN:  --entry-point-result=void | FileCheck %s
+
+// CHECK: [0, 1, -2,  1, 1048560, -87620295, -131071,  560969770]
+// CHECK: [0, 0, -1,  0,       0,        -1,       0, -499807318]
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+  gpu.module @kernels {
+    gpu.func @kernel_add(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %arg2 : memref<8xi32>, %arg3 : memref<8xi32>)
+      kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+      %0 = gpu.block_id x
+      %lhs = memref.load %arg0[%0] : memref<8xi32>
+      %rhs = memref.load %arg1[%0] : memref<8xi32>
+      %low, %hi = arith.mulsi_extended %lhs, %rhs : i32
+      memref.store %low, %arg2[%0] : memref<8xi32>
+      memref.store %hi, %arg3[%0] : memref<8xi32>
+      gpu.return
+    }
+  }
+
+  func.func @main() {
+    %buf0 = memref.alloc() : memref<8xi32>
+    %buf1 = memref.alloc() : memref<8xi32>
+    %buf2 = memref.alloc() : memref<8xi32>
+    %buf3 = memref.alloc() : memref<8xi32>
+    %i32_0 = arith.constant 0 : i32
+
+    // Initialize output buffers.
+    %buf4 = memref.cast %buf2 : memref<8xi32> to memref<?xi32>
+    %buf5 = memref.cast %buf3 : memref<8xi32> to memref<?xi32>
+    call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()
+    call @fillResource1DInt(%buf5, %i32_0) : (memref<?xi32>, i32) -> ()
+
+    %idx_0 = arith.constant 0 : index
+    %idx_1 = arith.constant 1 : index
+    %idx_8 = arith.constant 8 : index
+
+    // Initialize input buffers.
+    %lhs_vals = arith.constant dense<[0, 1, -1,  -1,  65535,  65535, -65535,  2088183954]> : vector<8xi32>
+    %rhs_vals = arith.constant dense<[0, 1,  2,  -1,     16,  -1337, -65535, -1028001427]> : vector<8xi32>
+    vector.store %lhs_vals, %buf0[%idx_0] : memref<8xi32>, vector<8xi32>
+    vector.store %rhs_vals, %buf1[%idx_0] : memref<8xi32>, vector<8xi32>
+
+    gpu.launch_func @kernels::@kernel_add
+        blocks in (%idx_8, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)
+        args(%buf0 : memref<8xi32>, %buf1 : memref<8xi32>, %buf2 : memref<8xi32>, %buf3 : memref<8xi32>)
+    %buf_low = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
+    %buf_hi = memref.cast %buf5 : memref<?xi32> to memref<*xi32>
+    call @printMemrefI32(%buf_low) : (memref<*xi32>) -> ()
+    call @printMemrefI32(%buf_hi) : (memref<*xi32>) -> ()
+    return
+  }
+  func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
+  func.func private @printMemrefI32(%ptr : memref<*xi32>)
+}