nir: Add is_null_constant to nir_constant
authorJesse Natalie <jenatali@microsoft.com>
Thu, 18 May 2023 17:31:50 +0000 (10:31 -0700)
committerMarge Bot <emma+marge@anholt.net>
Tue, 13 Jun 2023 00:43:36 +0000 (00:43 +0000)
Indicates that the values contained within are 0s, regardless of
type. Enables some optimizations.

Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23173>

src/compiler/nir/nir.h
src/compiler/nir/nir_clone.c
src/compiler/nir/nir_lower_io.c
src/compiler/nir/nir_opt_constant_folding.c
src/compiler/nir/nir_print.c
src/compiler/nir/nir_serialize.c
src/compiler/nir/nir_validate.c

index 0677930..cf7242f 100644 (file)
@@ -395,6 +395,9 @@ typedef struct nir_constant {
     */
    nir_const_value values[NIR_MAX_VEC_COMPONENTS];
 
+   /* Indicates all the values are 0s which can enable some optimizations */
+   bool is_null_constant;
+
    /* we could get this from the var->type but makes clone *much* easier to
     * not have to care about the type.
     */
index 84e1e40..5774cde 100644 (file)
@@ -136,6 +136,7 @@ nir_constant_clone(const nir_constant *c, nir_variable *nvar)
    nir_constant *nc = ralloc(nvar, nir_constant);
 
    memcpy(nc->values, c->values, sizeof(nc->values));
+   nc->is_null_constant = c->is_null_constant;
    nc->num_elements = c->num_elements;
    nc->elements = ralloc_array(nvar, nir_constant *, c->num_elements);
    for (unsigned i = 0; i < c->num_elements; i++) {
index ad83e9e..9e6e021 100644 (file)
@@ -2530,6 +2530,11 @@ static void
 write_constant(void *dst, size_t dst_size,
                const nir_constant *c, const struct glsl_type *type)
 {
+   if (c->is_null_constant) {
+      memset(dst, 0, dst_size);
+      return;
+   }
+
    if (glsl_type_is_vector_or_scalar(type)) {
       const unsigned num_components = glsl_get_vector_elements(type);
       const unsigned bit_size = glsl_get_bit_size(type);
index 4b27bc3..945055f 100644 (file)
@@ -122,6 +122,12 @@ const_value_for_deref(nir_deref_instr *deref)
    if (var->constant_initializer == NULL)
       goto fail;
 
+   if (var->constant_initializer->is_null_constant) {
+      /* Doesn't matter what casts are in the way, it's all zeros */
+      nir_deref_path_finish(&path);
+      return var->constant_initializer->values;
+   }
+
    nir_constant *c = var->constant_initializer;
    nir_const_value *v = NULL; /* Vector value for array-deref-of-vec */
 
index bb9588a..3a7864f 100644 (file)
@@ -714,9 +714,13 @@ print_var_decl(nir_variable *var, print_state *state)
    }
 
    if (var->constant_initializer) {
-      fprintf(fp, " = { ");
-      print_constant(var->constant_initializer, var->type, state);
-      fprintf(fp, " }");
+      if (var->constant_initializer->is_null_constant) {
+         fprintf(fp, " = null");
+      } else {
+         fprintf(fp, " = { ");
+         print_constant(var->constant_initializer, var->type, state);
+         fprintf(fp, " }");
+      }
    }
    if (glsl_type_is_sampler(var->type) && var->data.sampler.is_inline_sampler) {
       fprintf(fp, " = { %s, %s, %s }",
index 862f289..3004f1a 100644 (file)
@@ -187,11 +187,15 @@ read_constant(read_ctx *ctx, nir_variable *nvar)
 {
    nir_constant *c = ralloc(nvar, nir_constant);
 
+   static const nir_const_value zero_vals[ARRAY_SIZE(c->values)] = { 0 };
    blob_copy_bytes(ctx->blob, (uint8_t *)c->values, sizeof(c->values));
+   c->is_null_constant = memcmp(c->values, zero_vals, sizeof(c->values)) == 0;
    c->num_elements = blob_read_uint32(ctx->blob);
    c->elements = ralloc_array(nvar, nir_constant *, c->num_elements);
-   for (unsigned i = 0; i < c->num_elements; i++)
+   for (unsigned i = 0; i < c->num_elements; i++) {
       c->elements[i] = read_constant(ctx, nvar);
+      c->is_null_constant &= c->elements[i]->is_null_constant;
+   }
 
    return c;
 }
index b304916..81f6304 100644 (file)
@@ -963,7 +963,7 @@ validate_call_instr(nir_call_instr *instr, validate_state *state)
 
 static void
 validate_const_value(nir_const_value *val, unsigned bit_size,
-                     validate_state *state)
+                     bool is_null_constant, validate_state *state)
 {
    /* In order for block copies to work properly for things like instruction
     * comparisons and [de]serialization, we require the unused bits of the
@@ -971,24 +971,26 @@ validate_const_value(nir_const_value *val, unsigned bit_size,
     */
    nir_const_value cmp_val;
    memset(&cmp_val, 0, sizeof(cmp_val));
-   switch (bit_size) {
-   case 1:
-      cmp_val.b = val->b;
-      break;
-   case 8:
-      cmp_val.u8 = val->u8;
-      break;
-   case 16:
-      cmp_val.u16 = val->u16;
-      break;
-   case 32:
-      cmp_val.u32 = val->u32;
-      break;
-   case 64:
-      cmp_val.u64 = val->u64;
-      break;
-   default:
-      validate_assert(state, !"Invalid load_const bit size");
+   if (!is_null_constant) {
+      switch (bit_size) {
+      case 1:
+         cmp_val.b = val->b;
+         break;
+      case 8:
+         cmp_val.u8 = val->u8;
+         break;
+      case 16:
+         cmp_val.u16 = val->u16;
+         break;
+      case 32:
+         cmp_val.u32 = val->u32;
+         break;
+      case 64:
+         cmp_val.u64 = val->u64;
+         break;
+      default:
+         validate_assert(state, !"Invalid load_const bit size");
+      }
    }
    validate_assert(state, memcmp(val, &cmp_val, sizeof(cmp_val)) == 0);
 }
@@ -999,7 +1001,7 @@ validate_load_const_instr(nir_load_const_instr *instr, validate_state *state)
    validate_ssa_def(&instr->def, state);
 
    for (unsigned i = 0; i < instr->def.num_components; i++)
-      validate_const_value(&instr->value[i], instr->def.bit_size, state);
+      validate_const_value(&instr->value[i], instr->def.bit_size, false, state);
 }
 
 static void
@@ -1483,7 +1485,7 @@ validate_constant(nir_constant *c, const struct glsl_type *type,
       unsigned num_components = glsl_get_vector_elements(type);
       unsigned bit_size = glsl_get_bit_size(type);
       for (unsigned i = 0; i < num_components; i++)
-         validate_const_value(&c->values[i], bit_size, state);
+         validate_const_value(&c->values[i], bit_size, c->is_null_constant, state);
       for (unsigned i = num_components; i < NIR_MAX_VEC_COMPONENTS; i++)
          validate_assert(state, c->values[i].u64 == 0);
    } else {
@@ -1492,11 +1494,14 @@ validate_constant(nir_constant *c, const struct glsl_type *type,
          for (unsigned i = 0; i < c->num_elements; i++) {
             const struct glsl_type *elem_type = glsl_get_struct_field(type, i);
             validate_constant(c->elements[i], elem_type, state);
+            validate_assert(state, !c->is_null_constant || c->elements[i]->is_null_constant);
          }
       } else if (glsl_type_is_array_or_matrix(type)) {
          const struct glsl_type *elem_type = glsl_get_array_element(type);
-         for (unsigned i = 0; i < c->num_elements; i++)
+         for (unsigned i = 0; i < c->num_elements; i++) {
             validate_constant(c->elements[i], elem_type, state);
+            validate_assert(state, !c->is_null_constant || c->elements[i]->is_null_constant);
+         }
       } else {
          validate_assert(state, !"Invalid type for nir_constant");
       }