public:
using Size = uint64_t;
- /// Returns a new type with layout info. Assigns the type size in bytes to the
- /// `size`. Assigns the type alignment in bytes to the `alignment`.
- static Type decorateType(spirv::StructType structType, Size &size,
- Size &alignment);
+ /// Returns a new StructType with layout info. Assigns the type size in bytes
+ /// to the `size`. Assigns the type alignment in bytes to the `alignment`.
+ static spirv::StructType decorateType(spirv::StructType structType,
+ Size &size, Size &alignment);
/// Checks whether a type is legal in terms of Vulkan layout info
/// decoration. A type is dynamically illegal if it's a composite type in the
/// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant Storage
namespace mlir {
namespace spirv {
-// Creates a module pass that converts composite types used by objects in the
-// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage
-// classes with layout information.
-//
-// Right now this pass only supports Vulkan layout rules.
+class ModuleOp;
+/// Creates a module pass that converts composite types used by objects in the
+/// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage
+/// classes with layout information.
+/// Right now this pass only supports Vulkan layout rules.
std::unique_ptr<OpPassBase<mlir::ModuleOp>>
createDecorateSPIRVCompositeTypeLayoutPass();
+/// Creates a module pass that lowers the ABI attributes specified during SPIR-V
+/// Lowering. Specifically,
+/// 1) Creates the global variables for arguments of entry point function using
+/// the specification in the ABI attributes for each argument.
+/// 2) Inserts the EntryPointOp and the ExecutionModeOp for entry point
+/// functions using the specification in the EntryPointAttr.
+std::unique_ptr<OpPassBase<spirv::ModuleOp>> createLowerABIAttributesPass();
+
} // namespace spirv
} // namespace mlir
#define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/Support/StringExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SetVector.h"
namespace mlir {
-/// Type conversion from Standard Types to SPIR-V Types.
-class SPIRVBasicTypeConverter : public TypeConverter {
-public:
- /// Converts types to SPIR-V supported types.
- virtual Type convertType(Type t);
-};
-
/// Converts a function type according to the requirements of a SPIR-V entry
/// function. The arguments need to be converted to spv.GlobalVariables of
/// spv.ptr types so that they could be bound by the runtime.
class SPIRVTypeConverter final : public TypeConverter {
public:
- explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
- : basicTypeConverter(basicTypeConverter) {}
+ using TypeConverter::TypeConverter;
/// Converts types to SPIR-V types using the basic type converter.
Type convertType(Type t) override;
-
- /// Gets the basic type converter.
- Type convertBasicType(Type t) { return basicTypeConverter->convertType(t); }
-
-private:
- SPIRVBasicTypeConverter *basicTypeConverter;
};
/// Base class to define a conversion pattern to translate Ops into SPIR-V.
private:
};
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h.inc"
+
namespace spirv {
/// Returns a value that represents a builtin variable value within the SPIR-V
/// module.
OpBuilder &builder);
/// Legalizes a function as an entry function.
-LogicalResult lowerAsEntryFunction(FuncOp funcOp,
- SPIRVTypeConverter *typeConverter,
- ConversionPatternRewriter &rewriter,
- FuncOp &newFuncOp);
-
-/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
-/// spv.ExecutionMode ops.
-LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
+FuncOp lowerAsEntryFunction(FuncOp funcOp, SPIRVTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter,
+ ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo,
+ spirv::EntryPointABIAttr entryPointInfo);
+
+/// Attribute name for specifying argument ABI information.
+StringRef getInterfaceVarABIAttrName();
+
+/// Get the InterfaceVarABIAttr given its fields.
+InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet,
+ unsigned binding,
+ spirv::StorageClass storageClass,
+ MLIRContext *context);
+
+/// Attribute name for specifying entry point information.
+StringRef getEntryPointABIAttrName();
+
+/// Get the EntryPointABIAttr given its fields.
+EntryPointABIAttr getEntryPointABIAttr(ArrayRef<int32_t> localSize,
+ MLIRContext *context);
} // namespace spirv
} // namespace mlir
--- /dev/null
+//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the base file for supporting lowering to SPIR-V dialect. This
+// file defines SPIR-V attributes used for specifying the shader
+// interface or ABI. This is because SPIR-V module is expected to work in
+// an execution environment as specified by a client API. A SPIR-V module
+// needs to "link" correctly with the execution environment regarding the
+// resources that are used in the SPIR-V module and get populated with
+// data via the client API. The shader interface (or ABI) is passed into
+// SPIR-V lowering path via attributes defined in this file. A
+// compilation flow targeting SPIR-V is expected to attach such
+// attributes to resources and other suitable places.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_LOWERING
+#define SPIRV_LOWERING
+
+include "mlir/Dialect/SPIRV/SPIRVBase.td"
+
+// For arguments that eventually map to spv.globalVariable for the
+// shader interface, this attribute specifies the information regarding
+// the global variable :
+// 1) Descriptor Set.
+// 2) Binding number.
+// 3) Storage class.
+def SPV_InterfaceVarABIAttr:
+ StructAttr<"InterfaceVarABIAttr", SPV_Dialect,
+ [StructFieldAttr<"descriptor_set", I32Attr>,
+ StructFieldAttr<"binding", I32Attr>,
+ StructFieldAttr<"storage_class", SPV_StorageClassAttr>]>;
+
+// For entry functions, this attribute specifies information related to entry
+// points in the generated SPIR-V module:
+// 1) WorkGroup Size.
+def SPV_EntryPointABIAttr:
+ StructAttr<"EntryPointABIAttr", SPV_Dialect,
+ [StructFieldAttr<"local_size", I32ElementsAttr>]>;
+
+#endif // SPIRV_LOWERING
"TypeAttr type, ArrayRef<NamedAttribute> namedAttrs", [{
state.addAttribute("type", type);
state.addAttributes(namedAttrs);
- }]>
+ }]>,
+ OpBuilder<[{Builder *builder, OperationState &state,
+ Type type, StringRef name, unsigned descriptorSet,
+ unsigned binding}]>
];
let results = (outs);
CPred<"$_self.cast<ArrayAttr>().size() >= " # n>,
"with at least " # n # " elements">;
+class ArrayCount<int n> : AttrConstraint<
+ CPred<"$_self.cast<ArrayAttr>().size() == " #n>,
+ "with exactly " # n # " elements">;
+
class IntArrayNthElemEq<int index, int value> : AttrConstraint<
And<[
CPred<"$_self.cast<ArrayAttr>().size() > " # index>,
PatternMatchResult
KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
- FuncOp newFuncOp;
if (!gpu::GPUDialect::isKernel(funcOp)) {
return matchFailure();
}
- if (failed(spirv::lowerAsEntryFunction(funcOp, &typeConverter, rewriter,
- newFuncOp))) {
+ SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
+ for (auto argNum : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
+ argABI.push_back(spirv::getInterfaceVarABIAttr(
+ 0, argNum, spirv::StorageClass::StorageBuffer, rewriter.getContext()));
+ }
+ // TODO(ravishankarm) : For now set this to {32, 1, 1}. This is incorrect. The
+ // actual workgroup size needs to be plumbed through.
+ auto context = rewriter.getContext();
+ auto entryPointAttr = spirv::getEntryPointABIAttr({32, 1, 1}, context);
+ FuncOp newFuncOp = spirv::lowerAsEntryFunction(
+ funcOp, typeConverter, rewriter, argABI, entryPointAttr);
+ if (!newFuncOp) {
return matchFailure();
}
+ newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(),
+ rewriter.getContext()));
return matchSuccess();
}
if (!gpu::GPUDialect::isKernel(funcOp)) {
return;
}
- OpBuilder builder(module.getBodyRegion());
+ OpBuilder builder(funcOp.getOperation());
// Create a new spirv::ModuleOp for this function, and clone the
// function into it.
// TODO : Generalize this to account for different extensions,
});
/// Dialect conversion to lower the functions with the spirv::ModuleOps.
- SPIRVBasicTypeConverter basicTypeConverter;
- SPIRVTypeConverter typeConverter(&basicTypeConverter);
+ SPIRVTypeConverter typeConverter;
OwningRewritePatternList patterns;
populateGPUToSPIRVPatterns(context, typeConverter, patterns);
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
ConversionTarget target(*context);
target.addLegalDialect<spirv::SPIRVDialect>();
- target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
- // TODO(ravishankarm) : Currently lowering does not support handling
- // function conversion of non-kernel functions. This is to be added.
-
- // For kernel functions, verify that the signature is void(void).
- return gpu::GPUDialect::isKernel(op) && op.getNumResults() == 0 &&
- op.getNumArguments() == 0;
- });
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
if (failed(applyFullConversion(spirvModules, target, patterns,
&typeConverter))) {
return signalPassFailure();
}
-
- // After the SPIR-V modules have been generated, some finalization is needed
- // for the entry functions. For example, adding spv.EntryPoint op,
- // spv.ExecutionMode op, etc.
- for (auto *spvModule : spirvModules) {
- for (auto op :
- cast<spirv::ModuleOp>(spvModule).getBlock().getOps<FuncOp>()) {
- if (gpu::GPUDialect::isKernel(op)) {
- OpBuilder builder(op.getContext());
- builder.setInsertionPointAfter(op);
- if (failed(spirv::finalizeEntryFunction(op, builder))) {
- return signalPassFailure();
- }
- op.getOperation()->removeAttr(Identifier::get(
- gpu::GPUDialect::getKernelFuncAttrName(), op.getContext()));
- }
- }
- }
}
OpPassBase<ModuleOp> *createConvertGPUToSPIRVPass() {
return matchFailure();
}
auto spirvConstType =
- typeConverter.convertBasicType(constIndexOp.getResult()->getType());
+ typeConverter.convertType(constIndexOp.getResult()->getType());
auto spirvConstVal =
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
matchAndRewrite(StdOp operation, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultType =
- this->typeConverter.convertBasicType(operation.getResult()->getType());
+ this->typeConverter.convertType(operation.getResult()->getType());
rewriter.template replaceOpWithNewOp<SPIRVOp>(
operation, resultType, operands, ArrayRef<NamedAttribute>());
return this->matchSuccess();
OwningRewritePatternList patterns;
auto module = getModule();
- SPIRVBasicTypeConverter basicTypeConverter;
- SPIRVTypeConverter typeConverter(&basicTypeConverter);
+ SPIRVTypeConverter typeConverter;
populateStandardToSPIRVPatterns(module.getContext(), typeConverter, patterns);
ConversionTarget target(*(module.getContext()));
target.addLegalDialect<spirv::SPIRVDialect>();
using namespace mlir;
-Type VulkanLayoutUtils::decorateType(spirv::StructType structType,
- VulkanLayoutUtils::Size &size,
- VulkanLayoutUtils::Size &alignment) {
+spirv::StructType
+VulkanLayoutUtils::decorateType(spirv::StructType structType,
+ VulkanLayoutUtils::Size &size,
+ VulkanLayoutUtils::Size &alignment) {
if (structType.getNumElements() == 0) {
return structType;
}
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
-
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "llvm/ADT/Sequence.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
+// Attributes for ABI
+//===----------------------------------------------------------------------===//
+
+// Pull in the attributes needed for lowering.
+namespace mlir {
+#include "mlir/Dialect/SPIRV/SPIRVLowering.cpp.inc"
+}
+
+StringRef mlir::spirv::getInterfaceVarABIAttrName() {
+ return "spirv.interface_var_abi";
+}
+
+mlir::spirv::InterfaceVarABIAttr
+mlir::spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
+ spirv::StorageClass storageClass,
+ MLIRContext *context) {
+ Type i32Type = IntegerType::get(32, context);
+ return mlir::spirv::InterfaceVarABIAttr::get(
+ IntegerAttr::get(i32Type, descriptorSet),
+ IntegerAttr::get(i32Type, binding),
+ IntegerAttr::get(i32Type, static_cast<int64_t>(storageClass)), context);
+}
+
+StringRef mlir::spirv::getEntryPointABIAttrName() {
+ return "spirv.entry_point_abi";
+}
+
+mlir::spirv::EntryPointABIAttr
+mlir::spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize,
+ MLIRContext *context) {
+ assert(localSize.size() == 3);
+ return mlir::spirv::EntryPointABIAttr::get(
+ DenseElementsAttr::get<int32_t>(
+ VectorType::get(3, IntegerType::get(32, context)), localSize)
+ .cast<DenseIntElementsAttr>(),
+ context);
+}
+
+//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
return IntegerType::get(32, context);
}
-Type convertIndexType(IndexType t) { return convertIndexType(t.getContext()); }
-
-Type basicTypeConversion(Type t) {
+Type typeConversionImpl(Type t) {
// Check if the type is SPIR-V supported. If so return the type.
if (spirv::SPIRVDialect::isValidType(t)) {
return t;
}
if (auto indexType = t.dyn_cast<IndexType>()) {
- return convertIndexType(indexType);
+ return convertIndexType(t.getContext());
}
if (auto memRefType = t.dyn_cast<MemRefType>()) {
auto elementType = memRefType.getElementType();
- if (memRefType.hasStaticShape()) {
+ // TODO(ravishankarm) : Handle dynamic shapes and memref with strides.
+ if (memRefType.hasStaticShape() && memRefType.getAffineMaps().empty()) {
// Convert to a multi-dimensional spv.array if size is known.
for (auto size : reverse(memRefType.getShape())) {
elementType = spirv::ArrayType::get(elementType, size);
}
+ // For now initialize the storage class to StorageBuffer. This will be
+ // updated later based on whats passed in w.r.t to the ABI attributes.
return spirv::PointerType::get(elementType,
spirv::StorageClass::StorageBuffer);
- } else {
- // Vulkan SPIR-V validation rules require runtime array type to be the
- // last member of a struct.
- return spirv::PointerType::get(spirv::RuntimeArrayType::get(elementType),
- spirv::StorageClass::StorageBuffer);
}
}
return Type();
}
-
-Type getLayoutDecoratedType(spirv::StructType type) {
- VulkanLayoutUtils::Size size = 0, alignment = 0;
- return VulkanLayoutUtils::decorateType(type, size, alignment);
-}
-
-/// Generates the type of variable given the type of object.
-static Type getGlobalVarTypeForEntryFnArg(Type t) {
- auto convertedType = basicTypeConversion(t);
- if (auto ptrType = convertedType.dyn_cast<spirv::PointerType>()) {
- if (!ptrType.getPointeeType().isa<spirv::StructType>()) {
- return spirv::PointerType::get(
- getLayoutDecoratedType(
- spirv::StructType::get(ptrType.getPointeeType())),
- ptrType.getStorageClass());
- }
- } else {
- return spirv::PointerType::get(
- getLayoutDecoratedType(spirv::StructType::get(convertedType)),
- spirv::StorageClass::StorageBuffer);
- }
- return convertedType;
-}
} // namespace
-Type SPIRVBasicTypeConverter::convertType(Type t) {
- return basicTypeConversion(t);
-}
-
-Type SPIRVTypeConverter::convertType(Type t) {
- return getGlobalVarTypeForEntryFnArg(t);
-}
+Type SPIRVTypeConverter::convertType(Type t) { return typeConversionImpl(t); }
//===----------------------------------------------------------------------===//
// Builtin Variables
// Entry Function signature Conversion
//===----------------------------------------------------------------------===//
-namespace {
-/// Computes the replacement value for an argument of an entry function. It
-/// allocates a global variable for this argument and adds statements in the
-/// entry block to get a replacement value within function scope.
-Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
- size_t origArgNum, Value *origArg) {
- // Create a global variable for this argument.
- auto insertionOp = rewriter.getInsertionBlock()->getParent();
- auto module = insertionOp->getParentOfType<spirv::ModuleOp>();
- if (!module) {
- return nullptr;
- }
- auto funcOp = insertionOp->getParentOfType<FuncOp>();
- spirv::GlobalVariableOp var;
- {
- OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
- rewriter.setInsertionPoint(funcOp.getOperation());
- std::string varName =
- funcOp.getName().str() + "_arg_" + std::to_string(origArgNum);
- var = rewriter.create<spirv::GlobalVariableOp>(
- funcOp.getLoc(),
- TypeAttr::get(getGlobalVarTypeForEntryFnArg(origArg->getType())),
- rewriter.getStringAttr(varName), nullptr);
- var.setAttr(
- spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
- rewriter.getI32IntegerAttr(0));
- var.setAttr(
- spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
- rewriter.getI32IntegerAttr(origArgNum));
- }
- // Insert the addressOf and load instructions, to get back the converted value
- // type.
- auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
- auto indexType = convertIndexType(funcOp.getContext());
- auto zero = rewriter.create<spirv::ConstantOp>(
- funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
- auto accessChain = rewriter.create<spirv::AccessChainOp>(
- funcOp.getLoc(), addressOf.pointer(), zero.constant());
- // If the original argument is a tensor/memref type, the value is not
- // loaded. Instead the pointer value is returned to allow its use in access
- // chain ops.
- auto origArgType = origArg->getType();
- if (origArgType.isa<MemRefType>()) {
- return accessChain;
- }
- return rewriter.create<spirv::LoadOp>(
- funcOp.getLoc(), accessChain.component_ptr(), /*memory_access=*/nullptr,
- /*alignment=*/nullptr);
-}
-
-FuncOp applySignatureConversion(
- FuncOp funcOp, ConversionPatternRewriter &rewriter,
- TypeConverter::SignatureConversion &signatureConverter) {
- // Create a new function with an updated signature.
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
- llvm::None, funcOp.getContext()));
-
- // Tell the rewriter to convert the region signature.
- rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
- rewriter.replaceOp(funcOp.getOperation(), llvm::None);
- return newFuncOp;
-}
-
-/// Gets the global variables that need to be specified as interface variable
-/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
-LogicalResult getInterfaceVariables(FuncOp funcOp,
- SmallVectorImpl<Attribute> &interfaceVars) {
- auto module = funcOp.getParentOfType<spirv::ModuleOp>();
- if (!module) {
- return failure();
- }
- llvm::SetVector<Operation *> interfaceVarSet;
- for (auto &block : funcOp) {
- // TODO(ravishankarm) : This should in reality traverse the entry function
- // call graph and collect all the interfaces. For now, just traverse the
- // instructions in this function.
- for (auto op : block.getOps<spirv::AddressOfOp>()) {
- auto var = module.lookupSymbol<spirv::GlobalVariableOp>(op.variable());
- if (var.type().cast<spirv::PointerType>().getStorageClass() ==
- spirv::StorageClass::StorageBuffer) {
- continue;
- }
- interfaceVarSet.insert(var.getOperation());
- }
- }
- for (auto &var : interfaceVarSet) {
- interfaceVars.push_back(SymbolRefAttr::get(
- cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
- }
- return success();
-}
-} // namespace
-
-LogicalResult mlir::spirv::lowerAsEntryFunction(
- FuncOp funcOp, SPIRVTypeConverter *typeConverter,
- ConversionPatternRewriter &rewriter, FuncOp &newFuncOp) {
+FuncOp mlir::spirv::lowerAsEntryFunction(
+ FuncOp funcOp, SPIRVTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter,
+ ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo,
+ spirv::EntryPointABIAttr entryPointInfo) {
auto fnType = funcOp.getType();
if (fnType.getNumResults()) {
- return funcOp.emitError("SPIR-V lowering only supports functions with no "
- "return values right now");
+ funcOp.emitError("SPIR-V lowering only supports entry functions"
+ "with no return values right now");
+ return nullptr;
+ }
+ if (fnType.getNumInputs() != argABIInfo.size()) {
+ funcOp.emitError(
+ "lowering as entry functions requires ABI info for all arguments");
+ return nullptr;
}
// For entry functions need to make the signature void(void). Compute the
// replacement value for all arguments and replace all uses.
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
{
- OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
- rewriter.setInsertionPointToStart(&funcOp.front());
- for (auto origArg : enumerate(funcOp.getArguments())) {
- auto replacement = createAndLoadGlobalVarForEntryFnArg(
- rewriter, origArg.index(), origArg.value());
- signatureConverter.remapInput(origArg.index(), replacement);
+ for (auto argType : enumerate(funcOp.getType().getInputs())) {
+ auto convertedType = typeConverter.convertType(argType.value());
+ signatureConverter.addInputs(argType.index(), convertedType);
}
}
- newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
- return success();
-}
+ auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ newFuncOp.setType(rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), llvm::None));
+ rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+ rewriter.replaceOp(funcOp.getOperation(), llvm::None);
-LogicalResult mlir::spirv::finalizeEntryFunction(FuncOp newFuncOp,
- OpBuilder &builder) {
- // Add the spv.EntryPointOp after collecting all the interface variables
- // needed.
- SmallVector<Attribute, 1> interfaceVars;
- if (failed(getInterfaceVariables(newFuncOp, interfaceVars))) {
- return failure();
+ // Set the attributes for argument and the function.
+ StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
+ for (auto argIndex : llvm::seq<unsigned>(0, newFuncOp.getNumArguments())) {
+ newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
}
- builder.create<spirv::EntryPointOp>(newFuncOp.getLoc(),
- spirv::ExecutionModel::GLCompute,
- newFuncOp, interfaceVars);
- // Specify the spv.ExecutionModeOp.
-
- /// TODO(ravishankarm): Vulkan environment for SPIR-V requires "either a
- /// LocalSize execution mode or an object decorated with the WorkgroupSize
- /// decoration must be specified." Better approach is to use the
- /// WorkgroupSize GlobalVariable with initializer being a specialization
- /// constant. But current support for specialization constant does not allow
- /// for this. So for now use the execution mode. Hard-wiring this to {1, 1,
- /// 1} for now. To be fixed ASAP.
- builder.create<spirv::ExecutionModeOp>(newFuncOp.getLoc(), newFuncOp,
- spirv::ExecutionMode::LocalSize,
- ArrayRef<int32_t>{1, 1, 1});
- return success();
+ newFuncOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
+ return newFuncOp;
}
// spv.globalVariable
//===----------------------------------------------------------------------===//
+void spirv::GlobalVariableOp::build(Builder *builder, OperationState &state,
+ Type type, StringRef name,
+ unsigned descriptorSet, unsigned binding) {
+ build(builder, state, TypeAttr::get(type), builder->getStringAttr(name),
+ nullptr);
+ state.addAttribute(
+ spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
+ builder->getI32IntegerAttr(descriptorSet));
+ state.addAttribute(
+ spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
+ builder->getI32IntegerAttr(binding));
+}
+
static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
OperationState &state) {
// Parse variable name.
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
add_llvm_library(MLIRSPIRVTransforms
DecorateSPIRVCompositeTypeLayoutPass.cpp
+ LowerABIAttributesPass.cpp
)
target_link_libraries(MLIRSPIRVTransforms
--- /dev/null
+//===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to lower attributes that specify the shader ABI
+// for the functions in the generated SPIR-V module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/LayoutUtils.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+/// Checks if the `type` is a scalar or vector type. It is assumed that they are
+/// valid for SPIR-V dialect already.
+static bool isScalarOrVectorType(Type type) {
+ return spirv::SPIRVDialect::isValidScalarType(type) || type.isa<VectorType>();
+}
+
+/// Creates a global variable for an argument based on the ABI info.
+static spirv::GlobalVariableOp
+createGlobalVariableForArg(FuncOp funcOp, OpBuilder &builder, unsigned argNum,
+ spirv::InterfaceVarABIAttr abiInfo) {
+ auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
+ if (!spirvModule) {
+ return nullptr;
+ }
+ OpBuilder::InsertionGuard moduleInsertionGuard(builder);
+ builder.setInsertionPoint(funcOp.getOperation());
+ std::string varName =
+ funcOp.getName().str() + "_arg_" + std::to_string(argNum);
+
+ // Get the type of variable. If this is a scalar/vector type and has an ABI
+ // info create a variable of type !spv.ptr<!spv.struct<elementTYpe>>. If not
+ // it must already be a !spv.ptr<!spv.struct<...>>.
+ auto varType = funcOp.getType().getInput(argNum);
+ auto storageClass =
+ static_cast<spirv::StorageClass>(abiInfo.storage_class().getInt());
+ if (isScalarOrVectorType(varType)) {
+ varType =
+ spirv::PointerType::get(spirv::StructType::get(varType), storageClass);
+ } else {
+ auto varPtrType = varType.cast<spirv::PointerType>();
+ varType = spirv::PointerType::get(
+ spirv::StructType::get(varPtrType.getPointeeType()), storageClass);
+ }
+ auto varPtrType = varType.cast<spirv::PointerType>();
+ auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
+
+ // Set the offset information.
+ VulkanLayoutUtils::Size size = 0, alignment = 0;
+ varPointeeType =
+ VulkanLayoutUtils::decorateType(varPointeeType, size, alignment)
+ .cast<spirv::StructType>();
+ varType =
+ spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
+
+ return builder.create<spirv::GlobalVariableOp>(
+ funcOp.getLoc(), varType, varName, abiInfo.descriptor_set().getInt(),
+ abiInfo.binding().getInt());
+}
+
+/// Gets the global variables that need to be specified as interface variable
+/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
+static LogicalResult
+getInterfaceVariables(FuncOp funcOp,
+ SmallVectorImpl<Attribute> &interfaceVars) {
+ auto module = funcOp.getParentOfType<spirv::ModuleOp>();
+ if (!module) {
+ return failure();
+ }
+ llvm::SetVector<Operation *> interfaceVarSet;
+
+ // TODO(ravishankarm) : This should in reality traverse the entry function
+ // call graph and collect all the interfaces. For now, just traverse the
+ // instructions in this function.
+ funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
+ auto var =
+ module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
+ if (var.type().cast<spirv::PointerType>().getStorageClass() !=
+ spirv::StorageClass::StorageBuffer) {
+ interfaceVarSet.insert(var.getOperation());
+ }
+ });
+ for (auto &var : interfaceVarSet) {
+ interfaceVars.push_back(SymbolRefAttr::get(
+ cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
+ }
+ return success();
+}
+
+/// Lowers the entry point attribute.
+static LogicalResult lowerEntryPointABIAttr(FuncOp funcOp, OpBuilder &builder) {
+ auto entryPointAttrName = spirv::getEntryPointABIAttrName();
+ auto entryPointAttr =
+ funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
+ if (!entryPointAttr) {
+ return failure();
+ }
+
+ OpBuilder::InsertionGuard moduleInsertionGuard(builder);
+ auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
+ builder.setInsertionPoint(spirvModule.body().front().getTerminator());
+
+ // Adds the spv.EntryPointOp after collecting all the interface variables
+ // needed.
+ SmallVector<Attribute, 1> interfaceVars;
+ if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
+ return failure();
+ }
+ builder.create<spirv::EntryPointOp>(
+ funcOp.getLoc(), spirv::ExecutionModel::GLCompute, funcOp, interfaceVars);
+ // Specifies the spv.ExecutionModeOp.
+ auto localSizeAttr = entryPointAttr.local_size();
+ SmallVector<int32_t, 3> localSize(localSizeAttr.getValues<int32_t>());
+ builder.create<spirv::ExecutionModeOp>(
+ funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize);
+ funcOp.removeAttr(entryPointAttrName);
+ return success();
+}
+
+namespace {
+/// Pattern rewriter for changing function signature to match the ABI specified
+/// in attributes.
+class FuncOpLowering final : public SPIRVOpLowering<FuncOp> {
+public:
+ using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+ PatternMatchResult
+ matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pass to implement the ABI information specified as attributes.
+class LowerABIAttributesPass final
+ : public OperationPass<LowerABIAttributesPass, spirv::ModuleOp> {
+private:
+ void runOnOperation() override;
+};
+} // namespace
+
+PatternMatchResult
+FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
+ spirv::getEntryPointABIAttrName())) {
+ // TODO(ravishankarm) : Non-entry point functions are not handled.
+ return matchFailure();
+ }
+ TypeConverter::SignatureConversion signatureConverter(
+ funcOp.getType().getNumInputs());
+
+ auto attrName = spirv::getInterfaceVarABIAttrName();
+ for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) {
+ auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+ argType.index(), attrName);
+ if (!abiInfo) {
+ // TODO(ravishankarm) : For non-entry point functions, it should be legal
+ // to pass around scalar/vector values and return a scalar/vector. For now
+ // non-entry point functions are not handled in this ABI lowering and will
+ // produce an error.
+ return matchFailure();
+ }
+ auto var =
+ createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo);
+
+ OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
+ rewriter.setInsertionPointToStart(&funcOp.front());
+ // Inserts spirv::AddressOf and spirv::AccessChain operations.
+ auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
+ auto indexType =
+ typeConverter.convertType(IndexType::get(funcOp.getContext()));
+ auto zero = rewriter.create<spirv::ConstantOp>(
+ funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
+ Value *replacement = rewriter.create<spirv::AccessChainOp>(
+ funcOp.getLoc(), addressOf.pointer(), zero.constant());
+ // Check if the arg is a scalar or vector type. In that case, the value
+ // needs to be loaded into registers.
+ // TODO(ravishankarm) : This is loading value of the scalar into registers
+ // at the start of the function. It is probably better to do the load just
+ // before the use. There might be multiple loads and currently there is no
+ // easy way to replace all uses with a sequence of operations.
+ if (isScalarOrVectorType(argType.value())) {
+ replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), replacement,
+ /*memory_access=*/nullptr,
+ /*alignment=*/nullptr);
+ }
+ signatureConverter.remapInput(argType.index(), replacement);
+ }
+
+ // Creates a new function with the update signature.
+ auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ newFuncOp.setType(rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), llvm::None));
+ rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+ rewriter.eraseOp(funcOp.getOperation());
+ return matchSuccess();
+}
+
+void LowerABIAttributesPass::runOnOperation() {
+ // Uses the signature conversion methodology of the dialect conversion
+ // framework to implement the conversion.
+ spirv::ModuleOp module = getOperation();
+ MLIRContext *context = &getContext();
+
+ SPIRVTypeConverter typeConverter;
+ OwningRewritePatternList patterns;
+ patterns.insert<FuncOpLowering>(context, typeConverter);
+
+ ConversionTarget target(*context);
+ target.addLegalDialect<spirv::SPIRVDialect>();
+ auto entryPointAttrName = spirv::getEntryPointABIAttrName();
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ return op.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName) &&
+ op.getNumResults() == 0 && op.getNumArguments() == 0;
+ });
+ target.addLegalOp<ReturnOp>();
+ if (failed(
+ applyPartialConversion(module, target, patterns, &typeConverter))) {
+ return signalPassFailure();
+ }
+
+ // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
+ // attributes.
+ OpBuilder builder(context);
+ SmallVector<FuncOp, 1> entryPointFns;
+ module.walk([&](FuncOp funcOp) {
+ if (funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
+ entryPointFns.push_back(funcOp);
+ }
+ });
+ for (auto fn : entryPointFns) {
+ if (failed(lowerEntryPointABIAttr(fn, builder))) {
+ return signalPassFailure();
+ }
+ }
+}
+
+std::unique_ptr<OpPassBase<spirv::ModuleOp>>
+mlir::spirv::createLowerABIAttributesPass() {
+ return std::make_unique<LowerABIAttributesPass>();
+}
+
+static PassRegistration<LowerABIAttributesPass>
+ pass("spirv-lower-abi-attrs", "Lower SPIR-V ABI Attributes");
// Handle the case of a 1->0 value mapping.
if (!argInfo) {
+ // If a replacement value was given for this argument, use that to
+ // replace all uses.
+ auto argReplacementValue = mapping.lookupOrDefault(origArg);
+ if (argReplacementValue != origArg) {
+ origArg->replaceAllUsesWith(argReplacementValue);
+ continue;
+ }
// If there are any dangling uses then replace the argument with one
// generated by the type converter. This is necessary as the cast must
// persist in the IR after conversion.
// CHECK-DAG: spv.globalVariable [[NUMWORKGROUPSVAR:@.*]] built_in("NumWorkgroups") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
- // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
- // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
- // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
- // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
- // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
- // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
- // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
- // CHECK: func [[FN:@.*]]()
+ // CHECK-LABEL: func @load_store_kernel
+ // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer> {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer> {spirv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG3:%.*]]: i32 {spirv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG4:%.*]]: i32 {spirv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG5:%.*]]: i32 {spirv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG6:%.*]]: i32 {spirv.interface_var_abi = {binding = 6 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
attributes {gpu.kernel} {
- // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
- // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
- // CHECK: [[ARG0:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
- // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
- // CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
- // CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
- // CHECK: [[ADDRESSARG2:%.*]] = spv._address_of [[VAR2]]
- // CHECK: [[CONST2:%.*]] = spv.constant 0 : i32
- // CHECK: [[ARG2:%.*]] = spv.AccessChain [[ADDRESSARG2]]{{\[}}[[CONST2]]
- // CHECK: [[ADDRESSARG3:%.*]] = spv._address_of [[VAR3]]
- // CHECK: [[CONST3:%.*]] = spv.constant 0 : i32
- // CHECK: [[ARG3PTR:%.*]] = spv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
- // CHECK: [[ARG3:%.*]] = spv.Load "StorageBuffer" [[ARG3PTR]]
- // CHECK: [[ADDRESSARG4:%.*]] = spv._address_of [[VAR4]]
- // CHECK: [[CONST4:%.*]] = spv.constant 0 : i32
- // CHECK: [[ARG4PTR:%.*]] = spv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
- // CHECK: [[ARG4:%.*]] = spv.Load "StorageBuffer" [[ARG4PTR]]
- // CHECK: [[ADDRESSARG5:%.*]] = spv._address_of [[VAR5]]
- // CHECK: [[CONST5:%.*]] = spv.constant 0 : i32
- // CHECK: [[ARG5PTR:%.*]] = spv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
- // CHECK: [[ARG5:%.*]] = spv.Load "StorageBuffer" [[ARG5PTR]]
- // CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]]
- // CHECK: [[CONST6:%.*]] = spv.constant 0 : i32
- // CHECK: [[ARG6PTR:%.*]] = spv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
- // CHECK: [[ARG6:%.*]] = spv.Load "StorageBuffer" [[ARG6PTR]]
// CHECK: [[ADDRESSWORKGROUPID:%.*]] = spv._address_of [[WORKGROUPIDVAR]]
// CHECK: [[WORKGROUPID:%.*]] = spv.Load "Input" [[ADDRESSWORKGROUPID]]
// CHECK: [[WORKGROUPIDX:%.*]] = spv.CompositeExtract [[WORKGROUPID]]{{\[}}0 : i32{{\]}}
module @kernels attributes {gpu.kernel_module} {
// CHECK: spv.module "Logical" "GLSL450" {
- // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
- // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
- // CHECK: func [[FN:@.*]]()
+ // CHECK-LABEL: func @kernel_1
+ // CHECK-SAME: {{%.*}}: f32 {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.array<12 x f32>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>)
attributes { gpu.kernel } {
- // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
- // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
- // CHECK: [[ARG0PTR:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
- // CHECK: [[ARG0:%.*]] = spv.Load "StorageBuffer" [[ARG0PTR]]
- // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
- // CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
- // CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
- // CHECK-NEXT: spv.Return
- // CHECK: spv.EntryPoint "GLCompute" [[FN]]
- // CHECK: spv.ExecutionMode [[FN]] "LocalSize"
+ // CHECK: spv.Return
return
}
}
--- /dev/null
+// RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s
+
+// CHECK-LABEL: spv.module
+spv.module "Logical" "GLSL450" {
+ // CHECK-DAG: spv.globalVariable [[WORKGROUPSIZE:@.*]] built_in("WorkgroupSize")
+ spv.globalVariable @__builtin_var_WorkgroupSize__ built_in("WorkgroupSize") : !spv.ptr<vector<3xi32>, Input>
+ // CHECK-DAG: spv.globalVariable [[NUMWORKGROUPS:@.*]] built_in("NumWorkgroups")
+ spv.globalVariable @__builtin_var_NumWorkgroups__ built_in("NumWorkgroups") : !spv.ptr<vector<3xi32>, Input>
+ // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONID:@.*]] built_in("LocalInvocationId")
+ spv.globalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+ // CHECK-DAG: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId")
+ spv.globalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
+ // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+ // CHECK: func [[FN:@.*]]()
+ func @load_store_kernel(%arg0: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+ {spirv.interface_var_abi = {binding = 0 : i32,
+ descriptor_set = 0 : i32,
+ storage_class = 12 : i32}},
+ %arg1: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+ {spirv.interface_var_abi = {binding = 1 : i32,
+ descriptor_set = 0 : i32,
+ storage_class = 12 : i32}},
+ %arg2: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+ {spirv.interface_var_abi = {binding = 2 : i32,
+ descriptor_set = 0 : i32,
+ storage_class = 12 : i32}},
+ %arg3: i32
+ {spirv.interface_var_abi = {binding = 3 : i32,
+ descriptor_set = 0 : i32,
+ storage_class = 12 : i32}},
+ %arg4: i32
+ {spirv.interface_var_abi = {binding = 4 : i32,
+ descriptor_set = 0 : i32,
+ storage_class = 12 : i32}},
+ %arg5: i32
+ {spirv.interface_var_abi = {binding = 5 : i32,
+ descriptor_set = 0 : i32,
+ storage_class = 12 : i32}},
+ %arg6: i32
+ {spirv.interface_var_abi = {binding = 6 : i32,
+ descriptor_set = 0 : i32,
+ storage_class = 12 : i32}})
+ attributes {spirv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+ // CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]]
+ // CHECK: [[CONST6:%.*]] = spv.constant 0 : i32
+ // CHECK: [[ARG6PTR:%.*]] = spv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
+ // CHECK: {{%.*}} = spv.Load "StorageBuffer" [[ARG6PTR]]
+ // CHECK: [[ADDRESSARG5:%.*]] = spv._address_of [[VAR5]]
+ // CHECK: [[CONST5:%.*]] = spv.constant 0 : i32
+ // CHECK: [[ARG5PTR:%.*]] = spv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
+ // CHECK: {{%.*}} = spv.Load "StorageBuffer" [[ARG5PTR]]
+ // CHECK: [[ADDRESSARG4:%.*]] = spv._address_of [[VAR4]]
+ // CHECK: [[CONST4:%.*]] = spv.constant 0 : i32
+ // CHECK: [[ARG4PTR:%.*]] = spv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
+ // CHECK: [[ARG4:%.*]] = spv.Load "StorageBuffer" [[ARG4PTR]]
+ // CHECK: [[ADDRESSARG3:%.*]] = spv._address_of [[VAR3]]
+ // CHECK: [[CONST3:%.*]] = spv.constant 0 : i32
+ // CHECK: [[ARG3PTR:%.*]] = spv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
+ // CHECK: [[ARG3:%.*]] = spv.Load "StorageBuffer" [[ARG3PTR]]
+ // CHECK: [[ADDRESSARG2:%.*]] = spv._address_of [[VAR2]]
+ // CHECK: [[CONST2:%.*]] = spv.constant 0 : i32
+ // CHECK: [[ARG2:%.*]] = spv.AccessChain [[ADDRESSARG2]]{{\[}}[[CONST2]]
+ // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
+ // CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
+ // CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
+ // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
+ // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
+ // CHECK: [[ARG0:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
+ %0 = spv._address_of @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi32>, Input>
+ %1 = spv.Load "Input" %0 : vector<3xi32>
+ %2 = spv.CompositeExtract %1[0 : i32] : vector<3xi32>
+ %3 = spv._address_of @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi32>, Input>
+ %4 = spv.Load "Input" %3 : vector<3xi32>
+ %5 = spv.CompositeExtract %4[1 : i32] : vector<3xi32>
+ %6 = spv._address_of @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi32>, Input>
+ %7 = spv.Load "Input" %6 : vector<3xi32>
+ %8 = spv.CompositeExtract %7[2 : i32] : vector<3xi32>
+ %9 = spv._address_of @__builtin_var_LocalInvocationId__ : !spv.ptr<vector<3xi32>, Input>
+ %10 = spv.Load "Input" %9 : vector<3xi32>
+ %11 = spv.CompositeExtract %10[0 : i32] : vector<3xi32>
+ %12 = spv._address_of @__builtin_var_LocalInvocationId__ : !spv.ptr<vector<3xi32>, Input>
+ %13 = spv.Load "Input" %12 : vector<3xi32>
+ %14 = spv.CompositeExtract %13[1 : i32] : vector<3xi32>
+ %15 = spv._address_of @__builtin_var_LocalInvocationId__ : !spv.ptr<vector<3xi32>, Input>
+ %16 = spv.Load "Input" %15 : vector<3xi32>
+ %17 = spv.CompositeExtract %16[2 : i32] : vector<3xi32>
+ %18 = spv._address_of @__builtin_var_NumWorkgroups__ : !spv.ptr<vector<3xi32>, Input>
+ %19 = spv.Load "Input" %18 : vector<3xi32>
+ %20 = spv.CompositeExtract %19[0 : i32] : vector<3xi32>
+ %21 = spv._address_of @__builtin_var_NumWorkgroups__ : !spv.ptr<vector<3xi32>, Input>
+ %22 = spv.Load "Input" %21 : vector<3xi32>
+ %23 = spv.CompositeExtract %22[1 : i32] : vector<3xi32>
+ %24 = spv._address_of @__builtin_var_NumWorkgroups__ : !spv.ptr<vector<3xi32>, Input>
+ %25 = spv.Load "Input" %24 : vector<3xi32>
+ %26 = spv.CompositeExtract %25[2 : i32] : vector<3xi32>
+ %27 = spv._address_of @__builtin_var_WorkgroupSize__ : !spv.ptr<vector<3xi32>, Input>
+ %28 = spv.Load "Input" %27 : vector<3xi32>
+ %29 = spv.CompositeExtract %28[0 : i32] : vector<3xi32>
+ %30 = spv._address_of @__builtin_var_WorkgroupSize__ : !spv.ptr<vector<3xi32>, Input>
+ %31 = spv.Load "Input" %30 : vector<3xi32>
+ %32 = spv.CompositeExtract %31[1 : i32] : vector<3xi32>
+ %33 = spv._address_of @__builtin_var_WorkgroupSize__ : !spv.ptr<vector<3xi32>, Input>
+ %34 = spv.Load "Input" %33 : vector<3xi32>
+ %35 = spv.CompositeExtract %34[2 : i32] : vector<3xi32>
+ // CHECK: spv.IAdd [[ARG3]]
+ %36 = spv.IAdd %arg3, %2 : i32
+ // CHECK: spv.IAdd [[ARG4]]
+ %37 = spv.IAdd %arg4, %11 : i32
+ // CHECK: spv.AccessChain [[ARG0]]
+ %38 = spv.AccessChain %arg0[%36, %37] : !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+ %39 = spv.Load "StorageBuffer" %38 : f32
+ // CHECK: spv.AccessChain [[ARG1]]
+ %40 = spv.AccessChain %arg1[%36, %37] : !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+ %41 = spv.Load "StorageBuffer" %40 : f32
+ %42 = spv.FAdd %39, %41 : f32
+ // CHECK: spv.AccessChain [[ARG2]]
+ %43 = spv.AccessChain %arg2[%36, %37] : !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+ spv.Store "StorageBuffer" %43, %42 : f32
+ spv.Return
+ }
+ // CHECK: spv.EntryPoint "GLCompute" [[FN]], [[WORKGROUPID]], [[LOCALINVOCATIONID]], [[NUMWORKGROUPS]], [[WORKGROUPSIZE]]
+ // CHECK-NEXT: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
+} attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
--- /dev/null
+// RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s
+
+// CHECK-LABEL: spv.module
+spv.module "Logical" "GLSL450" {
+ // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+ // CHECK: func [[FN:@.*]]()
+ func @kernel_1(%arg0: f32
+ {spirv.interface_var_abi = {binding = 0 : i32,
+ descriptor_set = 0 : i32,
+ storage_class = 12 : i32}},
+ %arg1: !spv.ptr<!spv.array<12 x f32>, StorageBuffer>
+ {spirv.interface_var_abi = {binding = 1 : i32,
+ descriptor_set = 0 : i32,
+ storage_class = 12 : i32}})
+ attributes {spirv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+ // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
+ // CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
+ // CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
+ // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
+ // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
+ // CHECK: [[ARG0PTR:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
+ // CHECK: [[ARG0:%.*]] = spv.Load "StorageBuffer" [[ARG0PTR]]
+ // CHECK: spv.Return
+ spv.Return
+ }
+ // CHECK: spv.EntryPoint "GLCompute" [[FN]]
+ // CHECK: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
+} attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}