[mlir][spirv] Add pattern to expand UMulExtended for WebGPU
authorJakub Kuderski <kubak@google.com>
Wed, 4 Jan 2023 18:29:46 +0000 (13:29 -0500)
committerJakub Kuderski <kubak@google.com>
Wed, 4 Jan 2023 18:29:47 +0000 (13:29 -0500)
This is needed because WGSL does not yet support extended multiplication
ops.

Set up pattern/pass stuff and handle the first op: `UMulExtended`.
`SMulExtended` handling will go to a separate patch.

Issue: https://github.com/llvm/llvm-project/issues/59563

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D140995

mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h [new file with mode: 0644]
mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp [new file with mode: 0644]
mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir [new file with mode: 0644]

index 5cd3dac..f24b1ae 100644 (file)
@@ -46,4 +46,9 @@ def SPIRVUpdateVCE : Pass<"spirv-update-vce", "spirv::ModuleOp"> {
   let constructor = "mlir::spirv::createUpdateVersionCapabilityExtensionPass()";
 }
 
+def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
+  let summary = "Prepare SPIR-V to target WebGPU by expanding unsupported ops "
+                "and replacing with supported ones";
+}
+
 #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
new file mode 100644 (file)
index 0000000..ac4d38e
--- /dev/null
@@ -0,0 +1,30 @@
+//===- SPIRVWebGPUTransforms.h - WebGPU-specific Transforms -*- C++ -----*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines SPIR-V transforms used when targetting WebGPU.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRV_WEBGPU_TRANSFORMS_H
+#define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRV_WEBGPU_TRANSFORMS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace spirv {
+
+/// Appends to a pattern list additional patterns to expand extended
+/// multiplication ops into regular arithmetic ops. Extended multiplication ops
+/// are not supported by the WebGPU Shading Language (WGSL).
+void populateSPIRVExpandExtendedMultiplicationPatterns(
+    RewritePatternSet &patterns);
+
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRV_WEBGPU_TRANSFORMS_H
index 56996e2..821f82e 100644 (file)
@@ -4,6 +4,7 @@ set(LLVM_OPTIONAL_SOURCES
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
   SPIRVConversion.cpp
+  SPIRVWebGPUTransforms.cpp
   UnifyAliasedResourcePass.cpp
   UpdateVCEPass.cpp
 )
