[mlir][spirv] Support implied extensions and capabilities
authorLei Zhang <antiagainst@google.com>
Thu, 26 Dec 2019 17:15:26 +0000 (12:15 -0500)
committerLei Zhang <antiagainst@google.com>
Fri, 17 Jan 2020 13:01:57 +0000 (08:01 -0500)
In SPIR-V, when a new version is introduced, it is possible some
existing extensions will be incorporated into it so that it becomes
implicitly declared if targeting the new version. This affects
conversion target specification because we need to take this into
account when allowing what extensions to use.

For a capability, it may also implies some other capabilities,
for example, the `Shader` capability implies `Matrix` the capability.
This should also be taken into consideration when preparing the
conversion target: when we specify an capability is allowed, all
its recursively implied capabilities are also allowed.

This commit adds utility functions to query implied extensions for
a given version and implied capabilities for a given capability
and updated SPIRVConversionTarget to use them.

This commit also fixes a bug in availability spec. When a symbol
(op or enum case) can be enabled by an extension, we should drop
it's minimal version requirement. Being enabled by an extension
naturally means the symbol can be used by *any* SPIR-V version
as long as the extension is supported. The grammar still encodes
the 'version' field for such cases, but it should be interpreted
as a different way: rather than meaning a minimal version
requirement, it says the symbol becomes core at that specific
version.

Differential Revision: https://reviews.llvm.org/D72765

12 files changed:
mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/include/mlir/TableGen/Attribute.h
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/TableGen/Attribute.cpp
mlir/test/Dialect/SPIRV/TestAvailability.cpp
mlir/test/Dialect/SPIRV/availability.mlir
mlir/test/Dialect/SPIRV/target-env.mlir
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
mlir/utils/spirv/gen_spirv_dialect.py

index 8d49b4759e4be62af3e36053c810402e21ff323d..fb1cf9f720892c592266ab9d4bd83339726331d0 100644 (file)
@@ -8,6 +8,7 @@ add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
 set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
 mlir_tablegen(SPIRVEnumAvailability.h.inc -gen-spirv-enum-avail-decls)
 mlir_tablegen(SPIRVEnumAvailability.cpp.inc -gen-spirv-enum-avail-defs)
+mlir_tablegen(SPIRVCapabilityImplication.inc -gen-spirv-capability-implication)
 add_public_tablegen_target(MLIRSPIRVEnumAvailabilityIncGen)
 
 set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
index 17c296ef8817d0f32707dcdd1d0f9b2be946ec60..2a9124bfda750c318dcc7dd5e9dbc5ac321ae0ca 100644 (file)
@@ -367,25 +367,21 @@ def SPV_C_SubgroupVoteKHR                           : I32EnumAttrCase<"SubgroupV
 }
 def SPV_C_StorageBuffer16BitAccess                  : I32EnumAttrCase<"StorageBuffer16BitAccess", 4433> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_16bit_storage]>
   ];
 }
 def SPV_C_StoragePushConstant16                     : I32EnumAttrCase<"StoragePushConstant16", 4435> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_16bit_storage]>
   ];
 }
 def SPV_C_StorageInputOutput16                      : I32EnumAttrCase<"StorageInputOutput16", 4436> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_16bit_storage]>
   ];
 }
 def SPV_C_DeviceGroup                               : I32EnumAttrCase<"DeviceGroup", 4437> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_device_group]>
   ];
 }
