From: Jakub Kuderski Date: Wed, 4 Jan 2023 18:29:46 +0000 (-0500) Subject: [mlir][spirv] Add pattern to expand UMulExtended for WebGPU X-Git-Tag: upstream/17.0.6~22140 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c957fe0f6088485beea3b03cf4b4084c41226328;p=platform%2Fupstream%2Fllvm.git [mlir][spirv] Add pattern to expand UMulExtended for WebGPU This is needed because WGSL does not yet support extended multiplication ops. Set up pattern/pass stuff and handle the first op: `UMulExtended`. `SMulExtended` handling will go to a separate patch. Issue: https://github.com/llvm/llvm-project/issues/59563 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D140995 --- diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td index 5cd3dac..f24b1ae 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td @@ -46,4 +46,9 @@ def SPIRVUpdateVCE : Pass<"spirv-update-vce", "spirv::ModuleOp"> { let constructor = "mlir::spirv::createUpdateVersionCapabilityExtensionPass()"; } +def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> { + let summary = "Prepare SPIR-V to target WebGPU by expanding unsupported ops " + "and replacing with supported ones"; +} + #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h new file mode 100644 index 0000000..ac4d38e --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h @@ -0,0 +1,30 @@ +//===- SPIRVWebGPUTransforms.h - WebGPU-specific Transforms -*- C++ -----*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines SPIR-V transforms used when targetting WebGPU. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRV_WEBGPU_TRANSFORMS_H +#define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRV_WEBGPU_TRANSFORMS_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace spirv { + +/// Appends to a pattern list additional patterns to expand extended +/// multiplication ops into regular arithmetic ops. Extended multiplication ops +/// are not supported by the WebGPU Shading Language (WGSL). +void populateSPIRVExpandExtendedMultiplicationPatterns( + RewritePatternSet &patterns); + +} // namespace spirv +} // namespace mlir + +#endif // MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRV_WEBGPU_TRANSFORMS_H diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt index 56996e2..821f82e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ set(LLVM_OPTIONAL_SOURCES LowerABIAttributesPass.cpp RewriteInsertsPass.cpp SPIRVConversion.cpp + SPIRVWebGPUTransforms.cpp UnifyAliasedResourcePass.cpp UpdateVCEPass.cpp ) @@ -25,6 +26,7 @@ add_mlir_dialect_library(MLIRSPIRVTransforms DecorateCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp + SPIRVWebGPUTransforms.cpp UnifyAliasedResourcePass.cpp UpdateVCEPass.cpp diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp new file mode 100644 index 0000000..6cf127e --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -0,0 +1,135 @@ +//===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements SPIR-V transforms used when targetting WebGPU. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/FormatVariadic.h" + +namespace mlir { +namespace spirv { +#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS +#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" +} // namespace spirv +} // namespace mlir + +namespace mlir { +namespace spirv { +namespace { +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// +Attribute getScalarOrSplatAttr(Type type, int64_t value) { + APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); + if (auto intTy = type.dyn_cast()) + return IntegerAttr::get(intTy, sizedValue); + + return SplatElementsAttr::get(type, sizedValue); +} + +//===----------------------------------------------------------------------===// +// Rewrite Patterns +//===----------------------------------------------------------------------===// +struct ExpandUMulExtendedPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(UMulExtendedOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value lhs = op.getOperand1(); + Value rhs = op.getOperand2(); + Type argTy = lhs.getType(); + + // Currently, WGSL only supports 32-bit integer types. Any other integer + // types should already have been promoted/demoted to i32. + auto elemTy = getElementTypeOrSelf(argTy).cast(); + if (elemTy.getIntOrFloatBitWidth() != 32) + return rewriter.notifyMatchFailure( + loc, + llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); + + // Calculate the 'low' and the 'high' result separately, using long + // multiplication: + // + // lhs = [0 0] [a b] + // rhs = [0 0] [c d] + // --lhs * rhs-- + // = [ a * c ] [ b * d ] + + // [ 0 ] [a * d + b * c] [ 0 ] + // + // ==> high = (a * c) + (a * d + b * c) >> 16 + Value low = rewriter.create(loc, lhs, rhs); + + Value cstLowMask = rewriter.create( + loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); + auto getLowHalf = [&rewriter, loc, cstLowMask](Value val) { + return rewriter.create(loc, val, cstLowMask); + }; + + Value cst16 = rewriter.create(loc, lhs.getType(), + getScalarOrSplatAttr(argTy, 16)); + auto getHighHalf = [&rewriter, loc, cst16](Value val) { + return rewriter.create(loc, val, cst16); + }; + + Value lhsLow = getLowHalf(lhs); + Value lhsHigh = getHighHalf(lhs); + Value rhsLow = getLowHalf(rhs); + Value rhsHigh = getHighHalf(rhs); + + Value high0 = rewriter.create(loc, lhsHigh, rhsHigh); + Value mid = rewriter.create( + loc, rewriter.create(loc, lhsHigh, rhsLow), + rewriter.create(loc, lhsLow, rhsHigh)); + Value high1 = getHighHalf(mid); + Value high = rewriter.create(loc, high0, high1); + + rewriter.replaceOpWithNewOp( + op, op.getType(), llvm::makeArrayRef({low, high})); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Passes +//===----------------------------------------------------------------------===// +class WebGPUPreparePass + : public impl::SPIRVWebGPUPreparePassBase { +public: + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateSPIRVExpandExtendedMultiplicationPatterns(patterns); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Public Interface +//===----------------------------------------------------------------------===// +void populateSPIRVExpandExtendedMultiplicationPatterns( + RewritePatternSet &patterns) { + // WGSL currently does not support extended multiplication ops, see: + // https://github.com/gpuweb/gpuweb/issues/1565. + // TODO(https://github.com/llvm/llvm-project/issues/59563): Add SMulExtended + // expansion. + patterns.add(patterns.getContext()); +} +} // namespace spirv +} // namespace mlir diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir new file mode 100644 index 0000000..b2f93aa --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt --split-input-file --verify-diagnostics --spirv-webgpu-prepare %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spirv.UMulExtended +//===----------------------------------------------------------------------===// + +spirv.module Logical GLSL450 { + +// CHECK-LABEL: func @umul_extended_i32 +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) +// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant 65535 : i32 +// CHECK-DAG: [[CST16:%.+]] = spirv.Constant 16 : i32 +// CHECK-NEXT: [[RESLOW:%.+]] = spirv.IMul [[ARG0]], [[ARG1]] : i32 +// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : i32 +// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : i32 +// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : i32 +// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : i32 +// CHECK-NEXT: [[RESHI0:%.+]] = spirv.IMul [[LHSHI]], [[RHSHI]] : i32 +// CHECK-NEXT: [[MID0:%.+]] = spirv.IMul [[LHSHI]], [[RHSLOW]] : i32 +// CHECK-NEXT: [[MID1:%.+]] = spirv.IMul [[LHSLOW]], [[RHSHI]] : i32 +// CHECK-NEXT: [[MID:%.+]] = spirv.IAdd [[MID0]], [[MID1]] : i32 +// CHECK-NEXT: [[RESHI1:%.+]] = spirv.ShiftRightLogical [[MID]], [[CST16]] : i32 +// CHECK-NEXT: [[RESHI:%.+]] = spirv.IAdd [[RESHI0]], [[RESHI1]] : i32 +// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW]], [[RESHI]] : (i32, i32) -> !spirv.struct<(i32, i32)> +// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)> +spirv.func @umul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i32)> "None" { + %0 = spirv.UMulExtended %arg0, %arg1 : !spirv.struct<(i32, i32)> + spirv.ReturnValue %0 : !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: func @umul_extended_vector_i32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi32>, [[ARG1:%.+]]: vector<3xi32>) +// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32> +// CHECK-DAG: [[CST16:%.+]] = spirv.Constant dense<16> : vector<3xi32> +// CHECK-NEXT: [[RESLOW:%.+]] = spirv.IMul [[ARG0]], [[ARG1]] : vector<3xi32> +// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : vector<3xi32> +// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : vector<3xi32> +// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : vector<3xi32> +// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : vector<3xi32> +// CHECK-NEXT: [[RESHI0:%.+]] = spirv.IMul [[LHSHI]], [[RHSHI]] : vector<3xi32> +// CHECK-NEXT: [[MID0:%.+]] = spirv.IMul [[LHSHI]], [[RHSLOW]] : vector<3xi32> +// CHECK-NEXT: [[MID1:%.+]] = spirv.IMul [[LHSLOW]], [[RHSHI]] : vector<3xi32> +// CHECK-NEXT: [[MID:%.+]] = spirv.IAdd [[MID0]], [[MID1]] : vector<3xi32> +// CHECK-NEXT: [[RESHI1:%.+]] = spirv.ShiftRightLogical [[MID]], [[CST16]] : vector<3xi32> +// CHECK-NEXT: [[RESHI:%.+]] = spirv.IAdd [[RESHI0]], [[RESHI1]] : vector<3xi32> +// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW]], [[RESHI]] +// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)> +spirv.func @umul_extended_vector_i32(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) + -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" { + %0 = spirv.UMulExtended %arg0, %arg1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> + spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> +} + +// CHECK-LABEL: func @umul_extended_i16 +// CHECK-NEXT: spirv.UMulExtended +// CHECK-NEXT: spirv.ReturnValue +spirv.func @umul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" { + %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(i16, i16)> + spirv.ReturnValue %0 : !spirv.struct<(i16, i16)> +} + +} // end module