gallivm: Use range analysis to generate better fmin and fmax code
authorIan Romanick <ian.d.romanick@intel.com>
Wed, 28 Apr 2021 22:41:56 +0000 (15:41 -0700)
committerMarge Bot <eric+marge@anholt.net>
Tue, 4 May 2021 00:13:34 +0000 (00:13 +0000)
If it is known that one of the source must be a number, then the (more
efficient) GALLIVM_NAN_RETURN_OTHER_SECOND_NONNAN path can be used.

v2: s/know to be/known to be/.  Noticed by Roland.

Reviewed-by: Roland Scheidegger <sroland@vmware.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10532>

src/gallium/auxiliary/gallivm/lp_bld_nir.c
src/gallium/auxiliary/gallivm/lp_bld_nir.h

index 149cce8..fd252bd 100644 (file)
@@ -35,6 +35,7 @@
 #include "lp_bld_debug.h"
 #include "lp_bld_printf.h"
 #include "nir_deref.h"
+#include "nir_search_helpers.h"
 
 static void visit_cf_list(struct lp_build_nir_context *bld_base,
                           struct exec_list *list);
@@ -516,13 +517,15 @@ do_quantize_to_f16(struct lp_build_nir_context *bld_base,
 }
 
 static LLVMValueRef do_alu_action(struct lp_build_nir_context *bld_base,
-                                  nir_op op, unsigned src_bit_size[NIR_MAX_VEC_COMPONENTS], LLVMValueRef src[NIR_MAX_VEC_COMPONENTS])
+                                  const nir_alu_instr *instr,
+                                  unsigned src_bit_size[NIR_MAX_VEC_COMPONENTS],
+                                  LLVMValueRef src[NIR_MAX_VEC_COMPONENTS])
 {
    struct gallivm_state *gallivm = bld_base->base.gallivm;
    LLVMBuilderRef builder = gallivm->builder;
    LLVMValueRef result;
 
-   switch (op) {
+   switch (instr->op) {
    case nir_op_b2f32:
       result = emit_b2f(bld_base, src[0], 32);
       break;
@@ -679,14 +682,31 @@ static LLVMValueRef do_alu_action(struct lp_build_nir_context *bld_base,
       break;
    case nir_op_fmax:
    case nir_op_fmin: {
-      enum gallivm_nan_behavior minmax_nan = GALLIVM_NAN_RETURN_OTHER;
+      enum gallivm_nan_behavior minmax_nan;
+      int first = 0;
+
+      /* If one of the sources is known to be a number (i.e., not NaN), then
+       * better code can be generated by passing that information along.
+       */
+      if (is_a_number(bld_base->range_ht, instr, 1,
+                      0 /* unused num_components */,
+                      NULL /* unused swizzle */)) {
+         minmax_nan = GALLIVM_NAN_RETURN_OTHER_SECOND_NONNAN;
+      } else if (is_a_number(bld_base->range_ht, instr, 0,
+                             0 /* unused num_components */,
+                             NULL /* unused swizzle */)) {
+         first = 1;
+         minmax_nan = GALLIVM_NAN_RETURN_OTHER_SECOND_NONNAN;
+      } else {
+         minmax_nan = GALLIVM_NAN_RETURN_OTHER;
+      }
 
-      if (op == nir_op_fmin) {
+      if (instr->op == nir_op_fmin) {
          result = lp_build_min_ext(get_flt_bld(bld_base, src_bit_size[0]),
-                                   src[0], src[1], minmax_nan);
+                                   src[first], src[1 - first], minmax_nan);
       } else {
          result = lp_build_max_ext(get_flt_bld(bld_base, src_bit_size[0]),
-                                   src[0], src[1], minmax_nan);
+                                   src[first], src[1 - first], minmax_nan);
       }
       break;
    }
@@ -1019,7 +1039,7 @@ static void visit_alu(struct lp_build_nir_context *bld_base, const nir_alu_instr
                src_chan[i] = src[i];
             src_chan[i] = cast_type(bld_base, src_chan[i], nir_op_infos[instr->op].input_types[i], src_bit_size[i]);
          }
-         result[c] = do_alu_action(bld_base, instr->op, src_bit_size, src_chan);
+         result[c] = do_alu_action(bld_base, instr, src_bit_size, src_chan);
          result[c] = cast_type(bld_base, result[c], nir_op_infos[instr->op].output_type, nir_dest_bit_size(instr->dest.dest));
       }
    }
@@ -2286,6 +2306,7 @@ bool lp_build_nir_llvm(
                                             _mesa_key_pointer_equal);
    bld_base->vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
                                             _mesa_key_pointer_equal);
+   bld_base->range_ht = _mesa_pointer_hash_table_create(NULL);
 
    func = (struct nir_function *)exec_list_get_head(&nir->functions);
 
@@ -2302,6 +2323,7 @@ bool lp_build_nir_llvm(
    free(bld_base->ssa_defs);
    ralloc_free(bld_base->vars);
    ralloc_free(bld_base->regs);
+   ralloc_free(bld_base->range_ht);
    return true;
 }
 
index 1a92bbc..e9da9d9 100644 (file)
@@ -57,6 +57,9 @@ struct lp_build_nir_context
    struct hash_table *regs;
    struct hash_table *vars;
 
+   /** Value range analysis hash table used in code generation. */
+   struct hash_table *range_ht;
+
    nir_shader *shader;
 
    void (*load_ubo)(struct lp_build_nir_context *bld_base,