@@ -401,43 +397,36 @@ def SPV_C_SampleMaskPostDepthCoverage               : I32EnumAttrCase<"SampleMas
 }
 def SPV_C_StorageBuffer8BitAccess                   : I32EnumAttrCase<"StorageBuffer8BitAccess", 4448> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_5>,
     Extension<[SPV_KHR_8bit_storage]>
   ];
 }
 def SPV_C_StoragePushConstant8                      : I32EnumAttrCase<"StoragePushConstant8", 4450> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_5>,
     Extension<[SPV_KHR_8bit_storage]>
   ];
 }
 def SPV_C_DenormPreserve                            : I32EnumAttrCase<"DenormPreserve", 4464> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_float_controls]>
   ];
 }
 def SPV_C_DenormFlushToZero                         : I32EnumAttrCase<"DenormFlushToZero", 4465> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_float_controls]>
   ];
 }
 def SPV_C_SignedZeroInfNanPreserve                  : I32EnumAttrCase<"SignedZeroInfNanPreserve", 4466> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_float_controls]>
   ];
 }
 def SPV_C_RoundingModeRTE                           : I32EnumAttrCase<"RoundingModeRTE", 4467> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_float_controls]>
   ];
 }
 def SPV_C_RoundingModeRTZ                           : I32EnumAttrCase<"RoundingModeRTZ", 4468> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_float_controls]>
   ];
 }
@@ -595,14 +584,12 @@ def SPV_C_GroupNonUniformQuad                       : I32EnumAttrCase<"GroupNonU
 def SPV_C_StorageUniform16                          : I32EnumAttrCase<"StorageUniform16", 4434> {
   list<I32EnumAttrCase> implies = [SPV_C_StorageBuffer16BitAccess];
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_16bit_storage]>
   ];
 }
 def SPV_C_UniformAndStorageBuffer8BitAccess         : I32EnumAttrCase<"UniformAndStorageBuffer8BitAccess", 4449> {
   list<I32EnumAttrCase> implies = [SPV_C_StorageBuffer8BitAccess];
   list<Availability> availability = [
-    MinVersion<SPV_V_1_5>,
     Extension<[SPV_KHR_8bit_storage]>
   ];
 }
@@ -708,21 +695,18 @@ def SPV_C_PipeStorage                               : I32EnumAttrCase<"PipeStora
 def SPV_C_DrawParameters                            : I32EnumAttrCase<"DrawParameters", 4427> {
   list<I32EnumAttrCase> implies = [SPV_C_Shader];
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_shader_draw_parameters]>
   ];
 }
 def SPV_C_MultiView                                 : I32EnumAttrCase<"MultiView", 4439> {
   list<I32EnumAttrCase> implies = [SPV_C_Shader];
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_multiview]>
   ];
 }
 def SPV_C_VariablePointersStorageBuffer             : I32EnumAttrCase<"VariablePointersStorageBuffer", 4441> {
   list<I32EnumAttrCase> implies = [SPV_C_Shader];
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_variable_pointers]>
   ];
 }
@@ -807,7 +791,6 @@ def SPV_C_RayTracingNV                              : I32EnumAttrCase<"RayTracin
 def SPV_C_PhysicalStorageBufferAddresses            : I32EnumAttrCase<"PhysicalStorageBufferAddresses", 5347> {
   list<I32EnumAttrCase> implies = [SPV_C_Shader];
   list<Availability> availability = [
-    MinVersion<SPV_V_1_5>,
     Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]>
   ];
 }
@@ -874,7 +857,6 @@ def SPV_C_MultiViewport                             : I32EnumAttrCase<"MultiView
 def SPV_C_VariablePointers                          : I32EnumAttrCase<"VariablePointers", 4442> {
   list<I32EnumAttrCase> implies = [SPV_C_VariablePointersStorageBuffer];
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_variable_pointers]>
   ];
 }
@@ -1041,7 +1023,6 @@ def SPV_AM_Physical64              : I32EnumAttrCase<"Physical64", 2> {
 }
 def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64", 5348> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_5>,
     Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]>,
     Capability<[SPV_C_PhysicalStorageBufferAddresses]>
   ];
