This gives our shifts SM5 behaviour at the cost of a little extra ALU. That way,
we match NIR's shifts.
This fixes unsoundness of GLSL expressions like "a << (b & 31)", where the &
would mistakenly get optimized away.
Closes: #8181
Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reported-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21673>
#include "agx_builder.h"
#include "agx_compiler.h"
#include "agx_internal_formats.h"
+#include "agx_nir.h"
/* Alignment for shader programs. I'm not sure what the optimal value is. */
#define AGX_CODE_ALIGN 0x100
NIR_PASS_V(nir, nir_opt_peephole_select, 64, false, true);
NIR_PASS_V(nir, nir_opt_algebraic_late);
+ NIR_PASS_V(nir, agx_nir_lower_algebraic_late);
NIR_PASS_V(nir, nir_opt_constant_folding);
/* Must run after uses are fixed but before a last round of copyprop + DCE */
--- /dev/null
+/*
+ * Copyright 2023 Alyssa Rosenzweig
+ * SPDX-License-Identifier: MIT
+ */
+#ifndef AGX_NIR_H
+#define AGX_NIR_H
+
+#include <stdbool.h>
+
+struct nir_shader;
+
+bool agx_nir_lower_algebraic_late(struct nir_shader *shader);
+
+#endif
--- /dev/null
+# Copyright 2022 Alyssa Rosenzweig
+# Copyright 2021 Collabora, Ltd.
+# Copyright 2016 Intel Corporation
+# SPDX-License-Identifier: MIT
+
+import argparse
+import sys
+import math
+
+a = 'a'
+b = 'b'
+c = 'c'
+
+lower_sm5_shift = []
+
+# Our shifts differ from SM5 for the upper bits. Mask to match the NIR
+# behaviour. Because this happens as a late lowering, NIR won't optimize the
+# masking back out (that happens in the main nir_opt_algebraic).
+for s in [8, 16, 32, 64]:
+ for shift in ["ishl", "ishr", "ushr"]:
+ lower_sm5_shift += [((shift, f'a@{s}', b),
+ (shift, a, ('iand', b, s - 1)))]
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-p', '--import-path', required=True)
+ args = parser.parse_args()
+ sys.path.insert(0, args.import_path)
+ run()
+
+def run():
+ import nir_algebraic # pylint: disable=import-error
+
+ print('#include "agx_nir.h"')
+
+ print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late",
+ lower_sm5_shift).render())
+
+
+if __name__ == '__main__':
+ main()
'agx_validate.c',
)
+agx_nir_algebraic_c = custom_target(
+ 'agx_nir_algebraic.c',
+ input : 'agx_nir_algebraic.py',
+ output : 'agx_nir_algebraic.c',
+ command : [
+ prog_python, '@INPUT@', '-p', dir_compiler_nir,
+ ],
+ capture : true,
+ depend_files : nir_algebraic_depends,
+)
+
agx_opcodes_h = custom_target(
'agx_opcodes.h',
input : ['agx_opcodes.h.py'],
libasahi_compiler = static_library(
'asahi_compiler',
- [libasahi_agx_files, agx_opcodes_c],
+ [libasahi_agx_files, agx_opcodes_c, agx_nir_algebraic_c],
include_directories : [inc_include, inc_src, inc_mesa, inc_gallium, inc_gallium_aux, inc_mapi],
dependencies: [idep_nir, idep_agx_opcodes_h, idep_agx_builder_h, idep_agx_pack],
c_args : [no_override_init_args],