[mlir][spirv] Properly support SPIR-V conversion target
authorLei Zhang <antiagainst@google.com>
Tue, 14 Jan 2020 23:23:25 +0000 (18:23 -0500)
committerLei Zhang <antiagainst@google.com>
Wed, 15 Jan 2020 00:18:42 +0000 (19:18 -0500)
This commit defines a new SPIR-V dialect attribute for specifying
a SPIR-V target environment. It is a dictionary attribute containing
the SPIR-V version, supported extension list, and allowed capability
list. A SPIRVConversionTarget subclass is created to take in the
target environment and sets proper dynmaically legal ops by querying
the op availability interface of SPIR-V ops to make sure they are
available in the specified target environment. All existing conversions
targeting SPIR-V is changed to use this SPIRVConversionTarget. It
probes whether the input IR has a `spv.target_env` attribute,
otherwise, it uses the default target environment: SPIR-V 1.0 with
Shader capability and no extra extensions.

Differential Revision: https://reviews.llvm.org/D72256

13 files changed:
mlir/docs/Dialects/SPIR-V.md
mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/test/Dialect/SPIRV/TestAvailability.cpp
mlir/test/Dialect/SPIRV/target-and-abi.mlir
mlir/test/Dialect/SPIRV/target-env.mlir [new file with mode: 0644]