@@ -1266,35 +1247,30 @@ def SPV_BI_SubgroupLtMask              : I32EnumAttrCase<"SubgroupLtMask", 4420>
 }
 def SPV_BI_BaseVertex                  : I32EnumAttrCase<"BaseVertex", 4424> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_shader_draw_parameters]>,
     Capability<[SPV_C_DrawParameters]>
   ];
 }
 def SPV_BI_BaseInstance                : I32EnumAttrCase<"BaseInstance", 4425> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_shader_draw_parameters]>,
     Capability<[SPV_C_DrawParameters]>
   ];
 }
 def SPV_BI_DrawIndex                   : I32EnumAttrCase<"DrawIndex", 4426> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader]>,
     Capability<[SPV_C_DrawParameters, SPV_C_MeshShadingNV]>
   ];
 }
 def SPV_BI_DeviceIndex                 : I32EnumAttrCase<"DeviceIndex", 4438> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_device_group]>,
     Capability<[SPV_C_DeviceGroup]>
   ];
 }
 def SPV_BI_ViewIndex                   : I32EnumAttrCase<"ViewIndex", 4440> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_multiview]>,
     Capability<[SPV_C_MultiView]>
   ];
@@ -1803,13 +1779,11 @@ def SPV_D_MaxByteOffsetId             : I32EnumAttrCase<"MaxByteOffsetId", 47> {
 }
 def SPV_D_NoSignedWrap                : I32EnumAttrCase<"NoSignedWrap", 4469> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_no_integer_wrap_decoration]>
   ];
 }
 def SPV_D_NoUnsignedWrap              : I32EnumAttrCase<"NoUnsignedWrap", 4470> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_no_integer_wrap_decoration]>
   ];
 }
@@ -1873,14 +1847,12 @@ def SPV_D_NonUniform                  : I32EnumAttrCase<"NonUniform", 5300> {
 }
 def SPV_D_RestrictPointer             : I32EnumAttrCase<"RestrictPointer", 5355> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_5>,
     Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]>,
     Capability<[SPV_C_PhysicalStorageBufferAddresses]>
   ];
 }
 def SPV_D_AliasedPointer              : I32EnumAttrCase<"AliasedPointer", 5356> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_5>,
     Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]>,
     Capability<[SPV_C_PhysicalStorageBufferAddresses]>
   ];
@@ -2161,35 +2133,30 @@ def SPV_EM_PostDepthCoverage                : I32EnumAttrCase<"PostDepthCoverage
 }
 def SPV_EM_DenormPreserve                   : I32EnumAttrCase<"DenormPreserve", 4459> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_float_controls]>,
     Capability<[SPV_C_DenormPreserve]>
   ];
 }
 def SPV_EM_DenormFlushToZero                : I32EnumAttrCase<"DenormFlushToZero", 4460> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_float_controls]>,
     Capability<[SPV_C_DenormFlushToZero]>
   ];
 }
 def SPV_EM_SignedZeroInfNanPreserve         : I32EnumAttrCase<"SignedZeroInfNanPreserve", 4461> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_float_controls]>,
     Capability<[SPV_C_SignedZeroInfNanPreserve]>
   ];
 }
 def SPV_EM_RoundingModeRTE                  : I32EnumAttrCase<"RoundingModeRTE", 4462> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_float_controls]>,
     Capability<[SPV_C_RoundingModeRTE]>
   ];
 }
 def SPV_EM_RoundingModeRTZ                  : I32EnumAttrCase<"RoundingModeRTZ", 4463> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_4>,
     Extension<[SPV_KHR_float_controls]>,
     Capability<[SPV_C_RoundingModeRTZ]>
   ];
@@ -2705,7 +2672,6 @@ def SPV_MM_OpenCL  : I32EnumAttrCase<"OpenCL", 2> {
 }
 def SPV_MM_Vulkan  : I32EnumAttrCase<"Vulkan", 3> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_5>,
     Extension<[SPV_KHR_vulkan_memory_model]>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
@@ -2755,7 +2721,6 @@ def SPV_MS_MakeVisible            : BitEnumAttrCase<"MakeVisible", 0x4000> {
 }
 def SPV_MS_Volatile               : BitEnumAttrCase<"Volatile", 0x8000> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_5>,
     Extension<[SPV_KHR_vulkan_memory_model]>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
