From 1f475e316cc9991004fbfb8096bd0b10398515f3 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 18 Nov 2019 12:47:54 -0800 Subject: [PATCH] Fix gen_spirv_dialect.py regarding 1D/2D/3D Dim symbol name PiperOrigin-RevId: 281131561 --- mlir/utils/spirv/gen_spirv_dialect.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index 1e1af82..723a409 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -152,6 +152,15 @@ def gen_operand_kind_enum_attr(operand_kind): if 'enumerants' not in operand_kind: return '', '' + # Returns a symbol for the given case in the given kind. This function + # handles Dim specially to avoid having numbers as the start of symbols, + # which does not play well with C++ and the MLIR parser. + def get_case_symbol(kind_name, case_name): + if kind_name == 'Dim': + if case_name == '1D' or case_name == '2D' or case_name == '3D': + return 'Dim{}'.format(case_name) + return case_name + kind_name = operand_kind['kind'] is_bit_enum = operand_kind['category'] == 'BitEnum' kind_category = 'Bit' if is_bit_enum else 'I32' @@ -162,13 +171,14 @@ def gen_operand_kind_enum_attr(operand_kind): max_len = max([len(symbol) for (symbol, _) in kind_cases]) # Generate the definition for each enum case - fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\ + fmt_str = 'def SPV_{acronym}_{case} {colon:>{offset}} '\ '{category}EnumAttrCase<"{symbol}", {value}>;' case_defs = [ fmt_str.format( category=kind_category, acronym=kind_acronym, - symbol=case[0], + case=case[0], + symbol=get_case_symbol(kind_name, case[0]), value=case[1], colon=':', offset=(max_len + 1 - len(case[0]))) for case in kind_cases -- 2.7.4