.lower_pack_half_2x16 = true,
.lower_unpack_half_2x16 = true,
.lower_extract_byte = true,
- .lower_extract_word = true,
.lower_insert_byte = true,
.lower_insert_word = true,
.lower_cs_local_index_to_id = true,
lower_sm5_shift += [((shift, f'a@{s}', b),
(shift, a, ('iand', b, s - 1)))]
-lower_half_pack = [
+lower_pack = [
(('pack_half_2x16_split', a, b),
('pack_32_2x16_split', ('f2f16', a), ('f2f16', b))),
(('unpack_half_2x16_split_x', a), ('f2f32', ('unpack_32_2x16_split_x', a))),
(('unpack_half_2x16_split_y', a), ('f2f32', ('unpack_32_2x16_split_y', a))),
+
+ (('extract_u16', 'a@32', 0), ('u2u32', ('unpack_32_2x16_split_x', a))),
+ (('extract_u16', 'a@32', 1), ('u2u32', ('unpack_32_2x16_split_y', a))),
+ (('extract_i16', 'a@32', 0), ('i2i32', ('unpack_32_2x16_split_x', a))),
+ (('extract_i16', 'a@32', 1), ('i2i32', ('unpack_32_2x16_split_y', a))),
+
+ # For optimizing extract->convert sequences for unpack/pack norm
+ (('u2f32', ('u2u32', a)), ('u2f32', a)),
+ (('i2f32', ('i2i32', a)), ('i2f32', a)),
]
def main():
print('#include "agx_nir.h"')
print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late",
- lower_sm5_shift + lower_half_pack).render())
+ lower_sm5_shift + lower_pack).render())
if __name__ == '__main__':