[spirv] Add values for enum cases and generate the enum utilities
authorLei Zhang <antiagainst@google.com>
Mon, 10 Jun 2019 22:12:04 +0000 (15:12 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 11 Jun 2019 17:13:20 +0000 (10:13 -0700)
PiperOrigin-RevId: 252494957

mlir/include/mlir/SPIRV/SPIRVBase.td
mlir/include/mlir/SPIRV/SPIRVTypes.h
mlir/lib/SPIRV/SPIRVTypes.cpp
mlir/utils/spirv/gen_spirv_dialect.py

index 00e7f35..698c248 100644 (file)
@@ -68,24 +68,30 @@ class SPV_ScalarOrVectorOf<Type type> :
 
 // Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY!
 
-def SPV_AM_Logical                    : EnumAttrCase<"Logical">;
-def SPV_AM_Physical32                 : EnumAttrCase<"Physical32">;
-def SPV_AM_Physical64                 : EnumAttrCase<"Physical64">;
-def SPV_AM_PhysicalStorageBuffer64EXT : EnumAttrCase<"PhysicalStorageBuffer64EXT">;
+def SPV_AM_Logical                    : EnumAttrCase<"Logical", 0>;
+def SPV_AM_Physical32                 : EnumAttrCase<"Physical32", 1>;
+def SPV_AM_Physical64                 : EnumAttrCase<"Physical64", 2>;
+def SPV_AM_PhysicalStorageBuffer64EXT : EnumAttrCase<"PhysicalStorageBuffer64EXT", 5348>;
 def SPV_AddressingModelAttr :
     EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [
       SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
       SPV_AM_PhysicalStorageBuffer64EXT
-    ]>;
+    ]> {
+  let cppNamespace = "::mlir::spirv";
+  let underlyingType = "uint32_t";
+}
 
-def SPV_MM_Simple    : EnumAttrCase<"Simple">;
-def SPV_MM_GLSL450   : EnumAttrCase<"GLSL450">;
-def SPV_MM_OpenCL    : EnumAttrCase<"OpenCL">;
-def SPV_MM_VulkanKHR : EnumAttrCase<"VulkanKHR">;
+def SPV_MM_Simple    : EnumAttrCase<"Simple", 0>;
+def SPV_MM_GLSL450   : EnumAttrCase<"GLSL450", 1>;
+def SPV_MM_OpenCL    : EnumAttrCase<"OpenCL", 2>;
+def SPV_MM_VulkanKHR : EnumAttrCase<"VulkanKHR", 3>;
 def SPV_MemoryModelAttr :
     EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [
       SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_VulkanKHR
-    ]>;
+    ]> {
+  let cppNamespace = "::mlir::spirv";
+  let underlyingType = "uint32_t";
+}
 
 // End enum section. Generated from SPIR-V spec; DO NOT MODIFY!
 
index 2e2c819..3753d04 100644 (file)
@@ -24,6 +24,9 @@
 
 #include "mlir/IR/Types.h"
 
+// Pull in all enum type definitions and utility function declarations
+#include "mlir/SPIRV/SPIRVEnums.h.inc"
+
 namespace mlir {
 namespace spirv {
 
index 1e24675..8edcc96 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "mlir/SPIRV/SPIRVTypes.h"
+#include "llvm/ADT/StringSwitch.h"
 
 using namespace mlir;
 using namespace mlir::spirv;
 
+// Pull in all enum utility function definitions
+#include "mlir/SPIRV/SPIRVEnums.cpp.inc"
+
 //===----------------------------------------------------------------------===//
 // ArrayType
 //===----------------------------------------------------------------------===//
index 1ad9312..f815ae3 100755 (executable)
@@ -96,10 +96,15 @@ def gen_operand_kind_enum_attr(operand_kind):
 
   # Generate the definition for each enum case
   fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\
-            'EnumAttrCase<"{symbol}">;'
-  case_defs = [fmt_str.format(acronym=kind_acronym, symbol=case[0],
-                              colon=':', offset=(max_len + 1 - len(case[0])))
-               for case in kind_cases]
+            'EnumAttrCase<"{symbol}", {value}>;'
+  case_defs = [
+      fmt_str.format(
+          acronym=kind_acronym,
+          symbol=case[0],
+          value=case[1],
+          colon=':',
+          offset=(max_len + 1 - len(case[0]))) for case in kind_cases
+  ]
   case_defs = '\n'.join(case_defs)
 
   # Generate the list of enum case names
@@ -115,7 +120,9 @@ def gen_operand_kind_enum_attr(operand_kind):
 
   # Generate the enum attribute definition
   enum_attr = 'def SPV_{name}Attr :\n    '\
-      'EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n    ]>;'.format(
+      'EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n    ]> {{\n'\
+      '  let cppNamespace = "::mlir::spirv";\n'\
+      '  let underlyingType = "uint32_t";\n}}'.format(
           name=kind_name, cases=case_names)
   return kind_name, case_defs + '\n' + enum_attr