#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
#define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Support/StringExtras.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
+class LoadOp;
+class ReturnOp;
+class StoreOp;
namespace spirv {
class SPIRVDialect;
}
LogicalResult convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result) override;
- /// Get the basic type converter.
+ /// Gets the basic type converter.
SPIRVBasicTypeConverter *getBasicTypeConverter() const {
return basicTypeConverter;
}
typeConverter(typeConverter) {}
protected:
- // Type lowering class.
+ /// Gets the global variable associated with a builtin and add
+ /// it if it doesnt exist.
+ Value *loadFromBuiltinVariable(Operation *op, spirv::BuiltIn builtin,
+ ConversionPatternRewriter &rewriter) const {
+ auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
+ if (!moduleOp) {
+ op->emitError("expected operation to be within a SPIR-V module");
+ return nullptr;
+ }
+ auto varOp =
+ getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, rewriter);
+ auto ptr = rewriter
+ .create<spirv::AddressOfOp>(op->getLoc(), varOp.type(),
+ rewriter.getSymbolRefAttr(varOp))
+ .pointer();
+ return rewriter.create<spirv::LoadOp>(
+ op->getLoc(),
+ ptr->getType().template cast<spirv::PointerType>().getPointeeType(),
+ ptr, /*memory_access =*/nullptr, /*alignment =*/nullptr);
+ }
+
+ /// Type lowering class.
SPIRVTypeConverter &typeConverter;
+
+private:
+ /// Look through all global variables in `moduleOp` and check if there is a
+ /// spv.globalVariable that has the same `builtin` attribute.
+ spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
+ spirv::BuiltIn builtin) const {
+ for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
+ if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(convertToSnakeCase(
+ stringifyDecoration(spirv::Decoration::BuiltIn)))) {
+ auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
+ if (varBuiltIn && varBuiltIn.getValue() == builtin) {
+ return varOp;
+ }
+ }
+ }
+ return nullptr;
+ }
+
+ /// Gets name of global variable for a buitlin.
+ std::string getBuiltinVarName(spirv::BuiltIn builtin) const {
+ return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() +
+ "__";
+ }
+
+ /// Gets or inserts a global variable for a builtin within a module.
+ spirv::GlobalVariableOp
+ getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
+ spirv::BuiltIn builtin,
+ ConversionPatternRewriter &builder) const {
+ if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
+ return varOp;
+ }
+ auto ip = builder.saveInsertionPoint();
+ builder.setInsertionPointToStart(&moduleOp.getBlock());
+ auto name = getBuiltinVarName(builtin);
+ spirv::GlobalVariableOp newVarOp;
+ switch (builtin) {
+ case spirv::BuiltIn::NumWorkgroups:
+ case spirv::BuiltIn::WorkgroupSize:
+ case spirv::BuiltIn::WorkgroupId:
+ case spirv::BuiltIn::LocalInvocationId:
+ case spirv::BuiltIn::GlobalInvocationId: {
+ auto ptrType = spirv::PointerType::get(
+ builder.getVectorType({3}, builder.getIntegerType(32)),
+ spirv::StorageClass::Input);
+ newVarOp = builder.create<spirv::GlobalVariableOp>(
+ loc, builder.getTypeAttr(ptrType), builder.getStringAttr(name),
+ nullptr);
+ newVarOp.setAttr(
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)),
+ builder.getStringAttr(stringifyBuiltIn(builtin)));
+ break;
+ }
+ default:
+ emitError(loc, "unimplemented builtin variable generation for ")
+ << stringifyBuiltIn(builtin);
+ }
+ builder.restoreInsertionPoint(ip);
+ return newVarOp;
+ }
};
-/// Method to legalize a function as a non-entry function.
+/// Legalizes a function as a non-entry function.
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp);
-/// Method to legalize a function as an entry function.
+/// Legalizes a function as an entry function.
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
let results = (outs
SPV_AnyPtr:$component_ptr
);
+
+ let builders = [OpBuilder<[{Builder *builder, OperationState *state,
+ Value *basePtr, ArrayRef<Value *> indices}]>];
}
// -----
-//===- GPUToSPIRV.cp - MLIR SPIR-V lowering passes ------------------------===//
+//===- GPUToSPIRV.cpp - MLIR SPIR-V lowering passes -----------------------===//
//
// Copyright 2019 The MLIR Authors.
//
namespace {
+/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
+/// builin variables.
+template <typename OpTy, spirv::BuiltIn builtin>
+class LaunchConfigConversion : public SPIRVOpLowering<OpTy> {
+public:
+ using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
/// attribute gpu.kernel) within a spv.module.
class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
};
} // namespace
+template <typename OpTy, spirv::BuiltIn builtin>
+PatternMatchResult LaunchConfigConversion<OpTy, builtin>::matchAndRewrite(
+ Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const {
+ auto dimAttr = op->getAttrOfType<StringAttr>("dimension");
+ if (!dimAttr) {
+ return this->matchFailure();
+ }
+ int32_t index = 0;
+ if (dimAttr.getValue() == "x") {
+ index = 0;
+ } else if (dimAttr.getValue() == "y") {
+ index = 1;
+ } else if (dimAttr.getValue() == "z") {
+ index = 2;
+ } else {
+ return this->matchFailure();
+ }
+
+ // SPIR-V invocation builtin variables are a vector of type <3xi32>
+ auto spirvBuiltin = this->loadFromBuiltinVariable(op, builtin, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ op, rewriter.getIntegerType(32), spirvBuiltin,
+ rewriter.getI32ArrayAttr({index}));
+ return this->matchSuccess();
+}
+
PatternMatchResult
KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
SPIRVBasicTypeConverter basicTypeConverter(context);
SPIRVTypeConverter typeConverter(&basicTypeConverter);
OwningRewritePatternList patterns;
- patterns.insert<KernelFnConversion>(context, typeConverter);
+ patterns.insert<
+ KernelFnConversion,
+ LaunchConfigConversion<gpu::BlockDim, spirv::BuiltIn::WorkgroupSize>,
+ LaunchConfigConversion<gpu::BlockId, spirv::BuiltIn::WorkgroupId>,
+ LaunchConfigConversion<gpu::GridDim, spirv::BuiltIn::NumWorkgroups>,
+ LaunchConfigConversion<gpu::ThreadId, spirv::BuiltIn::LocalInvocationId>>(
+ context, typeConverter);
populateStandardToSPIRVPatterns(context, patterns);
ConversionTarget target(*context);
return t;
}
+ if (auto indexType = t.dyn_cast<IndexType>()) {
+ // Return I32 for index types.
+ return IntegerType::get(32, t.getContext());
+ }
+
if (auto memRefType = t.dyn_cast<MemRefType>()) {
if (memRefType.hasStaticShape()) {
- // Convert MemrefType to spv.array if size is known.
+ // Convert MemrefType to a multi-dimensional spv.array if size is known.
+ auto elementType = memRefType.getElementType();
+ for (auto size : reverse(memRefType.getShape())) {
+ elementType = spirv::ArrayType::get(elementType, size);
+ }
// TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
// to support other Storage Classes.
- return spirv::PointerType::get(
- spirv::ArrayType::get(memRefType.getElementType(),
- memRefType.getNumElements()),
- spirv::StorageClass::StorageBuffer);
+ return spirv::PointerType::get(elementType,
+ spirv::StorageClass::StorageBuffer);
}
}
return Type();
if (!convertedType)
return failure();
// For arguments to entry functions, convert the type into a pointer type if
- // it is already not one.
- if (!convertedType.isa<spirv::PointerType>()) {
+ // it is already not one, unless the original type was an index type.
+ // TODO(ravishankarm): For arguments that are of index type, keep the
+ // arguments as the scalar converted type, i.e. i32. These are still not
+ // handled effectively. These are potentially best handled as specialization
+ // constants.
+ if (!convertedType.isa<spirv::PointerType>() && !type.isa<IndexType>()) {
// TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
// to support other Storage classes.
convertedType = spirv::PointerType::get(convertedType,
if (!module) {
return funcOp.emitError("expected op to be within a spv.module");
}
- OpBuilder builder(module.getOperation()->getRegion(0));
+ auto ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPointToStart(&module.getBlock());
SmallVector<Attribute, 4> interface;
for (auto &convertedArgType :
llvm::enumerate(signatureConverter.getConvertedTypes())) {
+ // TODO(ravishankarm) : The arguments to the converted function are either
+ // spirv::PointerType or i32 type, the latter due to conversion of index
+ // type to i32. Eventually entry function should be of signature
+ // void(void). Arguments converted to spirv::PointerType, will be made
+ // variables and those converted to i32 will be made specialization
+ // constants. Latter is not implemented.
+ if (!convertedArgType.value().isa<spirv::PointerType>()) {
+ continue;
+ }
std::string varName = funcOp.getName().str() + "_arg_" +
std::to_string(convertedArgType.index());
- auto variableOp = builder.create<spirv::GlobalVariableOp>(
- funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()),
- builder.getStringAttr(varName), nullptr);
- variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
+ auto variableOp = rewriter.create<spirv::GlobalVariableOp>(
+ funcOp.getLoc(), rewriter.getTypeAttr(convertedArgType.value()),
+ rewriter.getStringAttr(varName), nullptr);
+ variableOp.setAttr("descriptor_set", rewriter.getI32IntegerAttr(0));
variableOp.setAttr("binding",
- builder.getI32IntegerAttr(convertedArgType.index()));
- interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name()));
+ rewriter.getI32IntegerAttr(convertedArgType.index()));
+ interface.push_back(rewriter.getSymbolRefAttr(variableOp.sym_name()));
}
// Create an entry point instruction for this function.
// TODO(ravishankarm) : Add execution mode for the entry function
- builder.setInsertionPoint(&(module.getBlock().back()));
- builder.create<spirv::EntryPointOp>(
+ rewriter.setInsertionPoint(&(module.getBlock().back()));
+ rewriter.create<spirv::EntryPointOp>(
funcOp.getLoc(),
- builder.getI32IntegerAttr(
+ rewriter.getI32IntegerAttr(
static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
- builder.getSymbolRefAttr(newFuncOp.getName()),
- builder.getArrayAttr(interface));
+ rewriter.getSymbolRefAttr(newFuncOp.getName()),
+ rewriter.getArrayAttr(interface));
+ rewriter.restoreInsertionPoint(ip);
return success();
}
} // namespace mlir
//===----------------------------------------------------------------------===//
namespace {
+
+/// Convert integer binary operations to SPIR-V operations. Cannot use tablegen
+/// for this. If the integer operation is on variables of IndexType, the type of
+/// the return value of the replacement operation differs from that of the
+/// replaced operation. This is not handled in tablegen-based pattern
+/// specification.
+template <typename StdOp, typename SPIRVOp>
+class IntegerOpConversion final : public ConversionPattern {
+public:
+ IntegerOpConversion(MLIRContext *context)
+ : ConversionPattern(StdOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.template replaceOpWithNewOp<SPIRVOp>(
+ op, operands[0]->getType(), operands, ArrayRef<NamedAttribute>());
+ return this->matchSuccess();
+ }
+};
+
+/// Convert load -> spv.LoadOp. The operands of the replaced operation are of
+/// IndexType while that of the replacement operation are of type i32. This is
+/// not suppored in tablegen based pattern specification.
+// TODO(ravishankarm) : These could potentially be templated on the operation
+// being converted, since the same logic should work for linalg.load.
+class LoadOpConversion final : public ConversionPattern {
+public:
+ LoadOpConversion(MLIRContext *context)
+ : ConversionPattern(LoadOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ LoadOpOperandAdaptor loadOperands(operands);
+ auto basePtr = loadOperands.memref();
+ auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ return matchFailure();
+ }
+ auto loadPtr = rewriter.create<spirv::AccessChainOp>(
+ op->getLoc(), basePtr, loadOperands.indices());
+ auto loadPtrType = loadPtr.getType().cast<spirv::PointerType>();
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(
+ op, loadPtrType.getPointeeType(), loadPtr, /*memory_access =*/nullptr,
+ /*alignment =*/nullptr);
+ return matchSuccess();
+ }
+};
+
/// Convert return -> spv.Return.
class ReturnToSPIRVConversion : public ConversionPattern {
public:
}
};
+/// Convert store -> spv.StoreOp. The operands of the replaced operation are of
+/// IndexType while that of the replacement operation are of type i32. This is
+/// not suppored in tablegen based pattern specification.
+// TODO(ravishankarm) : These could potentially be templated on the operation
+// being converted, since the same logic should work for linalg.store.
+class StoreOpConversion final : public ConversionPattern {
+public:
+ StoreOpConversion(MLIRContext *context)
+ : ConversionPattern(StoreOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ StoreOpOperandAdaptor storeOperands(operands);
+ auto value = storeOperands.value();
+ auto basePtr = storeOperands.memref();
+ auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ return matchFailure();
+ }
+ auto storePtr = rewriter.create<spirv::AccessChainOp>(
+ op->getLoc(), basePtr, storeOperands.indices());
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(op, storePtr, value,
+ /*memory_access =*/nullptr,
+ /*alignment =*/nullptr);
+ return matchSuccess();
+ }
+};
+
} // namespace
namespace {
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
// Add the return op conversion.
- patterns.insert<ReturnToSPIRVConversion>(context);
+ patterns.insert<IntegerOpConversion<AddIOp, spirv::IAddOp>,
+ IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion,
+ ReturnToSPIRVConversion, StoreOpConversion>(context);
}
} // namespace mlir
}
}
+defm : BinaryOpPattern<AddFOp, SPV_FAddOp>;
defm : BinaryOpPattern<MulFOp, SPV_FMulOp>;
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD
static Type getElementPtrType(Type type, ArrayRef<Value *> indices,
Location baseLoc) {
- if (!indices.size()) {
+ if (indices.empty()) {
emitError(baseLoc, "'spv.AccessChain' op expected at least "
"one index ");
return nullptr;
return spirv::PointerType::get(resultType, resultStorageClass);
}
+void spirv::AccessChainOp::build(Builder *builder, OperationState *state,
+ Value *basePtr, ArrayRef<Value *> indices) {
+ auto type = getElementPtrType(basePtr->getType(), indices, state->location);
+ assert(type && "Unable to deduce return type based on basePtr and indices");
+ build(builder, state, type, basePtr, indices);
+}
+
static ParseResult parseAccessChainOp(OpAsmParser *parser,
OperationState *state) {
OpAsmParser::OperandType ptrInfo;
--- /dev/null
+// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv %s -o - | FileCheck %s
+
+func @builtin() {
+ %c0 = constant 1 : index
+ "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_id_x} : (index, index, index, index, index, index) -> ()
+ return
+}
+
+// CHECK-LABEL: spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId")
+func @builtin_workgroup_id_x()
+ attributes {gpu.kernel} {
+ // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPID]]
+ // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+ // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}}
+ %0 = "gpu.block_id"() {dimension = "x"} : () -> index
+ return
+}
+
+// -----
+
+func @builtin() {
+ %c0 = constant 1 : index
+ "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_id_y} : (index, index, index, index, index, index) -> ()
+ return
+}
+
+// CHECK-LABEL: spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId")
+func @builtin_workgroup_id_y()
+ attributes {gpu.kernel} {
+ // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPID]]
+ // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+ // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}1 : i32{{\]}}
+ %0 = "gpu.block_id"() {dimension = "y"} : () -> index
+ return
+}
+
+// -----
+
+func @builtin() {
+ %c0 = constant 1 : index
+ "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_id_z} : (index, index, index, index, index, index) -> ()
+ return
+}
+
+// CHECK-LABEL: spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId")
+func @builtin_workgroup_id_z()
+ attributes {gpu.kernel} {
+ // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPID]]
+ // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+ // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}2 : i32{{\]}}
+ %0 = "gpu.block_id"() {dimension = "z"} : () -> index
+ return
+}
+
+// -----
+
+func @builtin() {
+ %c0 = constant 1 : index
+ "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_size_x} : (index, index, index, index, index, index) -> ()
+ return
+}
+
+// CHECK-LABEL: spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[WORKGROUPSIZE:@.*]] built_in("WorkgroupSize")
+func @builtin_workgroup_size_x()
+ attributes {gpu.kernel} {
+ // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPSIZE]]
+ // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+ // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}}
+ %0 = "gpu.block_dim"() {dimension = "x"} : () -> index
+ return
+}
+
+// -----
+
+func @builtin() {
+ %c0 = constant 1 : index
+ "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_local_id_x} : (index, index, index, index, index, index) -> ()
+ return
+}
+
+// CHECK-LABEL: spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[LOCALINVOCATIONID:@.*]] built_in("LocalInvocationId")
+func @builtin_local_id_x()
+ attributes {gpu.kernel} {
+ // CHECK: [[ADDRESS:%.*]] = spv._address_of [[LOCALINVOCATIONID]]
+ // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+ // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}}
+ %0 = "gpu.thread_id"() {dimension = "x"} : () -> index
+ return
+}
+
+// -----
+
+func @builtin() {
+ %c0 = constant 1 : index
+ "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_num_workgroups_x} : (index, index, index, index, index, index) -> ()
+ return
+}
+
+// CHECK-LABEL: spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[NUMWORKGROUPS:@.*]] built_in("NumWorkgroups")
+func @builtin_num_workgroups_x()
+ attributes {gpu.kernel} {
+ // CHECK: [[ADDRESS:%.*]] = spv._address_of [[NUMWORKGROUPS]]
+ // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+ // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}}
+ %0 = "gpu.grid_dim"() {dimension = "x"} : () -> index
+ return
+}
--- /dev/null
+// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
+
+func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) {
+ %c0 = constant 0 : index
+ %c12 = constant 12 : index
+ %0 = subi %c12, %c0 : index
+ %c1 = constant 1 : index
+ %c0_0 = constant 0 : index
+ %c4 = constant 4 : index
+ %1 = subi %c4, %c0_0 : index
+ %c1_1 = constant 1 : index
+ %c1_2 = constant 1 : index
+ "gpu.launch_func"(%0, %c1_2, %c1_2, %1, %c1_2, %c1_2, %arg0, %arg1, %arg2, %c0, %c0_0, %c1, %c1_1) {kernel = @load_store_kernel} : (index, index, index, index, index, index, memref<12x4xf32>, memref<12x4xf32>, memref<12x4xf32>, index, index, index, index) -> ()
+ return
+}
+
+// CHECK-LABEL: spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable {{@.*}} bind(0, 0) : [[TYPE1:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
+// CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 1) : [[TYPE2:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
+// CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 2) : [[TYPE3:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
+// CHECK: func @load_store_kernel([[ARG0:%.*]]: [[TYPE1]], [[ARG1:%.*]]: [[TYPE2]], [[ARG2:%.*]]: [[TYPE3]], [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: i32, [[ARG6:%.*]]: i32)
+func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
+ attributes {gpu.kernel} {
+ %0 = "gpu.block_id"() {dimension = "x"} : () -> index
+ %1 = "gpu.block_id"() {dimension = "y"} : () -> index
+ %2 = "gpu.block_id"() {dimension = "z"} : () -> index
+ %3 = "gpu.thread_id"() {dimension = "x"} : () -> index
+ %4 = "gpu.thread_id"() {dimension = "y"} : () -> index
+ %5 = "gpu.thread_id"() {dimension = "z"} : () -> index
+ %6 = "gpu.grid_dim"() {dimension = "x"} : () -> index
+ %7 = "gpu.grid_dim"() {dimension = "y"} : () -> index
+ %8 = "gpu.grid_dim"() {dimension = "z"} : () -> index
+ %9 = "gpu.block_dim"() {dimension = "x"} : () -> index
+ %10 = "gpu.block_dim"() {dimension = "y"} : () -> index
+ %11 = "gpu.block_dim"() {dimension = "z"} : () -> index
+ // CHECK: [[INDEX1:%.*]] = spv.IAdd [[ARG3]], {{%.*}}
+ %12 = addi %arg3, %0 : index
+ // CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], {{%.*}}
+ %13 = addi %arg4, %3 : index
+ // CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+ // CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]]
+ %14 = load %arg0[%12, %13] : memref<12x4xf32>
+ // CHECK: [[PTR2:%.*]] = spv.AccessChain [[ARG1]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+ // CHECK-NEXT: [[VAL2:%.*]] = spv.Load "StorageBuffer" [[PTR2]]
+ %15 = load %arg1[%12, %13] : memref<12x4xf32>
+ // CHECK: [[VAL3:%.*]] = spv.FAdd [[VAL1]], [[VAL2]]
+ %16 = addf %14, %15 : f32
+ // CHECK: [[PTR3:%.*]] = spv.AccessChain [[ARG2]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+ // CHECK-NEXT: spv.Store "StorageBuffer" [[PTR3]], [[VAL3]]
+ store %16, %arg2[%12, %13] : memref<12x4xf32>
+ return
+}
\ No newline at end of file