@@ -25,6 +26,7 @@ add_mlir_dialect_library(MLIRSPIRVTransforms
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
+  SPIRVWebGPUTransforms.cpp
   UnifyAliasedResourcePass.cpp
   UpdateVCEPass.cpp
 
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
new file mode 100644 (file)
index 0000000..6cf127e
--- /dev/null
@@ -0,0 +1,135 @@
+//===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements SPIR-V transforms used when targetting WebGPU.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace mlir {
+namespace spirv {
+#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
+} // namespace spirv
+} // namespace mlir
+
+namespace mlir {
+namespace spirv {
+namespace {
+//===----------------------------------------------------------------------===//
+// Helpers
+//===----------------------------------------------------------------------===//
+Attribute getScalarOrSplatAttr(Type type, int64_t value) {
+  APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
+  if (auto intTy = type.dyn_cast<IntegerType>())
+    return IntegerAttr::get(intTy, sizedValue);
+
+  return SplatElementsAttr::get(type, sizedValue);
+}
+
+//===----------------------------------------------------------------------===//
+// Rewrite Patterns
+//===----------------------------------------------------------------------===//
+struct ExpandUMulExtendedPattern final : OpRewritePattern<UMulExtendedOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(UMulExtendedOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    Value lhs = op.getOperand1();
+    Value rhs = op.getOperand2();
+    Type argTy = lhs.getType();
+
+    // Currently, WGSL only supports 32-bit integer types. Any other integer
+    // types should already have been promoted/demoted to i32.
+    auto elemTy = getElementTypeOrSelf(argTy).cast<IntegerType>();
+    if (elemTy.getIntOrFloatBitWidth() != 32)
+      return rewriter.notifyMatchFailure(
+          loc,
+          llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
+
+    // Calculate the 'low' and the 'high' result separately, using long
+    // multiplication:
+    //
+    // lhs = [0   0]  [a   b]
+    // rhs = [0   0]  [c   d]
+    // --lhs * rhs--
+    // =     [    a * c    ]   [    b * d    ] +
+    //       [ 0 ]    [a * d + b * c]    [ 0 ]
+    //
+    // ==> high = (a * c) + (a * d + b * c) >> 16
+    Value low = rewriter.create<IMulOp>(loc, lhs, rhs);
+
+    Value cstLowMask = rewriter.create<ConstantOp>(
+        loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
+    auto getLowHalf = [&rewriter, loc, cstLowMask](Value val) {
+      return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
+    };
+
+    Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
+                                              getScalarOrSplatAttr(argTy, 16));
+    auto getHighHalf = [&rewriter, loc, cst16](Value val) {
+      return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
+    };
+
+    Value lhsLow = getLowHalf(lhs);
+    Value lhsHigh = getHighHalf(lhs);
+    Value rhsLow = getLowHalf(rhs);
+    Value rhsHigh = getHighHalf(rhs);
+
+    Value high0 = rewriter.create<IMulOp>(loc, lhsHigh, rhsHigh);
+    Value mid = rewriter.create<IAddOp>(
+        loc, rewriter.create<IMulOp>(loc, lhsHigh, rhsLow),
+        rewriter.create<IMulOp>(loc, lhsLow, rhsHigh));
+    Value high1 = getHighHalf(mid);
+    Value high = rewriter.create<IAddOp>(loc, high0, high1);
+
+    rewriter.replaceOpWithNewOp<CompositeConstructOp>(
+        op, op.getType(), llvm::makeArrayRef({low, high}));
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+class WebGPUPreparePass
+    : public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
+public:
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
+
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Public Interface
+//===----------------------------------------------------------------------===//
+void populateSPIRVExpandExtendedMultiplicationPatterns(
+    RewritePatternSet &patterns) {
+  // WGSL currently does not support extended multiplication ops, see:
+  // https://github.com/gpuweb/gpuweb/issues/1565.
+  // TODO(https://github.com/llvm/llvm-project/issues/59563): Add SMulExtended
+  // expansion.
+  patterns.add<ExpandUMulExtendedPattern>(patterns.getContext());
+}
+} // namespace spirv
+} // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
new file mode 100644 (file)
index 0000000..b2f93aa
--- /dev/null
@@ -0,0 +1,62 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics --spirv-webgpu-prepare %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 {
+
+// CHECK-LABEL: func @umul_extended_i32
+// CHECK-SAME:       ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32)
+// CHECK-DAG:        [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
+// CHECK-DAG:        [[CST16:%.+]]   = spirv.Constant 16 : i32
+// CHECK-NEXT:       [[RESLOW:%.+]]  = spirv.IMul [[ARG0]], [[ARG1]] : i32
+// CHECK-NEXT:       [[LHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : i32
+// CHECK-NEXT:       [[LHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : i32
+// CHECK-NEXT:       [[RHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : i32
+// CHECK-NEXT:       [[RHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : i32
+// CHECK-NEXT:       [[RESHI0:%.+]]  = spirv.IMul [[LHSHI]], [[RHSHI]] : i32
+// CHECK-NEXT:       [[MID0:%.+]]    = spirv.IMul [[LHSHI]], [[RHSLOW]] : i32
+// CHECK-NEXT:       [[MID1:%.+]]    = spirv.IMul [[LHSLOW]], [[RHSHI]] : i32
+// CHECK-NEXT:       [[MID:%.+]]     = spirv.IAdd [[MID0]], [[MID1]] : i32
+// CHECK-NEXT:       [[RESHI1:%.+]]  = spirv.ShiftRightLogical [[MID]], [[CST16]] : i32
+// CHECK-NEXT:       [[RESHI:%.+]]   = spirv.IAdd [[RESHI0]], [[RESHI1]] : i32
+// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[RESLOW]], [[RESHI]] : (i32, i32) -> !spirv.struct<(i32, i32)>
+// CHECK-NEXT:       spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
+spirv.func @umul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i32)> "None" {
+  %0 = spirv.UMulExtended %arg0, %arg1 : !spirv.struct<(i32, i32)>
+  spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: func @umul_extended_vector_i32
+// CHECK-SAME:       ([[ARG0:%.+]]: vector<3xi32>, [[ARG1:%.+]]: vector<3xi32>)
+// CHECK-DAG:        [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
+// CHECK-DAG:        [[CST16:%.+]]   = spirv.Constant dense<16> : vector<3xi32>
+// CHECK-NEXT:       [[RESLOW:%.+]]  = spirv.IMul [[ARG0]], [[ARG1]] : vector<3xi32>
+// CHECK-NEXT:       [[LHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : vector<3xi32>
+// CHECK-NEXT:       [[LHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT:       [[RHSLOW:%.+]]  = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : vector<3xi32>
+// CHECK-NEXT:       [[RHSHI:%.+]]   = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT:       [[RESHI0:%.+]]  = spirv.IMul [[LHSHI]], [[RHSHI]] : vector<3xi32>
+// CHECK-NEXT:       [[MID0:%.+]]    = spirv.IMul [[LHSHI]], [[RHSLOW]] : vector<3xi32>
+// CHECK-NEXT:       [[MID1:%.+]]    = spirv.IMul [[LHSLOW]], [[RHSHI]] : vector<3xi32>
+// CHECK-NEXT:       [[MID:%.+]]     = spirv.IAdd [[MID0]], [[MID1]] : vector<3xi32>
+// CHECK-NEXT:       [[RESHI1:%.+]]  = spirv.ShiftRightLogical [[MID]], [[CST16]] : vector<3xi32>
+// CHECK-NEXT:       [[RESHI:%.+]]   = spirv.IAdd [[RESHI0]], [[RESHI1]] : vector<3xi32>
+// CHECK-NEXT:       [[RES:%.+]]     = spirv.CompositeConstruct [[RESLOW]], [[RESHI]]
+// CHECK-NEXT:       spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+spirv.func @umul_extended_vector_i32(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>)
+  -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
+  %0 = spirv.UMulExtended %arg0, %arg1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+}
+
+// CHECK-LABEL: func @umul_extended_i16
+// CHECK-NEXT:       spirv.UMulExtended
+// CHECK-NEXT:       spirv.ReturnValue
+spirv.func @umul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
+  %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(i16, i16)>
+  spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
+}
+
+} // end module