index ed02b1609f2b5bfbfc8ba3739ed7dd2a7c502daf..502f100cfaf4c929db79c4d4b089d24901c66129 100644 (file)
@@ -725,6 +725,28 @@ func @foo() -> () {
 }
 ```
 
+## Target environment
+
+SPIR-V aims to support multiple execution environments as specified by client
+APIs. These execution environments affect the availability of certain SPIR-V
+features. For example, a [Vulkan 1.1][VulkanSpirv] implementation must support
+the 1.0, 1.1, 1.2, and 1.3 versions of SPIR-V and the 1.0 version of the SPIR-V
+extended instructions for GLSL. Further Vulkan extensions may enable more SPIR-V
+instructions.
+
+SPIR-V compilation should also take into consideration of the execution
+environment, so we generate SPIR-V modules valid for the target environment.
+This is conveyed by the `spv.target_env` attribute. It is a triple of
+
+*   `version`: a 32-bit integer indicating the target SPIR-V version.
+*   `extensions`: a string array attribute containing allowed extensions.
+*   `capabilities`: a 32-bit integer array attribute containing allowed
+    capabilities.
+
+Dialect conversion framework will utilize the information in `spv.target_env`
+to properly filter out patterns and ops not available in the target execution
+environment.
+
 ## Shader interface (ABI)
 
 SPIR-V itself is just expressing computation happening on GPU device. SPIR-V
@@ -852,12 +874,18 @@ classes are provided.
 additional rules are imposed by [Vulkan execution environment][VulkanSpirv]. The
 lowering described below implements both these requirements.)
 
+### `SPIRVConversionTarget`
+
+The `mlir::spirv::SPIRVConversionTarget` class derives from the
+`mlir::ConversionTarget` class and serves as a utility to define a conversion
+target satisfying a given [`spv.target_env`](#target-environment). It registers
+proper hooks to check the dynamic legality of SPIR-V ops. Users can further
+register other legality constraints into the returned `SPIRVConversionTarget`.
 
-### SPIRVTypeConverter
+### `SPIRVTypeConverter`
 
-The `mlir::spirv::SPIRVTypeConverter` derives from
-`mlir::TypeConverter` and provides type conversion for standard
-types to SPIR-V types:
+The `mlir::SPIRVTypeConverter` derives from `mlir::TypeConverter` and provides
+type conversion for standard types to SPIR-V types:
 
 *   [Standard Integer][MlirIntegerType] -> Standard Integer
 *   [Standard Float][MlirFloatType] -> Standard Float
@@ -874,11 +902,11 @@ supported in SPIR-V. Currently the `index` type is converted to `i32`.
 (TODO: Allow for configuring the integer width to use for `index` types in the
 SPIR-V dialect)
 
-### SPIRVOpLowering
+### `SPIRVOpLowering`
 
-`mlir::spirv::SPIRVOpLowering` is a base class that can be used to define the
-patterns used for implementing the lowering. For now this only provides derived
-classes access to an instance of `mlir::spirv::SPIRVTypeLowering` class.
+`mlir::SPIRVOpLowering` is a base class that can be used to define the patterns
+used for implementing the lowering. For now this only provides derived classes
+access to an instance of `mlir::SPIRVTypeLowering` class.
 
 ### Utility functions for lowering
 
index ea13bc0dbb333cb173fde786bbce544e7a4bcdaf..97000a908f997934a9242f10445000c923c47751 100644 (file)
 #ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
 #define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
 
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallSet.h"
 
 namespace mlir {
 
@@ -48,7 +50,30 @@ protected:
 };
 
 namespace spirv {
-enum class BuiltIn : uint32_t;
+class SPIRVConversionTarget : public ConversionTarget {
+public:
+  /// Creates a SPIR-V conversion target for the given target environment.
+  static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetEnv,
+                                                    MLIRContext *context);
+
+private:
+  SPIRVConversionTarget(TargetEnvAttr targetEnv, MLIRContext *context);
+
+  // Be explicit that instance of this class cannot be copied or moved: there
+  // are lambdas capturing fields of the instance.
+  SPIRVConversionTarget(const SPIRVConversionTarget &) = delete;
+  SPIRVConversionTarget(SPIRVConversionTarget &&) = delete;
+  SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete;
+  SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete;
+
+  /// Returns true if the given `op` is legal to use under the current target
+  /// environment.
+  bool isLegalOp(Operation *op);
+
+  Version givenVersion;                            /// SPIR-V version to target
+  llvm::SmallSet<Extension, 4> givenExtensions;    /// Allowed extensions
+  llvm::SmallSet<Capability, 8> givenCapabilities; /// Allowed capabilities
+};
 
 /// Returns a value that represents a builtin variable value within the SPIR-V
 /// module.
index 9ac5c69b24a5e2c8bd3bd40c80133547487e0d37..72baf4aa6ed247fed1e953ccfa3795e883b0f46d 100644 (file)
@@ -27,21 +27,33 @@ class Value;
 namespace spirv {
 enum class StorageClass : uint32_t;
 
-/// Attribute name for specifying argument ABI information.
+/// Returns the attribute name for specifying argument ABI information.
 StringRef getInterfaceVarABIAttrName();
 
-/// Get the InterfaceVarABIAttr given its fields.
+/// Gets the InterfaceVarABIAttr given its fields.
 InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet,
                                            unsigned binding,
                                            StorageClass storageClass,
                                            MLIRContext *context);
 
-/// Attribute name for specifying entry point information.
+/// Returns the attribute name for specifying entry point information.
 StringRef getEntryPointABIAttrName();
 
-/// Get the EntryPointABIAttr given its fields.
+/// Gets the EntryPointABIAttr given its fields.
 EntryPointABIAttr getEntryPointABIAttr(ArrayRef<int32_t> localSize,
                                        MLIRContext *context);
+
+/// Returns the attribute name for specifying SPIR-V target environment.
+StringRef getTargetEnvAttrName();
+
+/// Returns the default target environment: SPIR-V 1.0 with Shader capability
+/// and no extra extensions.
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
+
+/// Queries the target environment from the given `op` or returns the default
+/// target environment (SPIR-V 1.0 with Shader capability and no extra
+/// extensions) if not provided.
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
 } // namespace spirv
 } // namespace mlir
 
index 91a8ff68bbf86229156aaa5cc1417f1db3e668fa..0496d76353d09883bac2b49950208a40a4b2d8b2 100644 (file)
@@ -1,4 +1,4 @@
-//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
+//===- TargetAndABI.td - SPIR-V Target and ABI definitions -*- tablegen -*-===//
 //
 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,41 +6,54 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This is the base file for supporting lowering to SPIR-V dialect. This
-// file defines SPIR-V attributes used for specifying the shader
-// interface or ABI. This is because SPIR-V module is expected to work in
-// an execution environment as specified by a client API. A SPIR-V module
-// needs to "link" correctly with the execution environment regarding the
-// resources that are used in the SPIR-V module and get populated with
-// data via the client API. The shader interface (or ABI) is passed into
-// SPIR-V lowering path via attributes defined in this file. A
-// compilation flow targeting SPIR-V is expected to attach such
+// This is the base file for supporting lowering to SPIR-V dialect. This file
+// defines SPIR-V attributes used for specifying the shader interface or ABI.
+// This is because SPIR-V module is expected to work in an execution environment
+// as specified by a client API. A SPIR-V module needs to "link" correctly with
+// the execution environment regarding the resources that are used in the SPIR-V
+// module and get populated with data via the client API. The shader interface
+// (or ABI) is passed into SPIR-V lowering path via attributes defined in this
+// file. A compilation flow targeting SPIR-V is expected to attach such
 // attributes to resources and other suitable places.
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef SPIRV_LOWERING
-#define SPIRV_LOWERING
+#ifndef SPIRV_TARGET_AND_ABI
+#define SPIRV_TARGET_AND_ABI
 
 include "mlir/Dialect/SPIRV/SPIRVBase.td"
 
 // For arguments that eventually map to spv.globalVariable for the
 // shader interface, this attribute specifies the information regarding
-// the global variable :
+// the global variable:
 // 1) Descriptor Set.
 // 2) Binding number.
 // 3) Storage class.
-def SPV_InterfaceVarABIAttr:
-    StructAttr<"InterfaceVarABIAttr", SPV_Dialect,
-               [StructFieldAttr<"descriptor_set", I32Attr>,
-                StructFieldAttr<"binding", I32Attr>,
-                StructFieldAttr<"storage_class", SPV_StorageClassAttr>]>;
+def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPV_Dialect, [
+    StructFieldAttr<"descriptor_set", I32Attr>,
+    StructFieldAttr<"binding", I32Attr>,
+    StructFieldAttr<"storage_class", SPV_StorageClassAttr>
+]>;
 
 // For entry functions, this attribute specifies information related to entry
 // points in the generated SPIR-V module:
 // 1) WorkGroup Size.
-def SPV_EntryPointABIAttr:
-    StructAttr<"EntryPointABIAttr", SPV_Dialect,
-               [StructFieldAttr<"local_size", I32ElementsAttr>]>;
+def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPV_Dialect, [
+    StructFieldAttr<"local_size", I32ElementsAttr>
+]>;
 
-#endif // SPIRV_LOWERING
+def SPV_ExtensionArrayAttr : TypedArrayAttrBase<
+    SPV_ExtensionAttr, "SPIR-V extension array attribute">;
+
+def SPV_CapabilityArrayAttr : TypedArrayAttrBase<
+    SPV_CapabilityAttr, "SPIR-V capability array attribute">;
+
+// For the generated SPIR-V module, this attribute specifies the target version,
+// allowed extensions and capabilities.
+def SPV_TargetEnvAttr : StructAttr<"TargetEnvAttr", SPV_Dialect, [
+    StructFieldAttr<"version", SPV_VersionAttr>,
+    StructFieldAttr<"extensions", SPV_ExtensionArrayAttr>,
+    StructFieldAttr<"capabilities", SPV_CapabilityArrayAttr>
+]>;
+
+#endif // SPIRV_TARGET_AND_ABI
index bc8273ec2a9778b64b779dc0b6e11014feedcbaf..4dda8bdc2b39f865c323db127bf3bf53d7d0179a 100644 (file)
@@ -55,8 +55,8 @@ private:
 } // namespace
 
 void GPUToSPIRVPass::runOnModule() {
-  auto context = &getContext();
-  auto module = getModule();
+  MLIRContext *context = &getContext();
+  ModuleOp module = getModule();
 
   SmallVector<Operation *, 1> kernelModules;
   OpBuilder builder(context);
@@ -73,12 +73,12 @@ void GPUToSPIRVPass::runOnModule() {
   populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
 
-  ConversionTarget target(*context);
-  target.addLegalDialect<spirv::SPIRVDialect>();
-  target.addDynamicallyLegalOp<FuncOp>(
+  std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
+      spirv::lookupTargetEnvOrDefault(module), context);
+  target->addDynamicallyLegalOp<FuncOp>(
       [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
 
-  if (failed(applyFullConversion(kernelModules, target, patterns,
+  if (failed(applyFullConversion(kernelModules, *target, patterns,
                                  &typeConverter))) {
     return signalPassFailure();
   }
index 52456b6e46d04d650ca36f6718e442b897c79740..6c7a453fe48559926c4deb1a3c22ce211c989abb 100644 (file)
@@ -64,19 +64,20 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
 }
 
 void ConvertStandardToSPIRVPass::runOnModule() {
-  OwningRewritePatternList patterns;
-  auto context = &getContext();
-  auto module = getModule();
+  MLIRContext *context = &getContext();
+  ModuleOp module = getModule();
 
   SPIRVTypeConverter typeConverter;
+  OwningRewritePatternList patterns;
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
   patterns.insert<FuncOpConversion>(context, typeConverter);
-  ConversionTarget target(*(module.getContext()));
-  target.addLegalDialect<spirv::SPIRVDialect>();
-  target.addDynamicallyLegalOp<FuncOp>(
+
+  std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
+      spirv::lookupTargetEnvOrDefault(module), context);
+  target->addDynamicallyLegalOp<FuncOp>(
       [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
 
-  if (failed(applyPartialConversion(module, target, patterns))) {
+  if (failed(applyPartialConversion(module, *target, patterns))) {
     return signalPassFailure();
   }
 }
index 0e543df29b64d11e43f17c22b7c717fc5a30d1fa..42c05bdd9080f63b8724af6c981cbf60dca81307 100644 (file)
@@ -650,15 +650,26 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
   StringRef symbol = attribute.first.strref();
   Attribute attr = attribute.second;
 
-  if (symbol != spirv::getEntryPointABIAttrName())
+  // TODO(antiagainst): figure out a way to generate the description from the
+  // StructAttr definition.
+  if (symbol == spirv::getEntryPointABIAttrName()) {
+    if (!attr.isa<spirv::EntryPointABIAttr>())
+      return op->emitError("'")
+             << symbol
+             << "' attribute must be a dictionary attribute containing one "
+                "32-bit integer elements attribute: 'local_size'";
+  } else if (symbol == spirv::getTargetEnvAttrName()) {
+    if (!attr.isa<spirv::TargetEnvAttr>())
+      return op->emitError("'")
+             << symbol
+             << "' must be a dictionary attribute containing one 32-bit "
+                "integer attribute 'version', one string array attribute "
+                "'extensions', and one 32-bit integer array attribute "
+                "'capabilities'";
+  } else {
     return op->emitError("found unsupported '")
            << symbol << "' attribute on operation";
-
-  if (!spirv::EntryPointABIAttr::classof(attr))
-    return op->emitError("'")
-           << symbol
-           << "' attribute must be a dictionary attribute containing one "
-              "integer elements attribute: 'local_size'";
+  }
 
   return success();
 }
@@ -675,11 +686,11 @@ verifyRegionAttribute(Location loc, NamedAttribute attribute, bool forArg) {
            << symbol << "' attribute on region "
            << (forArg ? "argument" : "result");
 
-  if (!spirv::InterfaceVarABIAttr::classof(attr))
+  if (!attr.isa<spirv::InterfaceVarABIAttr>())
     return emitError(loc, "'")
            << symbol
            << "' attribute must be a dictionary attribute containing three "
-              "integer attributes: 'descriptor_set', 'binding', and "
+              "32-bit integer attributes: 'descriptor_set', 'binding', and "
               "'storage_class'";
 
   return success();
index 2a635351485a43da740e4b7e05d298b572c89f34..adc610349b36f0e7b4bcfe4774d65f8cc2207561 100644 (file)
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "llvm/ADT/Sequence.h"
+#include "llvm/Support/Debug.h"
+
+#include <functional>
+
+#define DEBUG_TYPE "mlir-spirv-lowering"
 
 using namespace mlir;
 
@@ -214,3 +219,93 @@ mlir::spirv::setABIAttrs(FuncOp funcOp, spirv::EntryPointABIAttr entryPointInfo,
   funcOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// SPIR-V ConversionTarget
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<spirv::SPIRVConversionTarget>
+spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetEnv,
+                                  MLIRContext *context) {
+  std::unique_ptr<SPIRVConversionTarget> target(
+      // std::make_unique does not work here because the constructor is private.
+      new SPIRVConversionTarget(targetEnv, context));
+  SPIRVConversionTarget *targetPtr = target.get();
+  target->addDynamicallyLegalDialect<SPIRVDialect>(
+      Optional<ConversionTarget::DynamicLegalityCallbackFn>(
+          // We need to capture the raw pointer here because it is stable:
+          // target will be destroyed once this function is returned.
+          [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }));
+  return target;
+}
+
+spirv::SPIRVConversionTarget::SPIRVConversionTarget(
+    spirv::TargetEnvAttr targetEnv, MLIRContext *context)
+    : ConversionTarget(*context),
+      givenVersion(static_cast<spirv::Version>(targetEnv.version().getInt())) {
+  for (Attribute extAttr : targetEnv.extensions())
+    givenExtensions.insert(
+        *spirv::symbolizeExtension(extAttr.cast<StringAttr>().getValue()));
+
+  for (Attribute capAttr : targetEnv.capabilities())
+    givenCapabilities.insert(
+        static_cast<spirv::Capability>(capAttr.cast<IntegerAttr>().getInt()));
+}
+
+bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
+  // Make sure this op is available at the given version. Ops not implementing
+  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
+  // SPIR-V versions.
+  if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
+    if (minVersion.getMinVersion() > givenVersion) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << op->getName() << " illegal: requiring min version "
+                 << spirv::stringifyVersion(minVersion.getMinVersion())
+                 << "\n");
+      return false;
+    }
+  if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
+    if (maxVersion.getMaxVersion() < givenVersion) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << op->getName() << " illegal: requiring max version "
+                 << spirv::stringifyVersion(maxVersion.getMaxVersion())
+                 << "\n");
+      return false;
+    }
+
+  // Make sure this op's required extensions are allowed to use. For each op,
+  // we return a vector of vector for its extension requirements following
+  // ((Extension::A OR Extenion::B) AND (Extension::C OR Extension::D))
+  // convention. Ops not implementing QueryExtensionInterface do not require
+  // extensions to be available.
+  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
+    auto exts = extensions.getExtensions();
+    for (const auto &ors : exts)
+      if (llvm::all_of(ors, [this](spirv::Extension ext) {
+            return this->givenExtensions.count(ext) == 0;
+          })) {
+        LLVM_DEBUG(llvm::dbgs() << op->getName()
+                                << " illegal: missing required extension\n");
+        return false;
+      }
+  }
+
+  // Make sure this op's required extensions are allowed to use. For each op,
+  // we return a vector of vector for its capability requirements following
+  // ((Capability::A OR Extenion::B) AND (Capability::C OR Capability::D))
+  // convention. Ops not implementing QueryExtensionInterface do not require
+  // extensions to be available.
+  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
+    auto caps = capabilities.getCapabilities();
+    for (const auto &ors : caps)
+      if (llvm::all_of(ors, [this](spirv::Capability cap) {
+            return this->givenCapabilities.count(cap) == 0;
+          })) {
+        LLVM_DEBUG(llvm::dbgs() << op->getName()
+                                << " illegal: missing required capability\n");
+        return false;
+      }
+  }
+
+  return true;
+};
index 5aa8282f3e9b50153b6aa618473f785a362b1f5a..db866d63a19ab165d70d21b98d425a9ca6a0b901 100644 (file)
@@ -1,4 +1,4 @@
-//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===//
+//===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
 //
 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"
 
 using namespace mlir;
@@ -16,32 +17,48 @@ namespace mlir {
 #include "mlir/Dialect/SPIRV/TargetAndABI.cpp.inc"
 }
 
-StringRef mlir::spirv::getInterfaceVarABIAttrName() {
+StringRef spirv::getInterfaceVarABIAttrName() {
   return "spv.interface_var_abi";
 }
 
-mlir::spirv::InterfaceVarABIAttr
-mlir::spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
-                                    spirv::StorageClass storageClass,
-                                    MLIRContext *context) {
+spirv::InterfaceVarABIAttr
+spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
+                              spirv::StorageClass storageClass,
+                              MLIRContext *context) {
   Type i32Type = IntegerType::get(32, context);
-  return mlir::spirv::InterfaceVarABIAttr::get(
+  return spirv::InterfaceVarABIAttr::get(
       IntegerAttr::get(i32Type, descriptorSet),
       IntegerAttr::get(i32Type, binding),
       IntegerAttr::get(i32Type, static_cast<int64_t>(storageClass)), context);
 }
 
-StringRef mlir::spirv::getEntryPointABIAttrName() {
-  return "spv.entry_point_abi";
-}
+StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; }
 
-mlir::spirv::EntryPointABIAttr
-mlir::spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize,
-                                  MLIRContext *context) {
+spirv::EntryPointABIAttr
+spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, MLIRContext *context) {
   assert(localSize.size() == 3);
-  return mlir::spirv::EntryPointABIAttr::get(
+  return spirv::EntryPointABIAttr::get(
       DenseElementsAttr::get<int32_t>(
           VectorType::get(3, IntegerType::get(32, context)), localSize)
           .cast<DenseIntElementsAttr>(),
       context);
 }
+
+StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; }
+
+spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
+  Builder builder(context);
+  return spirv::TargetEnvAttr::get(
+      builder.getI32IntegerAttr(static_cast<uint32_t>(spirv::Version::V_1_0)),
+      builder.getI32ArrayAttr({}),
+      builder.getI32ArrayAttr(
+          {static_cast<uint32_t>(spirv::Capability::Shader)}),
+      context);
+}
+
+spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
+  if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
+          spirv::getTargetEnvAttrName()))
+    return attr;
+  return getDefaultTargetEnv(op->getContext());
+}
index 1892a6de3bd10a0bd57f6eca5c28b1510801d9e1..01c67a2f9a5406c9f49de9f49fe50ac04b61f581 100644 (file)
@@ -217,16 +217,16 @@ void LowerABIAttributesPass::runOnOperation() {
   OwningRewritePatternList patterns;
   patterns.insert<FuncOpLowering>(context, typeConverter);
 
-  ConversionTarget target(*context);
-  target.addLegalDialect<spirv::SPIRVDialect>();
+  std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
+      spirv::lookupTargetEnvOrDefault(module), context);
   auto entryPointAttrName = spirv::getEntryPointABIAttrName();
-  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+  target->addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
     return op.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName) &&
            op.getNumResults() == 0 && op.getNumArguments() == 0;
   });
-  target.addLegalOp<ReturnOp>();
+  target->addLegalOp<ReturnOp>();
   if (failed(
-          applyPartialConversion(module, target, patterns, &typeConverter))) {
+          applyPartialConversion(module, *target, patterns, &typeConverter))) {
     return signalPassFailure();
   }
 
index 0c31a8ee3a65a300b7b8eb3bed60d8deebc445dd..6398ab38877d1fcda12560a05f32aeac230ada40 100644 (file)
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Function.h"
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Printing op availability pass
+//===----------------------------------------------------------------------===//
+
 namespace {
 /// A pass for testing SPIR-V op availability.
-struct TestAvailability : public FunctionPass<TestAvailability> {
+struct PrintOpAvailability : public FunctionPass<PrintOpAvailability> {
   void runOnFunction() override;
 };
 } // end anonymous namespace
 
-void TestAvailability::runOnFunction() {
+void PrintOpAvailability::runOnFunction() {
   auto f = getFunction();
   llvm::outs() << f.getName() << "\n";
 
@@ -70,5 +75,105 @@ void TestAvailability::runOnFunction() {
   });
 }
 
-static PassRegistration<TestAvailability> pass("test-spirv-op-availability",
-                                               "Test SPIR-V op availability");
+static PassRegistration<PrintOpAvailability>
+    printOpAvailabilityPass("test-spirv-op-availability",
+                            "Test SPIR-V op availability");
+
+//===----------------------------------------------------------------------===//
+// Converting target environment pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pass for testing SPIR-V op availability.
+struct ConvertToTargetEnv : public FunctionPass<ConvertToTargetEnv> {
+  void runOnFunction() override;
+};
+
+struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
+  ConvertToAtomCmpExchangeWeak(MLIRContext *context);
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override;
+};
+
+struct ConvertToGroupNonUniformBallot : public RewritePattern {
+  ConvertToGroupNonUniformBallot(MLIRContext *context);
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override;
+};
+
+struct ConvertToSubgroupBallot : public RewritePattern {
+  ConvertToSubgroupBallot(MLIRContext *context);
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override;
+};
+} // end anonymous namespace
+
+void ConvertToTargetEnv::runOnFunction() {
+  MLIRContext *context = &getContext();
+  FuncOp fn = getFunction();
+
+  auto targetEnv = fn.getOperation()
+                       ->getAttr(spirv::getTargetEnvAttrName())
+                       .cast<spirv::TargetEnvAttr>();
+  auto target = spirv::SPIRVConversionTarget::get(targetEnv, context);
+
+  OwningRewritePatternList patterns;
+  patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToGroupNonUniformBallot,
+                  ConvertToSubgroupBallot>(context);
+
+  if (failed(applyPartialConversion(fn, *target, patterns)))
+    return signalPassFailure();
+}
+
+ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
+    : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op",
+                     {"spv.AtomicCompareExchangeWeak"}, 1, context) {}
+
+PatternMatchResult
+ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
+                                              PatternRewriter &rewriter) const {
+  Value ptr = op->getOperand(0);
+  Value value = op->getOperand(1);
+  Value comparator = op->getOperand(2);
+
+  // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in
+  // memory semantics to additionally require AtomicStorage capability.
+  rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
+      op, value.getType(), ptr, spirv::Scope::Workgroup,
+      spirv::MemorySemantics::AcquireRelease |
+          spirv::MemorySemantics::AtomicCounterMemory,
+      spirv::MemorySemantics::Acquire, value, comparator);
+  return matchSuccess();
+}
+
+ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
+    MLIRContext *context)
+    : RewritePattern("test.convert_to_group_non_uniform_ballot_op",
+                     {"spv.GroupNonUniformBallot"}, 1, context) {}
+
+PatternMatchResult ConvertToGroupNonUniformBallot::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  Value predicate = op->getOperand(0);
+
+  rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
+      op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
+  return matchSuccess();
+}
+
+ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
+    : RewritePattern("test.convert_to_subgroup_ballot_op",
+                     {"spv.SubgroupBallotKHR"}, 1, context) {}
+
+PatternMatchResult
+ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
+                                         PatternRewriter &rewriter) const {
+  Value predicate = op->getOperand(0);
+
+  rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
+      op, op->getResult(0).getType(), predicate);
+  return matchSuccess();
+}
+
+static PassRegistration<ConvertToTargetEnv>
+    convertToTargetEnvPass("test-spirv-target-env",
+                           "Test SPIR-V target environment");
index 19bfe9e8f5a36b1d89dc6d3e5f8aabff7b2dbf4e..3966cc0ccea6a8d406b14413ad902a652483069e 100644 (file)
@@ -26,14 +26,14 @@ func @unknown_attr_on_region() -> (i32 {spv.something}) {
 // spv.entry_point_abi
 //===----------------------------------------------------------------------===//
 
-// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one integer elements attribute: 'local_size'}}
+// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one 32-bit integer elements attribute: 'local_size'}}
 func @spv_entry_point() attributes {
   spv.entry_point_abi = 64
 } { return }
 
 // -----
 
-// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one integer elements attribute: 'local_size'}}
+// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one 32-bit integer elements attribute: 'local_size'}}
 func @spv_entry_point() attributes {
   spv.entry_point_abi = {local_size = 64}
 } { return }
@@ -51,14 +51,14 @@ func @spv_entry_point() attributes {
 // spv.interface_var_abi
 //===----------------------------------------------------------------------===//
 
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
 func @interface_var(
   %arg0 : f32 {spv.interface_var_abi = 64}
 ) { return }
 
 // -----
 
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
 func @interface_var(
   %arg0 : f32 {spv.interface_var_abi = {binding = 0: i32}}
 ) { return }
@@ -74,7 +74,7 @@ func @interface_var(
 
 // -----
 
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
 func @interface_var() -> (f32 {spv.interface_var_abi = 64})
 {
   %0 = constant 10.0 : f32
@@ -83,7 +83,7 @@ func @interface_var() -> (f32 {spv.interface_var_abi = 64})
 
 // -----
 
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
 func @interface_var() -> (f32 {spv.interface_var_abi = {binding = 0: i32}})
 {
   %0 = constant 10.0 : f32
@@ -99,3 +99,49 @@ func @interface_var() -> (f32 {spv.interface_var_abi = {
   %0 = constant 10.0 : f32
   return %0: f32
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.target_env
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}}
+func @target_env_wrong_type() attributes {
+  spv.target_env = 64
+} { return }
+
+// -----
+
+// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}}
+func @target_env_missing_fields() attributes {
+  spv.target_env = {version = 0: i32}
+} { return }
+
+// -----
+
+// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}}
+func @target_env_wrong_extension_type() attributes {
+  spv.target_env = {version = 0: i32, extensions = [32: i32], capabilities = [1: i32]}
+} { return }
+
+// -----
+
+// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}}
+func @target_env_wrong_extension() attributes {
+  spv.target_env = {version = 0: i32, extensions = ["SPV_Something"], capabilities = [1: i32]}
+} { return }
+
+// -----
+
+func @target_env() attributes {
+  // CHECK: spv.target_env = {capabilities = [1 : i32], extensions = ["SPV_KHR_storage_buffer_storage_class"], version = 0 : i32}
+  spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_storage_buffer_storage_class"], capabilities = [1: i32]}
+} { return }
+
+// -----
+
+// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}}
+func @target_env_extra_fields() attributes {
+  spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_storage_buffer_storage_class"], capabilities = [1: i32], extra = 32}
+} { return }
diff --git a/mlir/test/Dialect/SPIRV/target-env.mlir b/mlir/test/Dialect/SPIRV/target-env.mlir
new file mode 100644 (file)
index 0000000..92238b4
--- /dev/null
@@ -0,0 +1,120 @@
+// RUN: mlir-opt -disable-pass-threading -test-spirv-target-env %s | FileCheck %s
+
+// Note: The following tests check that a spv.target_env can properly control
+// the conversion target and filter unavailable ops during the conversion.
+// We don't care about the op argument consistency too much; so certain enum
+// values for enum attributes may not make much sense for the test op.
+
+// spv.AtomicCompareExchangeWeak is available from SPIR-V 1.0 to 1.3 under
+// Kernel capability.
+// spv.AtomicCompareExchangeWeak has two memory semantics enum attribute,
+// whose value, if containing AtomicCounterMemory bit, additionally requires
+// AtomicStorage capability.
+
+// spv.GroupNonUniformBallot is available starting from SPIR-V 1.3 under
+// GroupNonUniform capability.
+
+// spv.SubgroupBallotKHR is available under in all SPIR-V versions under
+// SubgroupBallotKHR capability and SPV_KHR_shader_ballot extension.
+
+// Enum case symbol (value) map:
+// Version: 1.0 (0), 1.1 (1), 1.2 (2), 1.3 (3), 1.4 (4)
+// Capability: Kernel (6), AtomicStorage (21), GroupNonUniformBallot (64),
+//             SubgroupBallotKHR (4423)
+
+//===----------------------------------------------------------------------===//
+// MaxVersion
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmp_exchange_weak_suitable_version_capabilities
+func @cmp_exchange_weak_suitable_version_capabilities(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 attributes {
+  spv.target_env = {version = 1: i32, extensions = [], capabilities = [6: i32, 21: i32]}
+} {
+  // CHECK: spv.AtomicCompareExchangeWeak "Workgroup" "AcquireRelease|AtomicCounterMemory" "Acquire"
+  %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr<i32, Workgroup>, i32, i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @cmp_exchange_weak_unsupported_version
+func @cmp_exchange_weak_unsupported_version(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 attributes {
+  spv.target_env = {version = 4: i32, extensions = [], capabilities = [6: i32, 21: i32]}
+} {
+  // CHECK: test.convert_to_atomic_compare_exchange_weak_op
+  %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr<i32, Workgroup>, i32, i32) -> (i32)
+  return %0: i32
+}
+
+//===----------------------------------------------------------------------===//
+// MinVersion
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_ballot_suitable_version
+func @group_non_uniform_ballot_suitable_version(%predicate: i1) -> vector<4xi32> attributes {
+  spv.target_env = {version = 4: i32, extensions = [], capabilities = [64: i32]}
+} {
+  // CHECK: spv.GroupNonUniformBallot "Workgroup"
+  %0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
+  return %0: vector<4xi32>
+}
+
+// CHECK-LABEL: @group_non_uniform_ballot_unsupported_version
+func @group_non_uniform_ballot_unsupported_version(%predicate: i1) -> vector<4xi32> attributes {
+  spv.target_env = {version = 1: i32, extensions = [], capabilities = [64: i32]}
+} {
+  // CHECK: test.convert_to_group_non_uniform_ballot_op
+  %0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
+  return %0: vector<4xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// Capability
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cmp_exchange_weak_missing_capability_kernel
+func @cmp_exchange_weak_missing_capability_kernel(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 attributes {
+  spv.target_env = {version = 3: i32, extensions = [], capabilities = [21: i32]}
+} {
+  // CHECK: test.convert_to_atomic_compare_exchange_weak_op
+  %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr<i32, Workgroup>, i32, i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @cmp_exchange_weak_missing_capability_atomic_storage
+func @cmp_exchange_weak_missing_capability_atomic_storage(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 attributes {
+  spv.target_env = {version = 3: i32, extensions = [], capabilities = [6: i32]}
+} {
+  // CHECK: test.convert_to_atomic_compare_exchange_weak_op
+  %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr<i32, Workgroup>, i32, i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @subgroup_ballot_missing_capability
+func @subgroup_ballot_missing_capability(%predicate: i1) -> vector<4xi32> attributes {
+  spv.target_env = {version = 4: i32, extensions = ["SPV_KHR_shader_ballot"], capabilities = []}
+} {
+  // CHECK: test.convert_to_subgroup_ballot_op
+  %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
+  return %0: vector<4xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// Extension
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @subgroup_ballot_suitable_extension
+func @subgroup_ballot_suitable_extension(%predicate: i1) -> vector<4xi32> attributes {
+  spv.target_env = {version = 4: i32, extensions = ["SPV_KHR_shader_ballot"], capabilities = [4423: i32]}
+} {
+  // CHECK: spv.SubgroupBallotKHR
+  %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
+  return %0: vector<4xi32>
+}
+
+// CHECK-LABEL: @subgroup_ballot_missing_extension
+func @subgroup_ballot_missing_extension(%predicate: i1) -> vector<4xi32> attributes {
+  spv.target_env = {version = 4: i32, extensions = [], capabilities = [4423: i32]}
+} {
+  // CHECK: test.convert_to_subgroup_ballot_op
+  %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
+  return %0: vector<4xi32>
+}