@@ -2835,7 +2800,6 @@ def SPV_SC_AtomicCounter          : I32EnumAttrCase<"AtomicCounter", 10> {
 def SPV_SC_Image                  : I32EnumAttrCase<"Image", 11>;
 def SPV_SC_StorageBuffer          : I32EnumAttrCase<"StorageBuffer", 12> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_3>,
     Extension<[SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
     Capability<[SPV_C_Shader]>
   ];
@@ -2878,7 +2842,6 @@ def SPV_SC_ShaderRecordBufferNV   : I32EnumAttrCase<"ShaderRecordBufferNV", 5343
 }
 def SPV_SC_PhysicalStorageBuffer  : I32EnumAttrCase<"PhysicalStorageBuffer", 5349> {
   list<Availability> availability = [
-    MinVersion<SPV_V_1_5>,
     Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]>,
     Capability<[SPV_C_PhysicalStorageBufferAddresses]>
   ];
index 19cd9ee0f564090dd5f7b3a3ffa923ab68675eb4..01088697c8790a2aee84c16a3a8f7896c4cc170e 100644 (file)
@@ -17,6 +17,8 @@
 #include "mlir/IR/TypeSupport.h"
 #include "mlir/IR/Types.h"
 
+#include <tuple>
+
 // Forward declare enum classes related to op availability. Their definitions
 // are in the TableGen'erated SPIRVEnums.h.inc and can be referenced by other
 // declarations in SPIRVEnums.h.inc.
@@ -33,10 +35,22 @@ enum class Capability : uint32_t;
 // Pull in all enum type availability query function declarations
 #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.h.inc"
 
