From: Mahesh Ravishankar Date: Tue, 27 Aug 2019 17:49:53 +0000 (-0700) Subject: Enhance GPU To SPIR-V conversion to support builtins and load/store ops. X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4ced99c085e62f681276a06e0ce842b190900d32;p=platform%2Fupstream%2Fllvm.git Enhance GPU To SPIR-V conversion to support builtins and load/store ops. To support a conversion of a simple load-compute-store kernel from GPU dialect to SPIR-V dialect, the conversion of operations like "gpu.block_dim", "gpu.thread_id" which allow threads to get the launch conversion is needed. In SPIR-V these are specified as global variables with builin attributes. This CL adds support to specify builtin variables in SPIR-V conversion framework. This is used to convert the relevant operations from GPU dialect to SPIR-V dialect. Also add support for conversion of load/store operation in Standard dialect to SPIR-V dialect. To simplify the conversion add a method to build a spv.AccessChain operation that automatically determines the return type based on the base pointer type and the indices provided. PiperOrigin-RevId: 265718525 --- diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h index adfd83b..25a710f 100644 --- a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h @@ -24,10 +24,15 @@ #ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H #define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Support/StringExtras.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { +class LoadOp; +class ReturnOp; +class StoreOp; namespace spirv { class SPIRVDialect; } @@ -63,7 +68,7 @@ public: LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) override; - /// Get the basic type converter. + /// Gets the basic type converter. SPIRVBasicTypeConverter *getBasicTypeConverter() const { return basicTypeConverter; } @@ -80,17 +85,98 @@ public: typeConverter(typeConverter) {} protected: - // Type lowering class. + /// Gets the global variable associated with a builtin and add + /// it if it doesnt exist. + Value *loadFromBuiltinVariable(Operation *op, spirv::BuiltIn builtin, + ConversionPatternRewriter &rewriter) const { + auto moduleOp = op->getParentOfType(); + if (!moduleOp) { + op->emitError("expected operation to be within a SPIR-V module"); + return nullptr; + } + auto varOp = + getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, rewriter); + auto ptr = rewriter + .create(op->getLoc(), varOp.type(), + rewriter.getSymbolRefAttr(varOp)) + .pointer(); + return rewriter.create( + op->getLoc(), + ptr->getType().template cast().getPointeeType(), + ptr, /*memory_access =*/nullptr, /*alignment =*/nullptr); + } + + /// Type lowering class. SPIRVTypeConverter &typeConverter; + +private: + /// Look through all global variables in `moduleOp` and check if there is a + /// spv.globalVariable that has the same `builtin` attribute. + spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp, + spirv::BuiltIn builtin) const { + for (auto varOp : moduleOp.getBlock().getOps()) { + if (auto builtinAttr = varOp.getAttrOfType(convertToSnakeCase( + stringifyDecoration(spirv::Decoration::BuiltIn)))) { + auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); + if (varBuiltIn && varBuiltIn.getValue() == builtin) { + return varOp; + } + } + } + return nullptr; + } + + /// Gets name of global variable for a buitlin. + std::string getBuiltinVarName(spirv::BuiltIn builtin) const { + return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + + "__"; + } + + /// Gets or inserts a global variable for a builtin within a module. + spirv::GlobalVariableOp + getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc, + spirv::BuiltIn builtin, + ConversionPatternRewriter &builder) const { + if (auto varOp = getBuiltinVariable(moduleOp, builtin)) { + return varOp; + } + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(&moduleOp.getBlock()); + auto name = getBuiltinVarName(builtin); + spirv::GlobalVariableOp newVarOp; + switch (builtin) { + case spirv::BuiltIn::NumWorkgroups: + case spirv::BuiltIn::WorkgroupSize: + case spirv::BuiltIn::WorkgroupId: + case spirv::BuiltIn::LocalInvocationId: + case spirv::BuiltIn::GlobalInvocationId: { + auto ptrType = spirv::PointerType::get( + builder.getVectorType({3}, builder.getIntegerType(32)), + spirv::StorageClass::Input); + newVarOp = builder.create( + loc, builder.getTypeAttr(ptrType), builder.getStringAttr(name), + nullptr); + newVarOp.setAttr( + convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)), + builder.getStringAttr(stringifyBuiltIn(builtin))); + break; + } + default: + emitError(loc, "unimplemented builtin variable generation for ") + << stringifyBuiltIn(builtin); + } + builder.restoreInsertionPoint(ip); + return newVarOp; + } }; -/// Method to legalize a function as a non-entry function. +/// Legalizes a function as a non-entry function. LogicalResult lowerFunction(FuncOp funcOp, ArrayRef operands, SPIRVTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, FuncOp &newFuncOp); -/// Method to legalize a function as an entry function. +/// Legalizes a function as an entry function. LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef operands, SPIRVTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 5fccf1b..6aad600 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -113,6 +113,9 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> { let results = (outs SPV_AnyPtr:$component_ptr ); + + let builders = [OpBuilder<[{Builder *builder, OperationState *state, + Value *basePtr, ArrayRef indices}]>]; } // ----- diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index ff6af83..06b2498 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -1,4 +1,4 @@ -//===- GPUToSPIRV.cp - MLIR SPIR-V lowering passes ------------------------===// +//===- GPUToSPIRV.cpp - MLIR SPIR-V lowering passes -----------------------===// // // Copyright 2019 The MLIR Authors. // @@ -29,6 +29,18 @@ using namespace mlir; namespace { +/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation +/// builin variables. +template +class LaunchConfigConversion : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Pattern to convert a kernel function in GPU dialect (a FuncOp with the /// attribute gpu.kernel) within a spv.module. class KernelFnConversion final : public SPIRVOpLowering { @@ -41,6 +53,33 @@ public: }; } // namespace +template +PatternMatchResult LaunchConfigConversion::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto dimAttr = op->getAttrOfType("dimension"); + if (!dimAttr) { + return this->matchFailure(); + } + int32_t index = 0; + if (dimAttr.getValue() == "x") { + index = 0; + } else if (dimAttr.getValue() == "y") { + index = 1; + } else if (dimAttr.getValue() == "z") { + index = 2; + } else { + return this->matchFailure(); + } + + // SPIR-V invocation builtin variables are a vector of type <3xi32> + auto spirvBuiltin = this->loadFromBuiltinVariable(op, builtin, rewriter); + rewriter.replaceOpWithNewOp( + op, rewriter.getIntegerType(32), spirvBuiltin, + rewriter.getI32ArrayAttr({index})); + return this->matchSuccess(); +} + PatternMatchResult KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -106,7 +145,13 @@ void GPUToSPIRVPass::runOnModule() { SPIRVBasicTypeConverter basicTypeConverter(context); SPIRVTypeConverter typeConverter(&basicTypeConverter); OwningRewritePatternList patterns; - patterns.insert(context, typeConverter); + patterns.insert< + KernelFnConversion, + LaunchConfigConversion, + LaunchConfigConversion, + LaunchConfigConversion, + LaunchConfigConversion>( + context, typeConverter); populateStandardToSPIRVPatterns(context, patterns); ConversionTarget target(*context); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index b7dfff4..e3bcc04 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -39,15 +39,22 @@ Type SPIRVBasicTypeConverter::convertType(Type t) { return t; } + if (auto indexType = t.dyn_cast()) { + // Return I32 for index types. + return IntegerType::get(32, t.getContext()); + } + if (auto memRefType = t.dyn_cast()) { if (memRefType.hasStaticShape()) { - // Convert MemrefType to spv.array if size is known. + // Convert MemrefType to a multi-dimensional spv.array if size is known. + auto elementType = memRefType.getElementType(); + for (auto size : reverse(memRefType.getShape())) { + elementType = spirv::ArrayType::get(elementType, size); + } // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need // to support other Storage Classes. - return spirv::PointerType::get( - spirv::ArrayType::get(memRefType.getElementType(), - memRefType.getNumElements()), - spirv::StorageClass::StorageBuffer); + return spirv::PointerType::get(elementType, + spirv::StorageClass::StorageBuffer); } } return Type(); @@ -68,8 +75,12 @@ SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type, if (!convertedType) return failure(); // For arguments to entry functions, convert the type into a pointer type if - // it is already not one. - if (!convertedType.isa()) { + // it is already not one, unless the original type was an index type. + // TODO(ravishankarm): For arguments that are of index type, keep the + // arguments as the scalar converted type, i.e. i32. These are still not + // handled effectively. These are potentially best handled as specialization + // constants. + if (!convertedType.isa() && !type.isa()) { // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need // to support other Storage classes. convertedType = spirv::PointerType::get(convertedType, @@ -143,29 +154,40 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef operands, if (!module) { return funcOp.emitError("expected op to be within a spv.module"); } - OpBuilder builder(module.getOperation()->getRegion(0)); + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(&module.getBlock()); SmallVector interface; for (auto &convertedArgType : llvm::enumerate(signatureConverter.getConvertedTypes())) { + // TODO(ravishankarm) : The arguments to the converted function are either + // spirv::PointerType or i32 type, the latter due to conversion of index + // type to i32. Eventually entry function should be of signature + // void(void). Arguments converted to spirv::PointerType, will be made + // variables and those converted to i32 will be made specialization + // constants. Latter is not implemented. + if (!convertedArgType.value().isa()) { + continue; + } std::string varName = funcOp.getName().str() + "_arg_" + std::to_string(convertedArgType.index()); - auto variableOp = builder.create( - funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()), - builder.getStringAttr(varName), nullptr); - variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0)); + auto variableOp = rewriter.create( + funcOp.getLoc(), rewriter.getTypeAttr(convertedArgType.value()), + rewriter.getStringAttr(varName), nullptr); + variableOp.setAttr("descriptor_set", rewriter.getI32IntegerAttr(0)); variableOp.setAttr("binding", - builder.getI32IntegerAttr(convertedArgType.index())); - interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name())); + rewriter.getI32IntegerAttr(convertedArgType.index())); + interface.push_back(rewriter.getSymbolRefAttr(variableOp.sym_name())); } // Create an entry point instruction for this function. // TODO(ravishankarm) : Add execution mode for the entry function - builder.setInsertionPoint(&(module.getBlock().back())); - builder.create( + rewriter.setInsertionPoint(&(module.getBlock().back())); + rewriter.create( funcOp.getLoc(), - builder.getI32IntegerAttr( + rewriter.getI32IntegerAttr( static_cast(spirv::ExecutionModel::GLCompute)), - builder.getSymbolRefAttr(newFuncOp.getName()), - builder.getArrayAttr(interface)); + rewriter.getSymbolRefAttr(newFuncOp.getName()), + rewriter.getArrayAttr(interface)); + rewriter.restoreInsertionPoint(ip); return success(); } } // namespace mlir @@ -175,6 +197,56 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef operands, //===----------------------------------------------------------------------===// namespace { + +/// Convert integer binary operations to SPIR-V operations. Cannot use tablegen +/// for this. If the integer operation is on variables of IndexType, the type of +/// the return value of the replacement operation differs from that of the +/// replaced operation. This is not handled in tablegen-based pattern +/// specification. +template +class IntegerOpConversion final : public ConversionPattern { +public: + IntegerOpConversion(MLIRContext *context) + : ConversionPattern(StdOp::getOperationName(), 1, context) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.template replaceOpWithNewOp( + op, operands[0]->getType(), operands, ArrayRef()); + return this->matchSuccess(); + } +}; + +/// Convert load -> spv.LoadOp. The operands of the replaced operation are of +/// IndexType while that of the replacement operation are of type i32. This is +/// not suppored in tablegen based pattern specification. +// TODO(ravishankarm) : These could potentially be templated on the operation +// being converted, since the same logic should work for linalg.load. +class LoadOpConversion final : public ConversionPattern { +public: + LoadOpConversion(MLIRContext *context) + : ConversionPattern(LoadOp::getOperationName(), 1, context) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + LoadOpOperandAdaptor loadOperands(operands); + auto basePtr = loadOperands.memref(); + auto ptrType = basePtr->getType().dyn_cast(); + if (!ptrType) { + return matchFailure(); + } + auto loadPtr = rewriter.create( + op->getLoc(), basePtr, loadOperands.indices()); + auto loadPtrType = loadPtr.getType().cast(); + rewriter.replaceOpWithNewOp( + op, loadPtrType.getPointeeType(), loadPtr, /*memory_access =*/nullptr, + /*alignment =*/nullptr); + return matchSuccess(); + } +}; + /// Convert return -> spv.Return. class ReturnToSPIRVConversion : public ConversionPattern { public: @@ -191,6 +263,35 @@ public: } }; +/// Convert store -> spv.StoreOp. The operands of the replaced operation are of +/// IndexType while that of the replacement operation are of type i32. This is +/// not suppored in tablegen based pattern specification. +// TODO(ravishankarm) : These could potentially be templated on the operation +// being converted, since the same logic should work for linalg.store. +class StoreOpConversion final : public ConversionPattern { +public: + StoreOpConversion(MLIRContext *context) + : ConversionPattern(StoreOp::getOperationName(), 1, context) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + StoreOpOperandAdaptor storeOperands(operands); + auto value = storeOperands.value(); + auto basePtr = storeOperands.memref(); + auto ptrType = basePtr->getType().dyn_cast(); + if (!ptrType) { + return matchFailure(); + } + auto storePtr = rewriter.create( + op->getLoc(), basePtr, storeOperands.indices()); + rewriter.replaceOpWithNewOp(op, storePtr, value, + /*memory_access =*/nullptr, + /*alignment =*/nullptr); + return matchSuccess(); + } +}; + } // namespace namespace { @@ -203,6 +304,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context, OwningRewritePatternList &patterns) { populateWithGenerated(context, &patterns); // Add the return op conversion. - patterns.insert(context); + patterns.insert, + IntegerOpConversion, LoadOpConversion, + ReturnToSPIRVConversion, StoreOpConversion>(context); } } // namespace mlir diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td index 4cfd559..b37eee8 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td @@ -43,6 +43,7 @@ multiclass BinaryOpPattern { } } +defm : BinaryOpPattern; defm : BinaryOpPattern; #endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index fef9c0b..aaa7ed5 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -316,7 +316,7 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter *printer, static Type getElementPtrType(Type type, ArrayRef indices, Location baseLoc) { - if (!indices.size()) { + if (indices.empty()) { emitError(baseLoc, "'spv.AccessChain' op expected at least " "one index "); return nullptr; @@ -372,6 +372,13 @@ static Type getElementPtrType(Type type, ArrayRef indices, return spirv::PointerType::get(resultType, resultStorageClass); } +void spirv::AccessChainOp::build(Builder *builder, OperationState *state, + Value *basePtr, ArrayRef indices) { + auto type = getElementPtrType(basePtr->getType(), indices, state->location); + assert(type && "Unable to deduce return type based on basePtr and indices"); + build(builder, state, type, basePtr, indices); +} + static ParseResult parseAccessChainOp(OpAsmParser *parser, OperationState *state) { OpAsmParser::OperandType ptrInfo; diff --git a/mlir/test/Conversion/GPUToSPIRV/builtins.mlir b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir new file mode 100644 index 0000000..ce9421e --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir @@ -0,0 +1,113 @@ +// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv %s -o - | FileCheck %s + +func @builtin() { + %c0 = constant 1 : index + "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_id_x} : (index, index, index, index, index, index) -> () + return +} + +// CHECK-LABEL: spv.module "Logical" "VulkanKHR" +// CHECK: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId") +func @builtin_workgroup_id_x() + attributes {gpu.kernel} { + // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPID]] + // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]] + // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}} + %0 = "gpu.block_id"() {dimension = "x"} : () -> index + return +} + +// ----- + +func @builtin() { + %c0 = constant 1 : index + "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_id_y} : (index, index, index, index, index, index) -> () + return +} + +// CHECK-LABEL: spv.module "Logical" "VulkanKHR" +// CHECK: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId") +func @builtin_workgroup_id_y() + attributes {gpu.kernel} { + // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPID]] + // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]] + // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}1 : i32{{\]}} + %0 = "gpu.block_id"() {dimension = "y"} : () -> index + return +} + +// ----- + +func @builtin() { + %c0 = constant 1 : index + "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_id_z} : (index, index, index, index, index, index) -> () + return +} + +// CHECK-LABEL: spv.module "Logical" "VulkanKHR" +// CHECK: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId") +func @builtin_workgroup_id_z() + attributes {gpu.kernel} { + // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPID]] + // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]] + // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}2 : i32{{\]}} + %0 = "gpu.block_id"() {dimension = "z"} : () -> index + return +} + +// ----- + +func @builtin() { + %c0 = constant 1 : index + "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_size_x} : (index, index, index, index, index, index) -> () + return +} + +// CHECK-LABEL: spv.module "Logical" "VulkanKHR" +// CHECK: spv.globalVariable [[WORKGROUPSIZE:@.*]] built_in("WorkgroupSize") +func @builtin_workgroup_size_x() + attributes {gpu.kernel} { + // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPSIZE]] + // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]] + // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}} + %0 = "gpu.block_dim"() {dimension = "x"} : () -> index + return +} + +// ----- + +func @builtin() { + %c0 = constant 1 : index + "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_local_id_x} : (index, index, index, index, index, index) -> () + return +} + +// CHECK-LABEL: spv.module "Logical" "VulkanKHR" +// CHECK: spv.globalVariable [[LOCALINVOCATIONID:@.*]] built_in("LocalInvocationId") +func @builtin_local_id_x() + attributes {gpu.kernel} { + // CHECK: [[ADDRESS:%.*]] = spv._address_of [[LOCALINVOCATIONID]] + // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]] + // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}} + %0 = "gpu.thread_id"() {dimension = "x"} : () -> index + return +} + +// ----- + +func @builtin() { + %c0 = constant 1 : index + "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_num_workgroups_x} : (index, index, index, index, index, index) -> () + return +} + +// CHECK-LABEL: spv.module "Logical" "VulkanKHR" +// CHECK: spv.globalVariable [[NUMWORKGROUPS:@.*]] built_in("NumWorkgroups") +func @builtin_num_workgroups_x() + attributes {gpu.kernel} { + // CHECK: [[ADDRESS:%.*]] = spv._address_of [[NUMWORKGROUPS]] + // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]] + // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}} + %0 = "gpu.grid_dim"() {dimension = "x"} : () -> index + return +} diff --git a/mlir/test/Conversion/GPUToSPIRV/load_store.mlir b/mlir/test/Conversion/GPUToSPIRV/load_store.mlir new file mode 100644 index 0000000..cc8ed07 --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/load_store.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s + +func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) { + %c0 = constant 0 : index + %c12 = constant 12 : index + %0 = subi %c12, %c0 : index + %c1 = constant 1 : index + %c0_0 = constant 0 : index + %c4 = constant 4 : index + %1 = subi %c4, %c0_0 : index + %c1_1 = constant 1 : index + %c1_2 = constant 1 : index + "gpu.launch_func"(%0, %c1_2, %c1_2, %1, %c1_2, %c1_2, %arg0, %arg1, %arg2, %c0, %c0_0, %c1, %c1_1) {kernel = @load_store_kernel} : (index, index, index, index, index, index, memref<12x4xf32>, memref<12x4xf32>, memref<12x4xf32>, index, index, index, index) -> () + return +} + +// CHECK-LABEL: spv.module "Logical" "VulkanKHR" +// CHECK: spv.globalVariable {{@.*}} bind(0, 0) : [[TYPE1:!spv.ptr>, StorageBuffer>]] +// CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 1) : [[TYPE2:!spv.ptr>, StorageBuffer>]] +// CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 2) : [[TYPE3:!spv.ptr>, StorageBuffer>]] +// CHECK: func @load_store_kernel([[ARG0:%.*]]: [[TYPE1]], [[ARG1:%.*]]: [[TYPE2]], [[ARG2:%.*]]: [[TYPE3]], [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: i32, [[ARG6:%.*]]: i32) +func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) + attributes {gpu.kernel} { + %0 = "gpu.block_id"() {dimension = "x"} : () -> index + %1 = "gpu.block_id"() {dimension = "y"} : () -> index + %2 = "gpu.block_id"() {dimension = "z"} : () -> index + %3 = "gpu.thread_id"() {dimension = "x"} : () -> index + %4 = "gpu.thread_id"() {dimension = "y"} : () -> index + %5 = "gpu.thread_id"() {dimension = "z"} : () -> index + %6 = "gpu.grid_dim"() {dimension = "x"} : () -> index + %7 = "gpu.grid_dim"() {dimension = "y"} : () -> index + %8 = "gpu.grid_dim"() {dimension = "z"} : () -> index + %9 = "gpu.block_dim"() {dimension = "x"} : () -> index + %10 = "gpu.block_dim"() {dimension = "y"} : () -> index + %11 = "gpu.block_dim"() {dimension = "z"} : () -> index + // CHECK: [[INDEX1:%.*]] = spv.IAdd [[ARG3]], {{%.*}} + %12 = addi %arg3, %0 : index + // CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], {{%.*}} + %13 = addi %arg4, %3 : index + // CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} + // CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]] + %14 = load %arg0[%12, %13] : memref<12x4xf32> + // CHECK: [[PTR2:%.*]] = spv.AccessChain [[ARG1]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} + // CHECK-NEXT: [[VAL2:%.*]] = spv.Load "StorageBuffer" [[PTR2]] + %15 = load %arg1[%12, %13] : memref<12x4xf32> + // CHECK: [[VAL3:%.*]] = spv.FAdd [[VAL1]], [[VAL2]] + %16 = addf %14, %15 : f32 + // CHECK: [[PTR3:%.*]] = spv.AccessChain [[ARG2]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} + // CHECK-NEXT: spv.Store "StorageBuffer" [[PTR3]], [[VAL3]] + store %16, %arg2[%12, %13] : memref<12x4xf32> + return +} \ No newline at end of file