From bb1e6ff161c9a438243392ca2cbebac781293658 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Tue, 5 Dec 2017 22:31:02 -0800 Subject: [PATCH] spirv: Add a prepass to set types on vtn_values This autogenerated pass will automatically find and set the type field on all vtn_values. This way we always have the type and can use it for validation and other checks. Reviewed-by: Ian Romanick --- src/compiler/Makefile.nir.am | 4 + src/compiler/Makefile.sources | 3 +- src/compiler/nir/meson.build | 11 ++- src/compiler/spirv/spirv_to_nir.c | 6 +- src/compiler/spirv/vtn_gather_types_c.py | 124 +++++++++++++++++++++++++++++++ src/compiler/spirv/vtn_private.h | 4 + 6 files changed, 149 insertions(+), 3 deletions(-) create mode 100644 src/compiler/spirv/vtn_gather_types_c.py diff --git a/src/compiler/Makefile.nir.am b/src/compiler/Makefile.nir.am index 1533ee5..dd38c45 100644 --- a/src/compiler/Makefile.nir.am +++ b/src/compiler/Makefile.nir.am @@ -56,6 +56,10 @@ spirv/spirv_info.c: spirv/spirv_info_c.py spirv/spirv.core.grammar.json $(MKDIR_GEN) $(PYTHON_GEN) $(srcdir)/spirv/spirv_info_c.py $(srcdir)/spirv/spirv.core.grammar.json $@ || ($(RM) $@; false) +spirv/vtn_gather_types.c: spirv/vtn_gather_types_c.py spirv/spirv.core.grammar.json + $(MKDIR_GEN) + $(PYTHON_GEN) $(srcdir)/spirv/vtn_gather_types_c.py $(srcdir)/spirv/spirv.core.grammar.json $@ || ($(RM) $@; false) + noinst_PROGRAMS += spirv2nir spirv2nir_SOURCES = \ diff --git a/src/compiler/Makefile.sources b/src/compiler/Makefile.sources index 588f96e..d3f746f 100644 --- a/src/compiler/Makefile.sources +++ b/src/compiler/Makefile.sources @@ -290,7 +290,8 @@ NIR_FILES = \ nir/nir_worklist.h SPIRV_GENERATED_FILES = \ - spirv/spirv_info.c + spirv/spirv_info.c \ + spirv/vtn_gather_types.c SPIRV_FILES = \ spirv/GLSL.std.450.h \ diff --git a/src/compiler/nir/meson.build b/src/compiler/nir/meson.build index b61a077..5dd21e6 100644 --- a/src/compiler/nir/meson.build +++ b/src/compiler/nir/meson.build @@ -72,6 +72,14 @@ spirv_info_c = custom_target( command : [prog_python2, '@INPUT0@', '@INPUT1@', '@OUTPUT@'], ) +vtn_gather_types_c = custom_target( + 'vtn_gather_types.c', + input : files('../spirv/vtn_gather_types_c.py', + '../spirv/spirv.core.grammar.json'), + output : 'vtn_gather_types.c', + command : [prog_python2, '@INPUT0@', '@INPUT1@', '@OUTPUT@'], +) + files_libnir = files( 'nir.c', 'nir.h', @@ -189,7 +197,8 @@ files_libnir = files( libnir = static_library( 'nir', [files_libnir, spirv_info_c, nir_opt_algebraic_c, nir_opcodes_c, - nir_opcodes_h, nir_constant_expressions_c, nir_builder_opcodes_h], + nir_opcodes_h, nir_constant_expressions_c, nir_builder_opcodes_h, + vtn_gather_types_c], include_directories : [inc_common, inc_compiler, include_directories('../spirv')], c_args : [c_vis_args, c_msvc_compat_args, no_override_init_args], link_with : libcompiler, diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index a3090cf..7cbf0c9 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -1265,7 +1265,6 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) { struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_constant); - val->type = vtn_value(b, w[1], vtn_value_type_type)->type; val->constant = rzalloc(b, nir_constant); switch (opcode) { case SpvOpConstantTrue: @@ -3287,6 +3286,8 @@ static bool vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) { + vtn_set_instruction_result_type(b, opcode, w, count); + switch (opcode) { case SpvOpSource: case SpvOpSourceContinued: @@ -3677,6 +3678,9 @@ spirv_to_nir(const uint32_t *words, size_t word_count, words = vtn_foreach_instruction(b, words, word_end, vtn_handle_variable_or_type_instruction); + /* Set types on all vtn_values */ + vtn_foreach_instruction(b, words, word_end, vtn_set_instruction_result_type); + vtn_build_cfg(b, words, word_end); assert(b->entry_point->value_type == vtn_value_type_function); diff --git a/src/compiler/spirv/vtn_gather_types_c.py b/src/compiler/spirv/vtn_gather_types_c.py new file mode 100644 index 0000000..7b42e95 --- /dev/null +++ b/src/compiler/spirv/vtn_gather_types_c.py @@ -0,0 +1,124 @@ +COPYRIGHT = """\ +/* + * Copyright (C) 2017 Intel Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice (including the next + * paragraph) shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ +""" + +import argparse +import json +from sys import stdout +from mako.template import Template + +def find_result_types(spirv): + for inst in spirv['instructions']: + name = inst['opname'] + + if 'operands' not in inst: + continue + + res_arg_idx = -1 + res_type_arg_idx = -1 + for idx, arg in enumerate(inst['operands']): + if arg['kind'] == 'IdResult': + res_arg_idx = idx + elif arg['kind'] == 'IdResultType': + res_type_arg_idx = idx + + if res_type_arg_idx >= 0: + assert res_arg_idx >= 0 + elif res_arg_idx >= 0: + untyped_insts = [ + 'OpString', + 'OpExtInstImport', + 'OpDecorationGroup', + 'OpLabel', + ] + assert name.startswith('OpType') or name in untyped_insts + + if res_arg_idx >= 0 or res_type_arg_idx >= 0: + yield (name, res_arg_idx, res_type_arg_idx) + +TEMPLATE = Template(COPYRIGHT + """\ + +/* DO NOT EDIT - This file is generated automatically by the + * vtn_gather_types_c.py script + */ + +#include "vtn_private.h" + +struct type_args { + int res_idx; + int res_type_idx; +}; + +static struct type_args +result_type_args_for_opcode(SpvOp opcode) +{ + switch (opcode) { +% for opcode in opcodes: + case Spv${opcode[0]}: return (struct type_args){ ${opcode[1]}, ${opcode[2]} }; +% endfor + default: return (struct type_args){ -1, -1 }; + } +} + +bool +vtn_set_instruction_result_type(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count) +{ + struct type_args args = result_type_args_for_opcode(opcode); + + if (args.res_idx >= 0 && args.res_type_idx >= 0) { + struct vtn_value *val = vtn_untyped_value(b, w[1 + args.res_idx]); + val->type = vtn_value(b, w[1 + args.res_type_idx], + vtn_value_type_type)->type; + } + + return true; +} + +""") + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument("json") + p.add_argument("out") + args = p.parse_args() + + spirv_info = json.JSONDecoder().decode(open(args.json, "r").read()) + + opcodes = list(find_result_types(spirv_info)) + + try: + with open(args.out, 'w') as f: + f.write(TEMPLATE.render(opcodes=opcodes)) + except Exception: + # In the even there's an error this imports some helpers from mako + # to print a useful stack trace and prints it, then exits with + # status 1, if python is run with debug; otherwise it just raises + # the exception + if __debug__: + import sys + from mako import exceptions + sys.stderr.write(exceptions.text_error_template().render() + '\n') + sys.exit(1) + raise diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 6d4ad3c..a0a4f3a 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -620,6 +620,10 @@ vtn_value(struct vtn_builder *b, uint32_t value_id, return val; } +bool +vtn_set_instruction_result_type(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count); + struct vtn_ssa_value *vtn_ssa_value(struct vtn_builder *b, uint32_t value_id); struct vtn_ssa_value *vtn_create_ssa_value(struct vtn_builder *b, -- 2.7.4