-#include <tuple>
-
 namespace mlir {
 namespace spirv {
+/// Returns the implied extensions for the given version. These extensions are
+/// incorporated into the current version so they are implicitly declared when
+/// targeting the given version.
+ArrayRef<Extension> getImpliedExtensions(Version version);
+
+/// Returns the directly implied capabilities for the given capability. These
+/// capabilities are implicitly declared by the given capability.
+ArrayRef<Capability> getDirectImpliedCapabilities(Capability cap);
+/// Returns the recursively implied capabilities for the given capability. These
+/// capabilities are implicitly declared by the given capability. Compared to
+/// the above function, this function collects implied capabilities recursively:
+/// if an implicitly declared capability implicitly declares a third one, the
+/// third one will also be returned.
+SmallVector<Capability, 0> getRecursiveImpliedCapabilities(Capability cap);
 
 namespace detail {
 struct ArrayTypeStorage;
index 9f1d8b392ae75b4de3780c59a25cdf80c171bbfe..7d7d9fe039c1a58868a5aa87107b363669e89429 100644 (file)
@@ -125,6 +125,7 @@ private:
 // StringAttr and IntegerAttr.
 class EnumAttrCase : public Attribute {
 public:
+  explicit EnumAttrCase(const llvm::Record *record);
   explicit EnumAttrCase(const llvm::DefInit *init);
 
   // Returns true if this EnumAttrCase is backed by a StringAttr.
index adc610349b36f0e7b4bcfe4774d65f8cc2207561..2f36c3370d74520fc2db17bdf2bc3120bec722be 100644 (file)
@@ -247,9 +247,19 @@ spirv::SPIRVConversionTarget::SPIRVConversionTarget(
     givenExtensions.insert(
         *spirv::symbolizeExtension(extAttr.cast<StringAttr>().getValue()));
 
-  for (Attribute capAttr : targetEnv.capabilities())
-    givenCapabilities.insert(
-        static_cast<spirv::Capability>(capAttr.cast<IntegerAttr>().getInt()));
+  // Add extensions implied by the current version.
+  for (spirv::Extension ext : spirv::getImpliedExtensions(givenVersion))
+    givenExtensions.insert(ext);
+
+  for (Attribute capAttr : targetEnv.capabilities()) {
+    auto cap =
+        static_cast<spirv::Capability>(capAttr.cast<IntegerAttr>().getInt());
+    givenCapabilities.insert(cap);
+
+    // Add capabilities implied by the current capability.
+    for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
+      givenCapabilities.insert(c);
+  }
 }
 
 bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
index c01f7161c40ec6f6d30f7ec5c58a64d72c794363..3e1e306374d42798cdd1c55b0b192788a37c169e 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 
@@ -25,6 +26,73 @@ using namespace mlir::spirv;
 // Pull in all enum type availability query function definitions
 #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// Availability relationship
+//===----------------------------------------------------------------------===//
+
+ArrayRef<Extension> spirv::getImpliedExtensions(Version version) {
+  // Note: the following lists are from "Appendix A: Changes" of the spec.
+
+#define V_1_3_IMPLIED_EXTS                                                     \
+  Extension::SPV_KHR_shader_draw_parameters, Extension::SPV_KHR_16bit_storage, \
+      Extension::SPV_KHR_device_group, Extension::SPV_KHR_multiview,           \
+      Extension::SPV_KHR_storage_buffer_storage_class,                         \
+      Extension::SPV_KHR_variable_pointers
+
+#define V_1_4_IMPLIED_EXTS                                                     \
+  Extension::SPV_KHR_no_integer_wrap_decoration,                               \
+      Extension::SPV_GOOGLE_decorate_string,                                   \
+      Extension::SPV_GOOGLE_hlsl_functionality1,                               \
+      Extension::SPV_KHR_float_controls
+
+#define V_1_5_IMPLIED_EXTS                                                     \
+  Extension::SPV_KHR_8bit_storage, Extension::SPV_EXT_descriptor_indexing,     \
+      Extension::SPV_EXT_shader_viewport_index_layer,                          \
+      Extension::SPV_EXT_physical_storage_buffer,                              \
+      Extension::SPV_KHR_physical_storage_buffer,                              \
+      Extension::SPV_KHR_vulkan_memory_model
+
+  switch (version) {
+  default:
+    return {};
+  case Version::V_1_3: {
+    static Extension exts[] = {V_1_3_IMPLIED_EXTS};
+    return exts;
+  }
+  case Version::V_1_4: {
+    static Extension exts[] = {V_1_3_IMPLIED_EXTS, V_1_4_IMPLIED_EXTS};
+    return exts;
+  }
+  case Version::V_1_5: {
+    static Extension exts[] = {V_1_3_IMPLIED_EXTS, V_1_4_IMPLIED_EXTS,
+                               V_1_5_IMPLIED_EXTS};
+    return exts;
+  }
+  }
+
+#undef V_1_5_IMPLIED_EXTS
+#undef V_1_4_IMPLIED_EXTS
+#undef V_1_3_IMPLIED_EXTS
+}
+
+// Pull in utility function definition for implied capabilities
+#include "mlir/Dialect/SPIRV/SPIRVCapabilityImplication.inc"
+
+SmallVector<Capability, 0>
+spirv::getRecursiveImpliedCapabilities(Capability cap) {
+  ArrayRef<Capability> directCaps = getDirectImpliedCapabilities(cap);
+  llvm::SetVector<Capability, SmallVector<Capability, 0>> allCaps(
+      directCaps.begin(), directCaps.end());
+
+  // TODO(antiagainst): This is insufficient; find a better way to handle this
+  // (e.g., using static lists) if this turns out to be a bottleneck.
+  for (unsigned i = 0; i < allCaps.size(); ++i)
+    for (Capability c : getDirectImpliedCapabilities(allCaps[i]))
+      allCaps.insert(c);
+
+  return allCaps.takeVector();
+}
+
 //===----------------------------------------------------------------------===//
 // ArrayType
 //===----------------------------------------------------------------------===//
index 958e1620a450c864b967abe63e95908e295e7bd2..659f7d4c5beac2763047e713d9257f07ae8068be 100644 (file)
@@ -137,12 +137,15 @@ StringRef tblgen::ConstantAttr::getConstantValue() const {
   return def->getValueAsString("value");
 }
 
-tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
-    : Attribute(init) {
+tblgen::EnumAttrCase::EnumAttrCase(const llvm::Record *record)
+    : Attribute(record) {
   assert(isSubClassOf("EnumAttrCaseInfo") &&
          "must be subclass of TableGen 'EnumAttrInfo' class");
 }
 
+tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
+    : EnumAttrCase(init->getDef()) {}
+
 bool tblgen::EnumAttrCase::isStrCase() const {
   return isSubClassOf("StrEnumAttrCase");
 }
index 6398ab38877d1fcda12560a05f32aeac230ada40..373397bdc2cd936e181bb002f3c1f7d3094e6dc5 100644 (file)
@@ -95,12 +95,24 @@ struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
                                      PatternRewriter &rewriter) const override;
 };
 
+struct ConvertToBitReverse : public RewritePattern {
+  ConvertToBitReverse(MLIRContext *context);
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override;
+};
+
 struct ConvertToGroupNonUniformBallot : public RewritePattern {
   ConvertToGroupNonUniformBallot(MLIRContext *context);
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override;
 };
 
+struct ConvertToModule : public RewritePattern {
+  ConvertToModule(MLIRContext *context);
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override;
+};
+
 struct ConvertToSubgroupBallot : public RewritePattern {
   ConvertToSubgroupBallot(MLIRContext *context);
   PatternMatchResult matchAndRewrite(Operation *op,
@@ -118,7 +130,8 @@ void ConvertToTargetEnv::runOnFunction() {
   auto target = spirv::SPIRVConversionTarget::get(targetEnv, context);
 
   OwningRewritePatternList patterns;
-  patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToGroupNonUniformBallot,
+  patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
+                  ConvertToGroupNonUniformBallot, ConvertToModule,
                   ConvertToSubgroupBallot>(context);
 
   if (failed(applyPartialConversion(fn, *target, patterns)))
@@ -146,6 +159,20 @@ ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
   return matchSuccess();
 }
 
+ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
+    : RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1,
+                     context) {}
+
+PatternMatchResult
+ConvertToBitReverse::matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const {
+  Value predicate = op->getOperand(0);
+
+  rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
+      op, op->getResult(0).getType(), predicate);
+  return matchSuccess();
+}
+
 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
     MLIRContext *context)
     : RewritePattern("test.convert_to_group_non_uniform_ballot_op",
@@ -160,6 +187,18 @@ PatternMatchResult ConvertToGroupNonUniformBallot::matchAndRewrite(
   return matchSuccess();
 }
 
+ConvertToModule::ConvertToModule(MLIRContext *context)
+    : RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {}
+
+PatternMatchResult
+ConvertToModule::matchAndRewrite(Operation *op,
+                                 PatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
+      op, spirv::AddressingModel::PhysicalStorageBuffer64,
+      spirv::MemoryModel::Vulkan);
+  return matchSuccess();
+}
+
 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
     : RewritePattern("test.convert_to_subgroup_ballot_op",
                      {"spv.SubgroupBallotKHR"}, 1, context) {}
index df440cfaaacdda1fc8e1d848b5c74530a1146930..381754c74609e2ef55a5af9594286c7a1216676b 100644 (file)
@@ -42,7 +42,7 @@ func @module_logical_glsl450() {
 
 // CHECK-LABEL: module_physical_storage_buffer64_vulkan
 func @module_physical_storage_buffer64_vulkan() {
-  // CHECK: spv.module min version: V_1_5
+  // CHECK: spv.module min version: V_1_0
   // CHECK: spv.module max version: V_1_5
   // CHECK: spv.module extensions: [ [SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer] [SPV_KHR_vulkan_memory_model] ]
   // CHECK: spv.module capabilities: [ [PhysicalStorageBufferAddresses] [VulkanMemoryModel] ]
index 92238b4e6744b9a099be4463f4ded6cefb2d7f66..3c6ca1807780b12691328b21580b9fd31e1b0bb3 100644 (file)
 // whose value, if containing AtomicCounterMemory bit, additionally requires
 // AtomicStorage capability.
 
+// spv.BitReverse is available in all SPIR-V versiosn under Shader capability.
+
 // spv.GroupNonUniformBallot is available starting from SPIR-V 1.3 under
 // GroupNonUniform capability.
 
 // spv.SubgroupBallotKHR is available under in all SPIR-V versions under
 // SubgroupBallotKHR capability and SPV_KHR_shader_ballot extension.
 
+// The GeometryPointSize capability implies the Geometry capability, which
+// implies the Shader capability.
+
+// PhysicalStorageBuffer64 addressing model is available via extension
+// SPV_EXT_physical_storage_buffer or SPV_KHR_physical_storage_buffer;
+// both extensions are incorporated into SPIR-V 1.5.
+
+// Vulkan memory model is available via extension SPV_KHR_vulkan_memory_model,
+// which extensions are incorporated into SPIR-V 1.5.
+
 // Enum case symbol (value) map:
-// Version: 1.0 (0), 1.1 (1), 1.2 (2), 1.3 (3), 1.4 (4)
-// Capability: Kernel (6), AtomicStorage (21), GroupNonUniformBallot (64),
-//             SubgroupBallotKHR (4423)
+// Version: 1.0 (0), 1.1 (1), 1.2 (2), 1.3 (3), 1.4 (4), 1.5 (5)
+// Capability: Shader (1), Geometry (2), Kernel (6), AtomicStorage (21),
+//             GeometryPointSize (24), GroupNonUniformBallot (64),
+//             SubgroupBallotKHR (4423), VulkanMemoryModel (5345),
+//             PhysicalStorageBufferAddresses (5347)
 
 //===----------------------------------------------------------------------===//
 // MaxVersion
@@ -97,6 +111,24 @@ func @subgroup_ballot_missing_capability(%predicate: i1) -> vector<4xi32> attrib
   return %0: vector<4xi32>
 }
 
+// CHECK-LABEL: @bit_reverse_directly_implied_capability
+func @bit_reverse_directly_implied_capability(%operand: i32) -> i32 attributes {
+  spv.target_env = {version = 0: i32, extensions = [], capabilities = [2: i32]}
+} {
+  // CHECK: spv.BitReverse
+  %0 = "test.convert_to_bit_reverse_op"(%operand): (i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @bit_reverse_recursively_implied_capability
+func @bit_reverse_recursively_implied_capability(%operand: i32) -> i32 attributes {
+  spv.target_env = {version = 0: i32, extensions = [], capabilities = [24: i32]}
+} {
+  // CHECK: spv.BitReverse
+  %0 = "test.convert_to_bit_reverse_op"(%operand): (i32) -> (i32)
+  return %0: i32
+}
+
 //===----------------------------------------------------------------------===//
 // Extension
 //===----------------------------------------------------------------------===//
@@ -118,3 +150,49 @@ func @subgroup_ballot_missing_extension(%predicate: i1) -> vector<4xi32> attribu
   %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
   return %0: vector<4xi32>
 }
+
+// CHECK-LABEL: @module_suitable_extension1
+func @module_suitable_extension1() attributes {
+  spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_vulkan_memory_model", "SPV_EXT_physical_storage_buffer"], capabilities = [5345: i32, 5347: i32]}
+} {
+  // CHECK: spv.module "PhysicalStorageBuffer64" "Vulkan"
+  "test.convert_to_module_op"() : () ->()
+  return
+}
+
+// CHECK-LABEL: @module_suitable_extension2
+func @module_suitable_extension2() attributes {
+  spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_vulkan_memory_model", "SPV_KHR_physical_storage_buffer"], capabilities = [5345: i32, 5347: i32]}
+} {
+  // CHECK: spv.module "PhysicalStorageBuffer64" "Vulkan"
+  "test.convert_to_module_op"() : () -> ()
+  return
+}
+
+// CHECK-LABEL: @module_missing_extension_mm
+func @module_missing_extension_mm() attributes {
+  spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_physical_storage_buffer"], capabilities = [5345: i32, 5347: i32]}
+} {
+  // CHECK: test.convert_to_module_op
+  "test.convert_to_module_op"() : () -> ()
+  return
+}
+
+// CHECK-LABEL: @module_missing_extension_am
+func @module_missing_extension_am() attributes {
+  spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_vulkan_memory_model"], capabilities = [5345: i32, 5347: i32]}
+} {
+  // CHECK: test.convert_to_module_op
+  "test.convert_to_module_op"() : () -> ()
+  return
+}
+
+// CHECK-LABEL: @module_implied_extension
+func @module_implied_extension() attributes {
+  // Version 1.5 implies SPV_KHR_vulkan_memory_model and SPV_KHR_physical_storage_buffer.
+  spv.target_env = {version = 5: i32, extensions = [], capabilities = [5345: i32, 5347: i32]}
+} {
+  // CHECK: spv.module "PhysicalStorageBuffer64" "Vulkan"
+  "test.convert_to_module_op"() : () -> ()
+  return
+}
index caede03d8229b527a52f59f4675746ae1564b693..7313c9ce34a7a4532e9f40f57b167f3ae635a5c8 100644 (file)
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Support/STLExtras.h"
 #include "mlir/Support/StringExtras.h"
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Format.h"
@@ -1283,3 +1284,48 @@ static mlir::GenRegistration
                           [](const RecordKeeper &records, raw_ostream &os) {
                             return emitAvailabilityImpl(records, os);
                           });
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Capability Implication AutoGen
+//===----------------------------------------------------------------------===//
+
+static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
+                                      raw_ostream &os) {
+  llvm::emitSourceFileHeader("SPIR-V Capability Implication", os);
+
+  EnumAttr enumAttr(recordKeeper.getDef("SPV_CapabilityAttr"));
+
+  os << "ArrayRef<Capability> "
+        "spirv::getDirectImpliedCapabilities(Capability cap) {\n"
+     << "  switch (cap) {\n"
+     << "  default: return {};\n";
+  for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) {
+    const Record &def = enumerant.getDef();
+    if (!def.getValue("implies"))
+      continue;
+
+    os << "  case Capability::" << enumerant.getSymbol()
+       << ": {static Capability implies[] = {";
+    std::vector<Record *> impliedCapsDefs = def.getValueAsListOfDefs("implies");
+    mlir::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) {
+      os << "Capability::" << EnumAttrCase(capDef).getSymbol();
+    });
+    os << "}; return implies; }\n";
+  }
+  os << "  }\n";
+  os << "}\n";
+
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Capability Implication Hook Registration
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+    genCapabilityImplication("gen-spirv-capability-implication",
+                             "Generate utilty function to return implied "
+                             "capabilities for a given capability",
+                             [](const RecordKeeper &records, raw_ostream &os) {
+                               return emitCapabilityImplication(records, os);
+                             });
index ef9be45798ba174e00c8118200bc48ed2b6cdc2e..096aa37910031a240e69cdb7eb773a808a74729b 100755 (executable)
@@ -266,6 +266,13 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
   exts = enum_case.get('extensions', [])
   if exts:
     exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts))))
+    # We need to strip the minimal version requirement if this symbol is
+    # available via an extension, which means *any* SPIR-V version can support
+    # it as long as the extension is provided. The grammar's 'version' field
+    # under such case should be interpreted as this symbol is introduced as
+    # a core symbol since the given version, rather than a minimal version
+    # requirement.
+    min_version = 'MinVersion<SPV_V_1_0>' if for_op else ''
   # TODO(antiagainst): delete this once ODS can support dialect-specific content
   # and we can use omission to mean no requirements.
   if for_op and not exts: