namespace mlir {
class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating GPU Ops to
-/// SPIR-V ops.
+/// SPIR-V ops. Needs the workgroup size as input since SPIR-V/Vulkan requires
+/// the workgroup size to be statically specified.
void populateGPUToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns,
+ ArrayRef<int64_t> workGroupSize);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H
#ifndef MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
#define MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
+#include "mlir/Support/LLVM.h"
+
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OpPassBase;
-/// Pass to convert GPU Ops to SPIR-V ops.
-std::unique_ptr<OpPassBase<ModuleOp>> createConvertGPUToSPIRVPass();
+/// Pass to convert GPU Ops to SPIR-V ops. Needs the workgroup size as input
+/// since SPIR-V/Vulkan requires the workgroup size to be statically specified.
+std::unique_ptr<OpPassBase<ModuleOp>>
+createConvertGPUToSPIRVPass(ArrayRef<int64_t> workGroupSize);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
MLIRSPIRV
MLIRStandardOps
MLIRStandardToSPIRVTransforms
+ MLIRSupport
MLIRTransforms
)
/// attribute gpu.kernel) within a spv.module.
class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
public:
- using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+ KernelFnConversion(MLIRContext *context, SPIRVTypeConverter &converter,
+ ArrayRef<int64_t> workGroupSize,
+ PatternBenefit benefit = 1)
+ : SPIRVOpLowering<FuncOp>(context, converter, benefit) {
+ auto config = workGroupSize.take_front(3);
+ workGroupSizeAsInt32.assign(config.begin(), config.end());
+ workGroupSizeAsInt32.resize(3, 1);
+ }
PatternMatchResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override;
+
+private:
+ SmallVector<int32_t, 3> workGroupSizeAsInt32;
};
} // namespace
argABI.push_back(spirv::getInterfaceVarABIAttr(
0, argNum, spirv::StorageClass::StorageBuffer, rewriter.getContext()));
}
- // TODO(ravishankarm) : For now set this to {32, 1, 1}. This is incorrect. The
- // actual workgroup size needs to be plumbed through.
+
auto context = rewriter.getContext();
- auto entryPointAttr = spirv::getEntryPointABIAttr({32, 1, 1}, context);
+ auto entryPointAttr =
+ spirv::getEntryPointABIAttr(workGroupSizeAsInt32, context);
FuncOp newFuncOp = spirv::lowerAsEntryFunction(
funcOp, typeConverter, rewriter, argABI, entryPointAttr);
if (!newFuncOp) {
namespace mlir {
void populateGPUToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns,
+ ArrayRef<int64_t> workGroupSize) {
+ patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize);
patterns.insert<
- ForOpConversion, KernelFnConversion,
+ ForOpConversion,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
using namespace mlir;
///
/// 2) Lower the body of the spirv::ModuleOp.
class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+public:
+ GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize)
+ : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
void runOnModule() override;
+
+private:
+ SmallVector<int64_t, 3> workGroupSize;
+};
+
+/// Command line option to specify the workgroup size.
+struct GPUToSPIRVPassOptions : public PassOptions<GPUToSPIRVPassOptions> {
+ List<unsigned> workGroupSize{
+ *this, "workgroup-size",
+ llvm::cl::desc(
+ "Workgroup Sizes in the SPIR-V module for x, followed by y, followed "
+ "by z dimension of the dispatch (others will be ignored)"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
};
} // namespace
/// Dialect conversion to lower the functions with the spirv::ModuleOps.
SPIRVTypeConverter typeConverter;
OwningRewritePatternList patterns;
- populateGPUToSPIRVPatterns(context, typeConverter, patterns);
+ populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
ConversionTarget target(*context);
}
}
-std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertGPUToSPIRVPass() {
- return std::make_unique<GPUToSPIRVPass>();
+std::unique_ptr<OpPassBase<ModuleOp>>
+mlir::createConvertGPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) {
+ return std::make_unique<GPUToSPIRVPass>(workGroupSize);
}
-static PassRegistration<GPUToSPIRVPass>
- pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
+static PassRegistration<GPUToSPIRVPass, GPUToSPIRVPassOptions>
+ pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect",
+ [](const GPUToSPIRVPassOptions &passOptions) {
+ SmallVector<int64_t, 3> workGroupSize;
+ workGroupSize.assign(passOptions.workGroupSize.begin(),
+ passOptions.workGroupSize.end());
+ return std::make_unique<GPUToSPIRVPass>(workGroupSize);
+ });
-// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -pass-pipeline='convert-gpu-to-spirv{workgroup-size=32,4}' %s -o - | FileCheck %s
module attributes {gpu.container_module} {
// CHECK-LABEL: func @kernel_1
// CHECK-SAME: {{%.*}}: f32 {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: spirv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32>)
attributes { gpu.kernel } {
// CHECK: spv.Return