From e99ba0b6d3fb93ba6bef4b3d0d6567cbeda7b367 Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Tue, 14 Feb 2023 21:38:41 +0000 Subject: [PATCH] nir/range_analysis: use perform_analysis() in nir_analyze_range() Signed-off-by: Rhys Perry Reviewed-by: Georg Lehmann Part-of: --- src/compiler/nir/nir_range_analysis.c | 274 +++++++++++++++++++++------------- 1 file changed, 170 insertions(+), 104 deletions(-) diff --git a/src/compiler/nir/nir_range_analysis.c b/src/compiler/nir/nir_range_analysis.c index aecd601..532335c 100644 --- a/src/compiler/nir/nir_range_analysis.c +++ b/src/compiler/nir/nir_range_analysis.c @@ -126,18 +126,15 @@ is_not_zero(enum ssa_ranges r) return r == gt_zero || r == lt_zero || r == ne_zero; } -static void * +static uint32_t pack_data(const struct ssa_result_range r) { - return (void *)(uintptr_t)(r.range | r.is_integral << 8 | r.is_finite << 9 | - r.is_a_number << 10); + return r.range | r.is_integral << 8 | r.is_finite << 9 | r.is_a_number << 10; } static struct ssa_result_range -unpack_data(const void *p) +unpack_data(uint32_t v) { - const uintptr_t v = (uintptr_t) p; - return (struct ssa_result_range){ .range = v & 0xff, .is_integral = (v & 0x00100) != 0, @@ -146,31 +143,6 @@ unpack_data(const void *p) }; } -static void * -pack_key(const struct nir_alu_instr *instr, nir_alu_type type) -{ - uintptr_t type_encoding; - uintptr_t ptr = (uintptr_t) instr; - - /* The low 2 bits have to be zero or this whole scheme falls apart. */ - assert((ptr & 0x3) == 0); - - /* NIR is typeless in the sense that sequences of bits have whatever - * meaning is attached to them by the instruction that consumes them. - * However, the number of bits must match between producer and consumer. - * As a result, the number of bits does not need to be encoded here. - */ - switch (nir_alu_type_get_base_type(type)) { - case nir_type_int: type_encoding = 0; break; - case nir_type_uint: type_encoding = 1; break; - case nir_type_bool: type_encoding = 2; break; - case nir_type_float: type_encoding = 3; break; - default: unreachable("Invalid base type."); - } - - return (void *)(ptr | type_encoding); -} - static nir_alu_type nir_alu_src_type(const nir_alu_instr *instr, unsigned src) { @@ -319,7 +291,7 @@ analyze_constant(const struct nir_alu_instr *instr, unsigned src, } /** - * Short-hand name for use in the tables in analyze_expression. If this name + * Short-hand name for use in the tables in process_fp_query. If this name * becomes a problem on some compiler, we can change it to _. */ #define _______ unknown @@ -502,6 +474,53 @@ union_ranges(enum ssa_ranges a, enum ssa_ranges b) #define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(t) #endif /* !defined(NDEBUG) */ +struct fp_query { + struct analysis_query head; + const nir_alu_instr *instr; + unsigned src; + nir_alu_type use_type; +}; + +static void +push_fp_query(struct analysis_state *state, const nir_alu_instr *alu, unsigned src, nir_alu_type type) +{ + struct fp_query *pushed_q = push_analysis_query(state, sizeof(struct fp_query)); + pushed_q->instr = alu; + pushed_q->src = src; + pushed_q->use_type = type == nir_type_invalid ? nir_alu_src_type(alu, src) : type; +} + +static uintptr_t +get_fp_key(struct analysis_query *q) +{ + struct fp_query *fp_q = (struct fp_query *)q; + const nir_src *src = &fp_q->instr->src[fp_q->src].src; + + if (!src->is_ssa || src->ssa->parent_instr->type != nir_instr_type_alu) + return 0; + + uintptr_t type_encoding; + uintptr_t ptr = (uintptr_t)nir_instr_as_alu(src->ssa->parent_instr); + + /* The low 2 bits have to be zero or this whole scheme falls apart. */ + assert((ptr & 0x3) == 0); + + /* NIR is typeless in the sense that sequences of bits have whatever + * meaning is attached to them by the instruction that consumes them. + * However, the number of bits must match between producer and consumer. + * As a result, the number of bits does not need to be encoded here. + */ + switch (nir_alu_type_get_base_type(fp_q->use_type)) { + case nir_type_int: type_encoding = 0; break; + case nir_type_uint: type_encoding = 1; break; + case nir_type_bool: type_encoding = 2; break; + case nir_type_float: type_encoding = 3; break; + default: unreachable("Invalid base type."); + } + + return ptr | type_encoding; +} + /** * Analyze an expression to determine the range of its result * @@ -511,21 +530,32 @@ union_ranges(enum ssa_ranges a, enum ssa_ranges b) * This function implements this grammar as a recursive-descent parser. Some * (but not all) of the grammar is listed in-line in the function. */ -static struct ssa_result_range -analyze_expression(const nir_alu_instr *instr, unsigned src, - struct hash_table *ht, nir_alu_type use_type) +static void +process_fp_query(struct analysis_state *state, struct analysis_query *aq, uint32_t *result, + const uint32_t *src_res) { /* Ensure that the _Pragma("GCC unroll 7") above are correct. */ STATIC_ASSERT(last_range + 1 == 7); - if (!instr->src[src].src.is_ssa) - return (struct ssa_result_range){unknown, false, false, false}; + struct fp_query q = *(struct fp_query *)aq; + const nir_alu_instr *instr = q.instr; + unsigned src = q.src; + nir_alu_type use_type = q.use_type; - if (nir_src_is_const(instr->src[src].src)) - return analyze_constant(instr, src, use_type); + if (!instr->src[src].src.is_ssa) { + *result = pack_data((struct ssa_result_range){unknown, false, false, false}); + return; + } - if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) - return (struct ssa_result_range){unknown, false, false, false}; + if (nir_src_is_const(instr->src[src].src)) { + *result = pack_data(analyze_constant(instr, src, use_type)); + return; + } + + if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) { + *result = pack_data((struct ssa_result_range){unknown, false, false, false}); + return; + } const struct nir_alu_instr *const alu = nir_instr_as_alu(instr->src[src].src.ssa->parent_instr); @@ -544,13 +574,62 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, if (use_base_type != src_base_type && (use_base_type == nir_type_float || src_base_type == nir_type_float)) { - return (struct ssa_result_range){unknown, false, false, false}; + *result = pack_data((struct ssa_result_range){unknown, false, false, false}); + return; } } - struct hash_entry *he = _mesa_hash_table_search(ht, pack_key(alu, use_type)); - if (he != NULL) - return unpack_data(he->data); + if (!aq->pushed_queries) { + switch (alu->op) { + case nir_op_bcsel: + push_fp_query(state, alu, 1, use_type); + push_fp_query(state, alu, 2, use_type); + return; + case nir_op_mov: + push_fp_query(state, alu, 0, use_type); + return; + case nir_op_i2f32: + case nir_op_u2f32: + case nir_op_fabs: + case nir_op_fexp2: + case nir_op_frcp: + case nir_op_fneg: + case nir_op_fsat: + case nir_op_fsign: + case nir_op_ffloor: + case nir_op_fceil: + case nir_op_ftrunc: + case nir_op_fdot2: + case nir_op_fdot3: + case nir_op_fdot4: + case nir_op_fdot8: + case nir_op_fdot16: + case nir_op_fdot2_replicated: + case nir_op_fdot3_replicated: + case nir_op_fdot4_replicated: + case nir_op_fdot8_replicated: + case nir_op_fdot16_replicated: + push_fp_query(state, alu, 0, nir_type_invalid); + return; + case nir_op_fadd: + case nir_op_fmax: + case nir_op_fmin: + case nir_op_fmul: + case nir_op_fmulz: + case nir_op_fpow: + push_fp_query(state, alu, 0, nir_type_invalid); + push_fp_query(state, alu, 1, nir_type_invalid); + return; + case nir_op_ffma: + case nir_op_flrp: + push_fp_query(state, alu, 0, nir_type_invalid); + push_fp_query(state, alu, 1, nir_type_invalid); + push_fp_query(state, alu, 2, nir_type_invalid); + return; + default: + break; + } + } struct ssa_result_range r = {unknown, false, false, false}; @@ -666,10 +745,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, break; case nir_op_bcsel: { - const struct ssa_result_range left = - analyze_expression(alu, 1, ht, use_type); - const struct ssa_result_range right = - analyze_expression(alu, 2, ht, use_type); + const struct ssa_result_range left = unpack_data(src_res[0]); + const struct ssa_result_range right = unpack_data(src_res[1]); r.is_integral = left.is_integral && right.is_integral; @@ -694,7 +771,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, case nir_op_i2f32: case nir_op_u2f32: - r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + r = unpack_data(src_res[0]); r.is_integral = true; r.is_a_number = true; @@ -706,7 +783,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, break; case nir_op_fabs: - r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + r = unpack_data(src_res[0]); switch (r.range) { case unknown: @@ -728,10 +805,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, break; case nir_op_fadd: { - const struct ssa_result_range left = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); - const struct ssa_result_range right = - analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); + const struct ssa_result_range left = unpack_data(src_res[0]); + const struct ssa_result_range right = unpack_data(src_res[1]); r.is_integral = left.is_integral && right.is_integral; r.range = fadd_table[left.range][right.range]; @@ -755,7 +830,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, ge_zero, ge_zero, ge_zero, gt_zero, gt_zero, ge_zero, gt_zero }; - r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + r = unpack_data(src_res[0]); ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(table); ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(table); @@ -770,10 +845,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_fmax: { - const struct ssa_result_range left = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); - const struct ssa_result_range right = - analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); + const struct ssa_result_range left = unpack_data(src_res[0]); + const struct ssa_result_range right = unpack_data(src_res[1]); r.is_integral = left.is_integral && right.is_integral; @@ -856,10 +929,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_fmin: { - const struct ssa_result_range left = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); - const struct ssa_result_range right = - analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); + const struct ssa_result_range left = unpack_data(src_res[0]); + const struct ssa_result_range right = unpack_data(src_res[1]); r.is_integral = left.is_integral && right.is_integral; @@ -943,10 +1014,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, case nir_op_fmul: case nir_op_fmulz: { - const struct ssa_result_range left = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); - const struct ssa_result_range right = - analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); + const struct ssa_result_range left = unpack_data(src_res[0]); + const struct ssa_result_range right = unpack_data(src_res[1]); r.is_integral = left.is_integral && right.is_integral; @@ -981,7 +1050,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, case nir_op_frcp: r = (struct ssa_result_range){ - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range, + unpack_data(src_res[0]).range, false, false, /* Various cases can result in NaN, so assume the worst. */ false /* " " " " " " " " " " */ @@ -989,18 +1058,16 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, break; case nir_op_mov: - r = analyze_expression(alu, 0, ht, use_type); + r = unpack_data(src_res[0]); break; case nir_op_fneg: - r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); - + r = unpack_data(src_res[0]); r.range = fneg_table[r.range]; break; case nir_op_fsat: { - const struct ssa_result_range left = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range left = unpack_data(src_res[0]); /* fsat(NaN) = 0. */ r.is_a_number = true; @@ -1035,7 +1102,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, case nir_op_fsign: r = (struct ssa_result_range){ - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range, + unpack_data(src_res[0]).range, true, true, /* fsign is -1, 0, or 1, even for NaN, so it must be a number. */ true /* fsign is -1, 0, or 1, even for NaN, so it must be finite. */ @@ -1048,8 +1115,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, break; case nir_op_ffloor: { - const struct ssa_result_range left = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range left = unpack_data(src_res[0]); r.is_integral = true; @@ -1070,8 +1136,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_fceil: { - const struct ssa_result_range left = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range left = unpack_data(src_res[0]); r.is_integral = true; @@ -1092,8 +1157,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_ftrunc: { - const struct ssa_result_range left = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range left = unpack_data(src_res[0]); r.is_integral = true; @@ -1139,8 +1203,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, case nir_op_fdot4_replicated: case nir_op_fdot8_replicated: case nir_op_fdot16_replicated: { - const struct ssa_result_range left = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); + const struct ssa_result_range left = unpack_data(src_res[0]); /* If the two sources are the same SSA value, then the result is either * NaN or some number >= 0. If one source is the negation of the other, @@ -1211,10 +1274,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, /* eq_zero */ { ge_zero, gt_zero, gt_zero, eq_zero, ge_zero, ge_zero, gt_zero }, }; - const struct ssa_result_range left = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); - const struct ssa_result_range right = - analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); + const struct ssa_result_range left = unpack_data(src_res[0]); + const struct ssa_result_range right = unpack_data(src_res[1]); ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(table); ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(table); @@ -1230,12 +1291,9 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_ffma: { - const struct ssa_result_range first = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); - const struct ssa_result_range second = - analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); - const struct ssa_result_range third = - analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2)); + const struct ssa_result_range first = unpack_data(src_res[0]); + const struct ssa_result_range second = unpack_data(src_res[1]); + const struct ssa_result_range third = unpack_data(src_res[2]); r.is_integral = first.is_integral && second.is_integral && third.is_integral; @@ -1261,12 +1319,9 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, } case nir_op_flrp: { - const struct ssa_result_range first = - analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)); - const struct ssa_result_range second = - analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1)); - const struct ssa_result_range third = - analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2)); + const struct ssa_result_range first = unpack_data(src_res[0]); + const struct ssa_result_range second = unpack_data(src_res[1]); + const struct ssa_result_range third = unpack_data(src_res[2]); r.is_integral = first.is_integral && second.is_integral && third.is_integral; @@ -1296,18 +1351,29 @@ analyze_expression(const nir_alu_instr *instr, unsigned src, /* Just like isfinite(), the is_finite flag implies the value is a number. */ assert((int) r.is_finite <= (int) r.is_a_number); - _mesa_hash_table_insert(ht, pack_key(alu, use_type), pack_data(r)); - return r; + *result = pack_data(r); } #undef _______ struct ssa_result_range nir_analyze_range(struct hash_table *range_ht, - const nir_alu_instr *instr, unsigned src) + const nir_alu_instr *alu, unsigned src) { - return analyze_expression(instr, src, range_ht, - nir_alu_src_type(instr, src)); + struct fp_query query_alloc[64]; + uint32_t result_alloc[64]; + + struct analysis_state state; + state.range_ht = range_ht; + util_dynarray_init_from_stack(&state.query_stack, query_alloc, sizeof(query_alloc)); + util_dynarray_init_from_stack(&state.result_stack, result_alloc, sizeof(result_alloc)); + state.query_size = sizeof(struct fp_query); + state.get_key = &get_fp_key; + state.process_query = &process_fp_query; + + push_fp_query(&state, alu, src, nir_type_invalid); + + return unpack_data(perform_analysis(&state)); } static uint32_t bitmask(uint32_t size) { -- 2.7.4