From 5bb8b0fe7dfae791d0542fb8e7e33ab40b74e285 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 2 Sep 2022 17:14:53 -0400 Subject: [PATCH] [mlir][spirv] Add support for converting gpu.shuffle xor Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D133054 --- .../Dialect/SPIRV/Transforms/SPIRVConversion.h | 4 +- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 52 +++++++++++++++++++++- .../Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 4 -- mlir/test/Conversion/GPUToSPIRV/shuffle.mlir | 48 ++++++++++++++++++++ 4 files changed, 102 insertions(+), 6 deletions(-) create mode 100644 mlir/test/Conversion/GPUToSPIRV/shuffle.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index d0f09ea..11928c9 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -69,8 +69,10 @@ public: /// Gets the SPIR-V correspondence for the standard index type. Type getIndexType() const; + const spirv::TargetEnv &getTargetEnv() const { return targetEnv; } + /// Returns the options controlling the SPIR-V type converter. - const Options &getOptions() const; + const Options &getOptions() const { return options; } private: spirv::TargetEnv targetEnv; diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 0e0083b60..2b83894 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -12,12 +12,14 @@ #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -120,6 +122,16 @@ public: ConversionPatternRewriter &rewriter) const override; }; +/// Pattern to convert a gpu.shuffle op into a spv.GroupNonUniformShuffle op. +class GPUShuffleConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -363,6 +375,44 @@ LogicalResult GPUBarrierConversion::matchAndRewrite( } //===----------------------------------------------------------------------===// +// Shuffle +//===----------------------------------------------------------------------===// + +LogicalResult GPUShuffleConversion::matchAndRewrite( + gpu::ShuffleOp shuffleOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Require the shuffle width to be the same as the target's subgroup size, + // given that for SPIR-V non-uniform subgroup ops, we cannot select + // participating invocations. + auto targetEnv = getTypeConverter()->getTargetEnv(); + unsigned subgroupSize = + targetEnv.getAttr().getResourceLimits().getSubgroupSize(); + IntegerAttr widthAttr; + if (!matchPattern(shuffleOp.width(), m_Constant(&widthAttr)) || + widthAttr.getValue().getZExtValue() != subgroupSize) + return rewriter.notifyMatchFailure( + shuffleOp, "shuffle width and target subgroup size mismatch"); + + Location loc = shuffleOp.getLoc(); + Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), + shuffleOp.getLoc(), rewriter); + auto scope = rewriter.getAttr(spirv::Scope::Subgroup); + Value result; + + switch (shuffleOp.mode()) { + case gpu::ShuffleMode::XOR: + result = rewriter.create( + loc, scope, adaptor.value(), adaptor.offset()); + break; + default: + return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode"); + } + + rewriter.replaceOp(shuffleOp, {result, trueVal}); + return success(); +} + +//===----------------------------------------------------------------------===// // GPU To SPIRV Patterns. //===----------------------------------------------------------------------===// @@ -370,7 +420,7 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion, - GPUModuleEndConversion, GPUReturnOpConversion, + GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion, diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 9154c81..c2e2286 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -118,10 +118,6 @@ Type SPIRVTypeConverter::getIndexType() const { return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32); } -const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const { - return options; -} - MLIRContext *SPIRVTypeConverter::getContext() const { return targetEnv.getAttr().getContext(); } diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir new file mode 100644 index 0000000..6a7b38c --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s + +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits> +} { + +gpu.module @kernels { + // CHECK-LABEL: spv.func @shuffle_xor() + gpu.func @shuffle_xor() kernel + attributes {spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>>} { + %mask = arith.constant 8 : i32 + %width = arith.constant 16 : i32 + %val = arith.constant 42.0 : f32 + + // CHECK: %[[MASK:.+]] = spv.Constant 8 : i32 + // CHECK: %[[VAL:.+]] = spv.Constant 4.200000e+01 : f32 + // CHECK: %{{.+}} = spv.Constant true + // CHECK: %{{.+}} = spv.GroupNonUniformShuffleXor %[[VAL]], %[[MASK]] : f32, i32 + %result, %valid = gpu.shuffle xor %val, %mask, %width : f32 + gpu.return + } +} + +} + +// ----- + +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits> +} { + +gpu.module @kernels { + gpu.func @shuffle_xor() kernel + attributes {spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>>} { + %mask = arith.constant 8 : i32 + %width = arith.constant 16 : i32 + %val = arith.constant 42.0 : f32 + + // Cannot convert due to shuffle width and target subgroup size mismatch + // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}} + %result, %valid = gpu.shuffle xor %val, %mask, %width : f32 + gpu.return + } +} + +} -- 2.7.4