From 8a53050d7d3fd800c211e20fe9d9f83249ca098c Mon Sep 17 00:00:00 2001 From: Alyssa Rosenzweig Date: Fri, 3 Mar 2023 00:29:49 -0500 Subject: [PATCH] agx: Implement extract_[ui]16 Instead of lowering to bitwise ops. Yet another way of subdividing in NIR. Probably insignificant but makes it easy to check that the pass ordering from the previous pass is right. It does let us get much better codegen for unpacksnorm2x16, whatever that's worth. No shader-db changes. Signed-off-by: Alyssa Rosenzweig Part-of: --- src/asahi/compiler/agx_compile.h | 1 - src/asahi/compiler/agx_nir_algebraic.py | 13 +++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/asahi/compiler/agx_compile.h b/src/asahi/compiler/agx_compile.h index 5702006..0ee8f43 100644 --- a/src/asahi/compiler/agx_compile.h +++ b/src/asahi/compiler/agx_compile.h @@ -206,7 +206,6 @@ static const nir_shader_compiler_options agx_nir_options = { .lower_pack_half_2x16 = true, .lower_unpack_half_2x16 = true, .lower_extract_byte = true, - .lower_extract_word = true, .lower_insert_byte = true, .lower_insert_word = true, .lower_cs_local_index_to_id = true, diff --git a/src/asahi/compiler/agx_nir_algebraic.py b/src/asahi/compiler/agx_nir_algebraic.py index cf60f13..deef09d 100644 --- a/src/asahi/compiler/agx_nir_algebraic.py +++ b/src/asahi/compiler/agx_nir_algebraic.py @@ -21,12 +21,21 @@ for s in [8, 16, 32, 64]: lower_sm5_shift += [((shift, f'a@{s}', b), (shift, a, ('iand', b, s - 1)))] -lower_half_pack = [ +lower_pack = [ (('pack_half_2x16_split', a, b), ('pack_32_2x16_split', ('f2f16', a), ('f2f16', b))), (('unpack_half_2x16_split_x', a), ('f2f32', ('unpack_32_2x16_split_x', a))), (('unpack_half_2x16_split_y', a), ('f2f32', ('unpack_32_2x16_split_y', a))), + + (('extract_u16', 'a@32', 0), ('u2u32', ('unpack_32_2x16_split_x', a))), + (('extract_u16', 'a@32', 1), ('u2u32', ('unpack_32_2x16_split_y', a))), + (('extract_i16', 'a@32', 0), ('i2i32', ('unpack_32_2x16_split_x', a))), + (('extract_i16', 'a@32', 1), ('i2i32', ('unpack_32_2x16_split_y', a))), + + # For optimizing extract->convert sequences for unpack/pack norm + (('u2f32', ('u2u32', a)), ('u2f32', a)), + (('i2f32', ('i2i32', a)), ('i2f32', a)), ] def main(): @@ -42,7 +51,7 @@ def run(): print('#include "agx_nir.h"') print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late", - lower_sm5_shift + lower_half_pack).render()) + lower_sm5_shift + lower_pack).render()) if __name__ == '__main__': -- 2.7.4