agx: Implement extract_[ui]16
authorAlyssa Rosenzweig <alyssa@rosenzweig.io>
Fri, 3 Mar 2023 05:29:49 +0000 (00:29 -0500)
committerMarge Bot <emma+marge@anholt.net>
Sat, 11 Mar 2023 14:15:50 +0000 (14:15 +0000)
Instead of lowering to bitwise ops. Yet another way of subdividing in NIR.
Probably insignificant but makes it easy to check that the pass ordering from the
previous pass is right. It does let us get much better codegen for
unpacksnorm2x16, whatever that's worth.

No shader-db changes.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21674>

src/asahi/compiler/agx_compile.h
src/asahi/compiler/agx_nir_algebraic.py

index 5702006..0ee8f43 100644 (file)
@@ -206,7 +206,6 @@ static const nir_shader_compiler_options agx_nir_options = {
    .lower_pack_half_2x16 = true,
    .lower_unpack_half_2x16 = true,
    .lower_extract_byte = true,
-   .lower_extract_word = true,
    .lower_insert_byte = true,
    .lower_insert_word = true,
    .lower_cs_local_index_to_id = true,
index cf60f13..deef09d 100644 (file)
@@ -21,12 +21,21 @@ for s in [8, 16, 32, 64]:
         lower_sm5_shift += [((shift, f'a@{s}', b),
                              (shift, a, ('iand', b, s - 1)))]
 
-lower_half_pack = [
+lower_pack = [
     (('pack_half_2x16_split', a, b),
      ('pack_32_2x16_split', ('f2f16', a), ('f2f16', b))),
 
     (('unpack_half_2x16_split_x', a), ('f2f32', ('unpack_32_2x16_split_x', a))),
     (('unpack_half_2x16_split_y', a), ('f2f32', ('unpack_32_2x16_split_y', a))),
+
+    (('extract_u16', 'a@32', 0), ('u2u32', ('unpack_32_2x16_split_x', a))),
+    (('extract_u16', 'a@32', 1), ('u2u32', ('unpack_32_2x16_split_y', a))),
+    (('extract_i16', 'a@32', 0), ('i2i32', ('unpack_32_2x16_split_x', a))),
+    (('extract_i16', 'a@32', 1), ('i2i32', ('unpack_32_2x16_split_y', a))),
+
+    # For optimizing extract->convert sequences for unpack/pack norm
+    (('u2f32', ('u2u32', a)), ('u2f32', a)),
+    (('i2f32', ('i2i32', a)), ('i2f32', a)),
 ]
 
 def main():
@@ -42,7 +51,7 @@ def run():
     print('#include "agx_nir.h"')
 
     print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late",
-                                      lower_sm5_shift + lower_half_pack).render())
+                                      lower_sm5_shift + lower_pack).render())
 
 
 if __name__ == '__main__':