Fix gen_spirv_dialect.py regarding 1D/2D/3D Dim symbol name
authorLei Zhang <antiagainst@google.com>
Mon, 18 Nov 2019 20:47:54 +0000 (12:47 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 18 Nov 2019 20:48:24 +0000 (12:48 -0800)
PiperOrigin-RevId: 281131561

mlir/utils/spirv/gen_spirv_dialect.py

index 1e1af82..723a409 100755 (executable)
@@ -152,6 +152,15 @@ def gen_operand_kind_enum_attr(operand_kind):
   if 'enumerants' not in operand_kind:
     return '', ''
 
+  # Returns a symbol for the given case in the given kind. This function
+  # handles Dim specially to avoid having numbers as the start of symbols,
+  # which does not play well with C++ and the MLIR parser.
+  def get_case_symbol(kind_name, case_name):
+    if kind_name == 'Dim':
+      if case_name == '1D' or case_name == '2D' or case_name == '3D':
+        return 'Dim{}'.format(case_name)
+    return case_name
+
   kind_name = operand_kind['kind']
   is_bit_enum = operand_kind['category'] == 'BitEnum'
   kind_category = 'Bit' if is_bit_enum else 'I32'
@@ -162,13 +171,14 @@ def gen_operand_kind_enum_attr(operand_kind):
   max_len = max([len(symbol) for (symbol, _) in kind_cases])
 
   # Generate the definition for each enum case
-  fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\
+  fmt_str = 'def SPV_{acronym}_{case} {colon:>{offset}} '\
             '{category}EnumAttrCase<"{symbol}", {value}>;'
   case_defs = [
       fmt_str.format(
           category=kind_category,
           acronym=kind_acronym,
-          symbol=case[0],
+          case=case[0],
+          symbol=get_case_symbol(kind_name, case[0]),
           value=case[1],
           colon=':',
           offset=(max_len + 1 - len(case[0]))) for case in kind_cases