}
```
+## Target environment
+
+SPIR-V aims to support multiple execution environments as specified by client
+APIs. These execution environments affect the availability of certain SPIR-V
+features. For example, a [Vulkan 1.1][VulkanSpirv] implementation must support
+the 1.0, 1.1, 1.2, and 1.3 versions of SPIR-V and the 1.0 version of the SPIR-V
+extended instructions for GLSL. Further Vulkan extensions may enable more SPIR-V
+instructions.
+
+SPIR-V compilation should also take into consideration of the execution
+environment, so we generate SPIR-V modules valid for the target environment.
+This is conveyed by the `spv.target_env` attribute. It is a triple of
+
+* `version`: a 32-bit integer indicating the target SPIR-V version.
+* `extensions`: a string array attribute containing allowed extensions.
+* `capabilities`: a 32-bit integer array attribute containing allowed
+ capabilities.
+
+Dialect conversion framework will utilize the information in `spv.target_env`
+to properly filter out patterns and ops not available in the target execution
+environment.
+
## Shader interface (ABI)
SPIR-V itself is just expressing computation happening on GPU device. SPIR-V
additional rules are imposed by [Vulkan execution environment][VulkanSpirv]. The
lowering described below implements both these requirements.)
+### `SPIRVConversionTarget`
+
+The `mlir::spirv::SPIRVConversionTarget` class derives from the
+`mlir::ConversionTarget` class and serves as a utility to define a conversion
+target satisfying a given [`spv.target_env`](#target-environment). It registers
+proper hooks to check the dynamic legality of SPIR-V ops. Users can further
+register other legality constraints into the returned `SPIRVConversionTarget`.
-### SPIRVTypeConverter
+### `SPIRVTypeConverter`
-The `mlir::spirv::SPIRVTypeConverter` derives from
-`mlir::TypeConverter` and provides type conversion for standard
-types to SPIR-V types:
+The `mlir::SPIRVTypeConverter` derives from `mlir::TypeConverter` and provides
+type conversion for standard types to SPIR-V types:
* [Standard Integer][MlirIntegerType] -> Standard Integer
* [Standard Float][MlirFloatType] -> Standard Float
(TODO: Allow for configuring the integer width to use for `index` types in the
SPIR-V dialect)
-### SPIRVOpLowering
+### `SPIRVOpLowering`
-`mlir::spirv::SPIRVOpLowering` is a base class that can be used to define the
-patterns used for implementing the lowering. For now this only provides derived
-classes access to an instance of `mlir::spirv::SPIRVTypeLowering` class.
+`mlir::SPIRVOpLowering` is a base class that can be used to define the patterns
+used for implementing the lowering. For now this only provides derived classes
+access to an instance of `mlir::SPIRVTypeLowering` class.
### Utility functions for lowering
#ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
#define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallSet.h"
namespace mlir {
};
namespace spirv {
-enum class BuiltIn : uint32_t;
+class SPIRVConversionTarget : public ConversionTarget {
+public:
+ /// Creates a SPIR-V conversion target for the given target environment.
+ static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetEnv,
+ MLIRContext *context);
+
+private:
+ SPIRVConversionTarget(TargetEnvAttr targetEnv, MLIRContext *context);
+
+ // Be explicit that instance of this class cannot be copied or moved: there
+ // are lambdas capturing fields of the instance.
+ SPIRVConversionTarget(const SPIRVConversionTarget &) = delete;
+ SPIRVConversionTarget(SPIRVConversionTarget &&) = delete;
+ SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete;
+ SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete;
+
+ /// Returns true if the given `op` is legal to use under the current target
+ /// environment.
+ bool isLegalOp(Operation *op);
+
+ Version givenVersion; /// SPIR-V version to target
+ llvm::SmallSet<Extension, 4> givenExtensions; /// Allowed extensions
+ llvm::SmallSet<Capability, 8> givenCapabilities; /// Allowed capabilities
+};
/// Returns a value that represents a builtin variable value within the SPIR-V
/// module.
namespace spirv {
enum class StorageClass : uint32_t;
-/// Attribute name for specifying argument ABI information.
+/// Returns the attribute name for specifying argument ABI information.
StringRef getInterfaceVarABIAttrName();
-/// Get the InterfaceVarABIAttr given its fields.
+/// Gets the InterfaceVarABIAttr given its fields.
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet,
unsigned binding,
StorageClass storageClass,
MLIRContext *context);
-/// Attribute name for specifying entry point information.
+/// Returns the attribute name for specifying entry point information.
StringRef getEntryPointABIAttrName();
-/// Get the EntryPointABIAttr given its fields.
+/// Gets the EntryPointABIAttr given its fields.
EntryPointABIAttr getEntryPointABIAttr(ArrayRef<int32_t> localSize,
MLIRContext *context);
+
+/// Returns the attribute name for specifying SPIR-V target environment.
+StringRef getTargetEnvAttrName();
+
+/// Returns the default target environment: SPIR-V 1.0 with Shader capability
+/// and no extra extensions.
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
+
+/// Queries the target environment from the given `op` or returns the default
+/// target environment (SPIR-V 1.0 with Shader capability and no extra
+/// extensions) if not provided.
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
} // namespace spirv
} // namespace mlir
-//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
+//===- TargetAndABI.td - SPIR-V Target and ABI definitions -*- tablegen -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
//
//===----------------------------------------------------------------------===//
//
-// This is the base file for supporting lowering to SPIR-V dialect. This
-// file defines SPIR-V attributes used for specifying the shader
-// interface or ABI. This is because SPIR-V module is expected to work in
-// an execution environment as specified by a client API. A SPIR-V module
-// needs to "link" correctly with the execution environment regarding the
-// resources that are used in the SPIR-V module and get populated with
-// data via the client API. The shader interface (or ABI) is passed into
-// SPIR-V lowering path via attributes defined in this file. A
-// compilation flow targeting SPIR-V is expected to attach such
+// This is the base file for supporting lowering to SPIR-V dialect. This file
+// defines SPIR-V attributes used for specifying the shader interface or ABI.
+// This is because SPIR-V module is expected to work in an execution environment
+// as specified by a client API. A SPIR-V module needs to "link" correctly with
+// the execution environment regarding the resources that are used in the SPIR-V
+// module and get populated with data via the client API. The shader interface
+// (or ABI) is passed into SPIR-V lowering path via attributes defined in this
+// file. A compilation flow targeting SPIR-V is expected to attach such
// attributes to resources and other suitable places.
//
//===----------------------------------------------------------------------===//
-#ifndef SPIRV_LOWERING
-#define SPIRV_LOWERING
+#ifndef SPIRV_TARGET_AND_ABI
+#define SPIRV_TARGET_AND_ABI
include "mlir/Dialect/SPIRV/SPIRVBase.td"
// For arguments that eventually map to spv.globalVariable for the
// shader interface, this attribute specifies the information regarding
-// the global variable :
+// the global variable:
// 1) Descriptor Set.
// 2) Binding number.
// 3) Storage class.
-def SPV_InterfaceVarABIAttr:
- StructAttr<"InterfaceVarABIAttr", SPV_Dialect,
- [StructFieldAttr<"descriptor_set", I32Attr>,
- StructFieldAttr<"binding", I32Attr>,
- StructFieldAttr<"storage_class", SPV_StorageClassAttr>]>;
+def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPV_Dialect, [
+ StructFieldAttr<"descriptor_set", I32Attr>,
+ StructFieldAttr<"binding", I32Attr>,
+ StructFieldAttr<"storage_class", SPV_StorageClassAttr>
+]>;
// For entry functions, this attribute specifies information related to entry
// points in the generated SPIR-V module:
// 1) WorkGroup Size.
-def SPV_EntryPointABIAttr:
- StructAttr<"EntryPointABIAttr", SPV_Dialect,
- [StructFieldAttr<"local_size", I32ElementsAttr>]>;
+def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPV_Dialect, [
+ StructFieldAttr<"local_size", I32ElementsAttr>
+]>;
-#endif // SPIRV_LOWERING
+def SPV_ExtensionArrayAttr : TypedArrayAttrBase<
+ SPV_ExtensionAttr, "SPIR-V extension array attribute">;
+
+def SPV_CapabilityArrayAttr : TypedArrayAttrBase<
+ SPV_CapabilityAttr, "SPIR-V capability array attribute">;
+
+// For the generated SPIR-V module, this attribute specifies the target version,
+// allowed extensions and capabilities.
+def SPV_TargetEnvAttr : StructAttr<"TargetEnvAttr", SPV_Dialect, [
+ StructFieldAttr<"version", SPV_VersionAttr>,
+ StructFieldAttr<"extensions", SPV_ExtensionArrayAttr>,
+ StructFieldAttr<"capabilities", SPV_CapabilityArrayAttr>
+]>;
+
+#endif // SPIRV_TARGET_AND_ABI
} // namespace
void GPUToSPIRVPass::runOnModule() {
- auto context = &getContext();
- auto module = getModule();
+ MLIRContext *context = &getContext();
+ ModuleOp module = getModule();
SmallVector<Operation *, 1> kernelModules;
OpBuilder builder(context);
populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
- ConversionTarget target(*context);
- target.addLegalDialect<spirv::SPIRVDialect>();
- target.addDynamicallyLegalOp<FuncOp>(
+ std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
+ spirv::lookupTargetEnvOrDefault(module), context);
+ target->addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
- if (failed(applyFullConversion(kernelModules, target, patterns,
+ if (failed(applyFullConversion(kernelModules, *target, patterns,
&typeConverter))) {
return signalPassFailure();
}
}
void ConvertStandardToSPIRVPass::runOnModule() {
- OwningRewritePatternList patterns;
- auto context = &getContext();
- auto module = getModule();
+ MLIRContext *context = &getContext();
+ ModuleOp module = getModule();
SPIRVTypeConverter typeConverter;
+ OwningRewritePatternList patterns;
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
patterns.insert<FuncOpConversion>(context, typeConverter);
- ConversionTarget target(*(module.getContext()));
- target.addLegalDialect<spirv::SPIRVDialect>();
- target.addDynamicallyLegalOp<FuncOp>(
+
+ std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
+ spirv::lookupTargetEnvOrDefault(module), context);
+ target->addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
- if (failed(applyPartialConversion(module, target, patterns))) {
+ if (failed(applyPartialConversion(module, *target, patterns))) {
return signalPassFailure();
}
}
StringRef symbol = attribute.first.strref();
Attribute attr = attribute.second;
- if (symbol != spirv::getEntryPointABIAttrName())
+ // TODO(antiagainst): figure out a way to generate the description from the
+ // StructAttr definition.
+ if (symbol == spirv::getEntryPointABIAttrName()) {
+ if (!attr.isa<spirv::EntryPointABIAttr>())
+ return op->emitError("'")
+ << symbol
+ << "' attribute must be a dictionary attribute containing one "
+ "32-bit integer elements attribute: 'local_size'";
+ } else if (symbol == spirv::getTargetEnvAttrName()) {
+ if (!attr.isa<spirv::TargetEnvAttr>())
+ return op->emitError("'")
+ << symbol
+ << "' must be a dictionary attribute containing one 32-bit "
+ "integer attribute 'version', one string array attribute "
+ "'extensions', and one 32-bit integer array attribute "
+ "'capabilities'";
+ } else {
return op->emitError("found unsupported '")
<< symbol << "' attribute on operation";
-
- if (!spirv::EntryPointABIAttr::classof(attr))
- return op->emitError("'")
- << symbol
- << "' attribute must be a dictionary attribute containing one "
- "integer elements attribute: 'local_size'";
+ }
return success();
}
<< symbol << "' attribute on region "
<< (forArg ? "argument" : "result");
- if (!spirv::InterfaceVarABIAttr::classof(attr))
+ if (!attr.isa<spirv::InterfaceVarABIAttr>())
return emitError(loc, "'")
<< symbol
<< "' attribute must be a dictionary attribute containing three "
- "integer attributes: 'descriptor_set', 'binding', and "
+ "32-bit integer attributes: 'descriptor_set', 'binding', and "
"'storage_class'";
return success();
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/Support/Debug.h"
+
+#include <functional>
+
+#define DEBUG_TYPE "mlir-spirv-lowering"
using namespace mlir;
funcOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
return success();
}
+
+//===----------------------------------------------------------------------===//
+// SPIR-V ConversionTarget
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<spirv::SPIRVConversionTarget>
+spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetEnv,
+ MLIRContext *context) {
+ std::unique_ptr<SPIRVConversionTarget> target(
+ // std::make_unique does not work here because the constructor is private.
+ new SPIRVConversionTarget(targetEnv, context));
+ SPIRVConversionTarget *targetPtr = target.get();
+ target->addDynamicallyLegalDialect<SPIRVDialect>(
+ Optional<ConversionTarget::DynamicLegalityCallbackFn>(
+ // We need to capture the raw pointer here because it is stable:
+ // target will be destroyed once this function is returned.
+ [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }));
+ return target;
+}
+
+spirv::SPIRVConversionTarget::SPIRVConversionTarget(
+ spirv::TargetEnvAttr targetEnv, MLIRContext *context)
+ : ConversionTarget(*context),
+ givenVersion(static_cast<spirv::Version>(targetEnv.version().getInt())) {
+ for (Attribute extAttr : targetEnv.extensions())
+ givenExtensions.insert(
+ *spirv::symbolizeExtension(extAttr.cast<StringAttr>().getValue()));
+
+ for (Attribute capAttr : targetEnv.capabilities())
+ givenCapabilities.insert(
+ static_cast<spirv::Capability>(capAttr.cast<IntegerAttr>().getInt()));
+}
+
+bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
+ // Make sure this op is available at the given version. Ops not implementing
+ // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
+ // SPIR-V versions.
+ if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
+ if (minVersion.getMinVersion() > givenVersion) {
+ LLVM_DEBUG(llvm::dbgs()
+ << op->getName() << " illegal: requiring min version "
+ << spirv::stringifyVersion(minVersion.getMinVersion())
+ << "\n");
+ return false;
+ }
+ if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
+ if (maxVersion.getMaxVersion() < givenVersion) {
+ LLVM_DEBUG(llvm::dbgs()
+ << op->getName() << " illegal: requiring max version "
+ << spirv::stringifyVersion(maxVersion.getMaxVersion())
+ << "\n");
+ return false;
+ }
+
+ // Make sure this op's required extensions are allowed to use. For each op,
+ // we return a vector of vector for its extension requirements following
+ // ((Extension::A OR Extenion::B) AND (Extension::C OR Extension::D))
+ // convention. Ops not implementing QueryExtensionInterface do not require
+ // extensions to be available.
+ if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
+ auto exts = extensions.getExtensions();
+ for (const auto &ors : exts)
+ if (llvm::all_of(ors, [this](spirv::Extension ext) {
+ return this->givenExtensions.count(ext) == 0;
+ })) {
+ LLVM_DEBUG(llvm::dbgs() << op->getName()
+ << " illegal: missing required extension\n");
+ return false;
+ }
+ }
+
+ // Make sure this op's required extensions are allowed to use. For each op,
+ // we return a vector of vector for its capability requirements following
+ // ((Capability::A OR Extenion::B) AND (Capability::C OR Capability::D))
+ // convention. Ops not implementing QueryExtensionInterface do not require
+ // extensions to be available.
+ if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
+ auto caps = capabilities.getCapabilities();
+ for (const auto &ors : caps)
+ if (llvm::all_of(ors, [this](spirv::Capability cap) {
+ return this->givenCapabilities.count(cap) == 0;
+ })) {
+ LLVM_DEBUG(llvm::dbgs() << op->getName()
+ << " illegal: missing required capability\n");
+ return false;
+ }
+ }
+
+ return true;
+};
-//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===//
+//===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
#include "mlir/Dialect/SPIRV/TargetAndABI.cpp.inc"
}
-StringRef mlir::spirv::getInterfaceVarABIAttrName() {
+StringRef spirv::getInterfaceVarABIAttrName() {
return "spv.interface_var_abi";
}
-mlir::spirv::InterfaceVarABIAttr
-mlir::spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
- spirv::StorageClass storageClass,
- MLIRContext *context) {
+spirv::InterfaceVarABIAttr
+spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
+ spirv::StorageClass storageClass,
+ MLIRContext *context) {
Type i32Type = IntegerType::get(32, context);
- return mlir::spirv::InterfaceVarABIAttr::get(
+ return spirv::InterfaceVarABIAttr::get(
IntegerAttr::get(i32Type, descriptorSet),
IntegerAttr::get(i32Type, binding),
IntegerAttr::get(i32Type, static_cast<int64_t>(storageClass)), context);
}
-StringRef mlir::spirv::getEntryPointABIAttrName() {
- return "spv.entry_point_abi";
-}
+StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; }
-mlir::spirv::EntryPointABIAttr
-mlir::spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize,
- MLIRContext *context) {
+spirv::EntryPointABIAttr
+spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, MLIRContext *context) {
assert(localSize.size() == 3);
- return mlir::spirv::EntryPointABIAttr::get(
+ return spirv::EntryPointABIAttr::get(
DenseElementsAttr::get<int32_t>(
VectorType::get(3, IntegerType::get(32, context)), localSize)
.cast<DenseIntElementsAttr>(),
context);
}
+
+StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; }
+
+spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
+ Builder builder(context);
+ return spirv::TargetEnvAttr::get(
+ builder.getI32IntegerAttr(static_cast<uint32_t>(spirv::Version::V_1_0)),
+ builder.getI32ArrayAttr({}),
+ builder.getI32ArrayAttr(
+ {static_cast<uint32_t>(spirv::Capability::Shader)}),
+ context);
+}
+
+spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
+ if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
+ spirv::getTargetEnvAttrName()))
+ return attr;
+ return getDefaultTargetEnv(op->getContext());
+}
OwningRewritePatternList patterns;
patterns.insert<FuncOpLowering>(context, typeConverter);
- ConversionTarget target(*context);
- target.addLegalDialect<spirv::SPIRVDialect>();
+ std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
+ spirv::lookupTargetEnvOrDefault(module), context);
auto entryPointAttrName = spirv::getEntryPointABIAttrName();
- target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ target->addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return op.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName) &&
op.getNumResults() == 0 && op.getNumArguments() == 0;
});
- target.addLegalOp<ReturnOp>();
+ target->addLegalOp<ReturnOp>();
if (failed(
- applyPartialConversion(module, target, patterns, &typeConverter))) {
+ applyPartialConversion(module, *target, patterns, &typeConverter))) {
return signalPassFailure();
}
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Function.h"
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Printing op availability pass
+//===----------------------------------------------------------------------===//
+
namespace {
/// A pass for testing SPIR-V op availability.
-struct TestAvailability : public FunctionPass<TestAvailability> {
+struct PrintOpAvailability : public FunctionPass<PrintOpAvailability> {
void runOnFunction() override;
};
} // end anonymous namespace
-void TestAvailability::runOnFunction() {
+void PrintOpAvailability::runOnFunction() {
auto f = getFunction();
llvm::outs() << f.getName() << "\n";
});
}
-static PassRegistration<TestAvailability> pass("test-spirv-op-availability",
- "Test SPIR-V op availability");
+static PassRegistration<PrintOpAvailability>
+ printOpAvailabilityPass("test-spirv-op-availability",
+ "Test SPIR-V op availability");
+
+//===----------------------------------------------------------------------===//
+// Converting target environment pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pass for testing SPIR-V op availability.
+struct ConvertToTargetEnv : public FunctionPass<ConvertToTargetEnv> {
+ void runOnFunction() override;
+};
+
+struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
+ ConvertToAtomCmpExchangeWeak(MLIRContext *context);
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+};
+
+struct ConvertToGroupNonUniformBallot : public RewritePattern {
+ ConvertToGroupNonUniformBallot(MLIRContext *context);
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+};
+
+struct ConvertToSubgroupBallot : public RewritePattern {
+ ConvertToSubgroupBallot(MLIRContext *context);
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+};
+} // end anonymous namespace
+
+void ConvertToTargetEnv::runOnFunction() {
+ MLIRContext *context = &getContext();
+ FuncOp fn = getFunction();
+
+ auto targetEnv = fn.getOperation()
+ ->getAttr(spirv::getTargetEnvAttrName())
+ .cast<spirv::TargetEnvAttr>();
+ auto target = spirv::SPIRVConversionTarget::get(targetEnv, context);
+
+ OwningRewritePatternList patterns;
+ patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToGroupNonUniformBallot,
+ ConvertToSubgroupBallot>(context);
+
+ if (failed(applyPartialConversion(fn, *target, patterns)))
+ return signalPassFailure();
+}
+
+ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
+ : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op",
+ {"spv.AtomicCompareExchangeWeak"}, 1, context) {}
+
+PatternMatchResult
+ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ Value ptr = op->getOperand(0);
+ Value value = op->getOperand(1);
+ Value comparator = op->getOperand(2);
+
+ // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in
+ // memory semantics to additionally require AtomicStorage capability.
+ rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
+ op, value.getType(), ptr, spirv::Scope::Workgroup,
+ spirv::MemorySemantics::AcquireRelease |
+ spirv::MemorySemantics::AtomicCounterMemory,
+ spirv::MemorySemantics::Acquire, value, comparator);
+ return matchSuccess();
+}
+
+ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
+ MLIRContext *context)
+ : RewritePattern("test.convert_to_group_non_uniform_ballot_op",
+ {"spv.GroupNonUniformBallot"}, 1, context) {}
+
+PatternMatchResult ConvertToGroupNonUniformBallot::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ Value predicate = op->getOperand(0);
+
+ rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
+ op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
+ return matchSuccess();
+}
+
+ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
+ : RewritePattern("test.convert_to_subgroup_ballot_op",
+ {"spv.SubgroupBallotKHR"}, 1, context) {}
+
+PatternMatchResult
+ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ Value predicate = op->getOperand(0);
+
+ rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
+ op, op->getResult(0).getType(), predicate);
+ return matchSuccess();
+}
+
+static PassRegistration<ConvertToTargetEnv>
+ convertToTargetEnvPass("test-spirv-target-env",
+ "Test SPIR-V target environment");
// spv.entry_point_abi
//===----------------------------------------------------------------------===//
-// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one integer elements attribute: 'local_size'}}
+// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one 32-bit integer elements attribute: 'local_size'}}
func @spv_entry_point() attributes {
spv.entry_point_abi = 64
} { return }
// -----
-// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one integer elements attribute: 'local_size'}}
+// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one 32-bit integer elements attribute: 'local_size'}}
func @spv_entry_point() attributes {
spv.entry_point_abi = {local_size = 64}
} { return }
// spv.interface_var_abi
//===----------------------------------------------------------------------===//
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
func @interface_var(
%arg0 : f32 {spv.interface_var_abi = 64}
) { return }
// -----
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
func @interface_var(
%arg0 : f32 {spv.interface_var_abi = {binding = 0: i32}}
) { return }
// -----
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
func @interface_var() -> (f32 {spv.interface_var_abi = 64})
{
%0 = constant 10.0 : f32
// -----
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
func @interface_var() -> (f32 {spv.interface_var_abi = {binding = 0: i32}})
{
%0 = constant 10.0 : f32
%0 = constant 10.0 : f32
return %0: f32
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.target_env
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}}
+func @target_env_wrong_type() attributes {
+ spv.target_env = 64
+} { return }
+
+// -----
+
+// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}}
+func @target_env_missing_fields() attributes {
+ spv.target_env = {version = 0: i32}
+} { return }
+
+// -----
+
+// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}}
+func @target_env_wrong_extension_type() attributes {
+ spv.target_env = {version = 0: i32, extensions = [32: i32], capabilities = [1: i32]}
+} { return }
+
+// -----
+
+// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}}
+func @target_env_wrong_extension() attributes {
+ spv.target_env = {version = 0: i32, extensions = ["SPV_Something"], capabilities = [1: i32]}
+} { return }
+
+// -----
+
+func @target_env() attributes {
+ // CHECK: spv.target_env = {capabilities = [1 : i32], extensions = ["SPV_KHR_storage_buffer_storage_class"], version = 0 : i32}
+ spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_storage_buffer_storage_class"], capabilities = [1: i32]}
+} { return }
+
+// -----
+
+// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}}
+func @target_env_extra_fields() attributes {
+ spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_storage_buffer_storage_class"], capabilities = [1: i32], extra = 32}
+} { return }
--- /dev/null
+// RUN: mlir-opt -disable-pass-threading -test-spirv-target-env %s | FileCheck %s
+
+// Note: The following tests check that a spv.target_env can properly control
+// the conversion target and filter unavailable ops during the conversion.
+// We don't care about the op argument consistency too much; so certain enum
+// values for enum attributes may not make much sense for the test op.
+
+// spv.AtomicCompareExchangeWeak is available from SPIR-V 1.0 to 1.3 under
+// Kernel capability.
+// spv.AtomicCompareExchangeWeak has two memory semantics enum attribute,
+// whose value, if containing AtomicCounterMemory bit, additionally requires
+// AtomicStorage capability.
+
+// spv.GroupNonUniformBallot is available starting from SPIR-V 1.3 under
+// GroupNonUniform capability.
+
+// spv.SubgroupBallotKHR is available under in all SPIR-V versions under
+// SubgroupBallotKHR capability and SPV_KHR_shader_ballot extension.
+
+// Enum case symbol (value) map:
+// Version: 1.0 (0), 1.1 (1), 1.2 (2), 1.3 (3), 1.4 (4)
+// Capability: Kernel (6), AtomicStorage (21), GroupNonUniformBallot (64),
+// SubgroupBallotKHR (4423)
+
+//===----------------------------------------------------------------------===//
+// MaxVersion
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmp_exchange_weak_suitable_version_capabilities
+func @cmp_exchange_weak_suitable_version_capabilities(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 attributes {
+ spv.target_env = {version = 1: i32, extensions = [], capabilities = [6: i32, 21: i32]}
+} {
+ // CHECK: spv.AtomicCompareExchangeWeak "Workgroup" "AcquireRelease|AtomicCounterMemory" "Acquire"
+ %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr<i32, Workgroup>, i32, i32) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @cmp_exchange_weak_unsupported_version
+func @cmp_exchange_weak_unsupported_version(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 attributes {
+ spv.target_env = {version = 4: i32, extensions = [], capabilities = [6: i32, 21: i32]}
+} {
+ // CHECK: test.convert_to_atomic_compare_exchange_weak_op
+ %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr<i32, Workgroup>, i32, i32) -> (i32)
+ return %0: i32
+}
+
+//===----------------------------------------------------------------------===//
+// MinVersion
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_ballot_suitable_version
+func @group_non_uniform_ballot_suitable_version(%predicate: i1) -> vector<4xi32> attributes {
+ spv.target_env = {version = 4: i32, extensions = [], capabilities = [64: i32]}
+} {
+ // CHECK: spv.GroupNonUniformBallot "Workgroup"
+ %0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
+ return %0: vector<4xi32>
+}
+
+// CHECK-LABEL: @group_non_uniform_ballot_unsupported_version
+func @group_non_uniform_ballot_unsupported_version(%predicate: i1) -> vector<4xi32> attributes {
+ spv.target_env = {version = 1: i32, extensions = [], capabilities = [64: i32]}
+} {
+ // CHECK: test.convert_to_group_non_uniform_ballot_op
+ %0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
+ return %0: vector<4xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// Capability
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmp_exchange_weak_missing_capability_kernel
+func @cmp_exchange_weak_missing_capability_kernel(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 attributes {
+ spv.target_env = {version = 3: i32, extensions = [], capabilities = [21: i32]}
+} {
+ // CHECK: test.convert_to_atomic_compare_exchange_weak_op
+ %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr<i32, Workgroup>, i32, i32) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @cmp_exchange_weak_missing_capability_atomic_storage
+func @cmp_exchange_weak_missing_capability_atomic_storage(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 attributes {
+ spv.target_env = {version = 3: i32, extensions = [], capabilities = [6: i32]}
+} {
+ // CHECK: test.convert_to_atomic_compare_exchange_weak_op
+ %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr<i32, Workgroup>, i32, i32) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @subgroup_ballot_missing_capability
+func @subgroup_ballot_missing_capability(%predicate: i1) -> vector<4xi32> attributes {
+ spv.target_env = {version = 4: i32, extensions = ["SPV_KHR_shader_ballot"], capabilities = []}
+} {
+ // CHECK: test.convert_to_subgroup_ballot_op
+ %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
+ return %0: vector<4xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// Extension
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @subgroup_ballot_suitable_extension
+func @subgroup_ballot_suitable_extension(%predicate: i1) -> vector<4xi32> attributes {
+ spv.target_env = {version = 4: i32, extensions = ["SPV_KHR_shader_ballot"], capabilities = [4423: i32]}
+} {
+ // CHECK: spv.SubgroupBallotKHR
+ %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
+ return %0: vector<4xi32>
+}
+
+// CHECK-LABEL: @subgroup_ballot_missing_extension
+func @subgroup_ballot_missing_extension(%predicate: i1) -> vector<4xi32> attributes {
+ spv.target_env = {version = 4: i32, extensions = [], capabilities = [4423: i32]}
+} {
+ // CHECK: test.convert_to_subgroup_ballot_op
+ %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
+ return %0: vector<4xi32>
+}