nir: Pass fully qualified type to nir_const_value_negative_equal
authorIan Romanick <ian.d.romanick@intel.com>
Thu, 13 Jun 2019 19:59:29 +0000 (12:59 -0700)
committerIan Romanick <ian.d.romanick@intel.com>
Mon, 8 Jul 2019 18:30:10 +0000 (11:30 -0700)
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Suggested-by: Jason Ekstrand <jason@jlekstrand.net>
Reviewed-by: Matt Turner <mattst88@gmail.com>
src/compiler/nir/nir.h
src/compiler/nir/nir_instr_set.c
src/compiler/nir/tests/negative_equal_tests.cpp

index 92dad1c..90e35ee 100644 (file)
@@ -1032,8 +1032,7 @@ nir_ssa_alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
 bool nir_const_value_negative_equal(const nir_const_value *c1,
                                     const nir_const_value *c2,
                                     unsigned components,
-                                    nir_alu_type base_type,
-                                    unsigned bits);
+                                    nir_alu_type full_type);
 
 bool nir_alu_srcs_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2,
                         unsigned src1, unsigned src2);
index a19e846..5893dc5 100644 (file)
@@ -305,95 +305,73 @@ bool
 nir_const_value_negative_equal(const nir_const_value *c1,
                                const nir_const_value *c2,
                                unsigned components,
-                               nir_alu_type base_type,
-                               unsigned bits)
+                               nir_alu_type full_type)
 {
-   assert(base_type == nir_alu_type_get_base_type(base_type));
-   assert(base_type != nir_type_invalid);
-
-   /* This can occur for 1-bit Boolean values. */
-   if (bits == 1)
-      return false;
-
-   switch (base_type) {
-   case nir_type_float:
-      switch (bits) {
-      case 16:
-         for (unsigned i = 0; i < components; i++) {
-            if (_mesa_half_to_float(c1[i].u16) !=
-                -_mesa_half_to_float(c2[i].u16)) {
-               return false;
-            }
-         }
-
-         return true;
-
-      case 32:
-         for (unsigned i = 0; i < components; i++) {
-            if (c1[i].f32 != -c2[i].f32)
-               return false;
-         }
-
-         return true;
-
-      case 64:
-         for (unsigned i = 0; i < components; i++) {
-            if (c1[i].f64 != -c2[i].f64)
-               return false;
+   assert(nir_alu_type_get_base_type(full_type) != nir_type_invalid);
+   assert(nir_alu_type_get_type_size(full_type) != 0);
+
+   switch (full_type) {
+   case nir_type_float16:
+      for (unsigned i = 0; i < components; i++) {
+         if (_mesa_half_to_float(c1[i].u16) !=
+             -_mesa_half_to_float(c2[i].u16)) {
+            return false;
          }
+      }
 
-         return true;
+      return true;
 
-      default:
-         unreachable("unknown bit size");
+   case nir_type_float32:
+      for (unsigned i = 0; i < components; i++) {
+         if (c1[i].f32 != -c2[i].f32)
+            return false;
       }
 
-      break;
+      return true;
 
-   case nir_type_int:
-   case nir_type_uint:
-      switch (bits) {
-      case 8:
-         for (unsigned i = 0; i < components; i++) {
-            if (c1[i].i8 != -c2[i].i8)
-               return false;
-         }
+   case nir_type_float64:
+      for (unsigned i = 0; i < components; i++) {
+         if (c1[i].f64 != -c2[i].f64)
+            return false;
+      }
 
-         return true;
+      return true;
 
-      case 16:
-         for (unsigned i = 0; i < components; i++) {
-            if (c1[i].i16 != -c2[i].i16)
-               return false;
-         }
+   case nir_type_int8:
+   case nir_type_uint8:
+      for (unsigned i = 0; i < components; i++) {
+         if (c1[i].i8 != -c2[i].i8)
+            return false;
+      }
 
-         return true;
-         break;
+      return true;
 
-      case 32:
-         for (unsigned i = 0; i < components; i++) {
-            if (c1[i].i32 != -c2[i].i32)
-               return false;
-         }
+   case nir_type_int16:
+   case nir_type_uint16:
+      for (unsigned i = 0; i < components; i++) {
+         if (c1[i].i16 != -c2[i].i16)
+            return false;
+      }
 
-         return true;
+      return true;
 
-      case 64:
-         for (unsigned i = 0; i < components; i++) {
-            if (c1[i].i64 != -c2[i].i64)
-               return false;
-         }
+   case nir_type_int32:
+   case nir_type_uint32:
+      for (unsigned i = 0; i < components; i++) {
+         if (c1[i].i32 != -c2[i].i32)
+            return false;
+      }
 
-         return true;
+      return true;
 
-      default:
-         unreachable("unknown bit size");
+   case nir_type_int64:
+   case nir_type_uint64:
+      for (unsigned i = 0; i < components; i++) {
+         if (c1[i].i64 != -c2[i].i64)
+            return false;
       }
 
-      break;
-
-   case nir_type_bool:
-      return false;
+      return true;
 
    default:
       break;
@@ -449,7 +427,7 @@ nir_alu_srcs_negative_equal(const nir_alu_instr *alu1,
       return nir_const_value_negative_equal(const1,
                                             const2,
                                             nir_ssa_alu_instr_src_components(alu1, src1),
-                                            nir_op_infos[alu1->op].input_types[src1],
+                                            nir_op_infos[alu1->op].input_types[src1] |
                                             nir_src_bit_size(alu1->src[src1].src));
    }
 
index 2d3aa6a..86305ce 100644 (file)
 #include "util/half_float.h"
 
 static void count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS],
-                           nir_alu_type base_type, unsigned bits, int first);
+                           nir_alu_type full_type, int first);
 static void negate(nir_const_value dst[NIR_MAX_VEC_COMPONENTS],
                    const nir_const_value src[NIR_MAX_VEC_COMPONENTS],
-                   nir_alu_type base_type, unsigned bits, unsigned components);
+                   nir_alu_type full_type, unsigned components);
 
 class const_value_negative_equal_test : public ::testing::Test {
 protected:
@@ -68,89 +68,89 @@ TEST_F(const_value_negative_equal_test, float32_zero)
 {
    /* Verify that 0.0 negative-equals 0.0. */
    EXPECT_TRUE(nir_const_value_negative_equal(c1, c1, NIR_MAX_VEC_COMPONENTS,
-                                              nir_type_float32));
+                                              nir_type_float32));
 }
 
 TEST_F(const_value_negative_equal_test, float64_zero)
 {
    /* Verify that 0.0 negative-equals 0.0. */
    EXPECT_TRUE(nir_const_value_negative_equal(c1, c1, NIR_MAX_VEC_COMPONENTS,
-                                              nir_type_float64));
+                                              nir_type_float64));
 }
 
 /* Compare an object with non-zero values to itself.  This should always be
  * false.
  */
-#define compare_with_self(base_type, bits) \
-TEST_F(const_value_negative_equal_test, base_type ## bits ## _self)     \
+#define compare_with_self(full_type)                                    \
+TEST_F(const_value_negative_equal_test, full_type ## _self)             \
 {                                                                       \
-   count_sequence(c1, base_type, bits, 1);                              \
+   count_sequence(c1, full_type, 1);                                    \
    EXPECT_FALSE(nir_const_value_negative_equal(c1, c1,                  \
                                                NIR_MAX_VEC_COMPONENTS,  \
-                                               base_type, bits));       \
+                                               full_type));             \
 }
 
-compare_with_self(nir_type_float16)
-compare_with_self(nir_type_float32)
-compare_with_self(nir_type_float64)
-compare_with_self(nir_type_int8)
-compare_with_self(nir_type_uint8)
-compare_with_self(nir_type_int16)
-compare_with_self(nir_type_uint16)
-compare_with_self(nir_type_int32)
-compare_with_self(nir_type_uint32)
-compare_with_self(nir_type_int64)
-compare_with_self(nir_type_uint64)
+compare_with_self(nir_type_float16)
+compare_with_self(nir_type_float32)
+compare_with_self(nir_type_float64)
+compare_with_self(nir_type_int8)
+compare_with_self(nir_type_uint8)
+compare_with_self(nir_type_int16)
+compare_with_self(nir_type_uint16)
+compare_with_self(nir_type_int32)
+compare_with_self(nir_type_uint32)
+compare_with_self(nir_type_int64)
+compare_with_self(nir_type_uint64)
 
 /* Compare an object with the negation of itself.  This should always be true.
  */
-#define compare_with_negation(base_type, bits) \
-TEST_F(const_value_negative_equal_test, base_type ## bits ## _trivially_true) \
+#define compare_with_negation(full_type)                                \
+TEST_F(const_value_negative_equal_test, full_type ## _trivially_true)   \
 {                                                                       \
-   count_sequence(c1, base_type, bits, 1);                              \
-   negate(c2, c1, base_type, bits, NIR_MAX_VEC_COMPONENTS);             \
+   count_sequence(c1, full_type, 1);                                    \
+   negate(c2, c1, full_type, NIR_MAX_VEC_COMPONENTS);                   \
    EXPECT_TRUE(nir_const_value_negative_equal(c1, c2,                   \
                                               NIR_MAX_VEC_COMPONENTS,   \
-                                              base_type, bits));        \
+                                              full_type));              \
 }
 
-compare_with_negation(nir_type_float16)
-compare_with_negation(nir_type_float32)
-compare_with_negation(nir_type_float64)
-compare_with_negation(nir_type_int8)
-compare_with_negation(nir_type_uint8)
-compare_with_negation(nir_type_int16)
-compare_with_negation(nir_type_uint16)
-compare_with_negation(nir_type_int32)
-compare_with_negation(nir_type_uint32)
-compare_with_negation(nir_type_int64)
-compare_with_negation(nir_type_uint64)
+compare_with_negation(nir_type_float16)
+compare_with_negation(nir_type_float32)
+compare_with_negation(nir_type_float64)
+compare_with_negation(nir_type_int8)
+compare_with_negation(nir_type_uint8)
+compare_with_negation(nir_type_int16)
+compare_with_negation(nir_type_uint16)
+compare_with_negation(nir_type_int32)
+compare_with_negation(nir_type_uint32)
+compare_with_negation(nir_type_int64)
+compare_with_negation(nir_type_uint64)
 
 /* Compare fewer than the maximum possible components.  All of the components
  * that are compared a negative-equal, but the extra components are not.
  */
-#define compare_fewer_components(base_type, bits) \
-TEST_F(const_value_negative_equal_test, base_type ## bits ## _fewer_components) \
+#define compare_fewer_components(full_type)                             \
+TEST_F(const_value_negative_equal_test, full_type ## _fewer_components) \
 {                                                                       \
-   count_sequence(c1, base_type, bits, 1);                              \
-   negate(c2, c1, base_type, bits, 3);                                  \
-   EXPECT_TRUE(nir_const_value_negative_equal(c1, c2, 3, base_type, bits)); \
+   count_sequence(c1, full_type, 1);                                    \
+   negate(c2, c1, full_type, 3);                                        \
+   EXPECT_TRUE(nir_const_value_negative_equal(c1, c2, 3, full_type));   \
    EXPECT_FALSE(nir_const_value_negative_equal(c1, c2,                  \
                                                NIR_MAX_VEC_COMPONENTS,  \
-                                               base_type, bits));       \
+                                               full_type));             \
 }
 
-compare_fewer_components(nir_type_float16)
-compare_fewer_components(nir_type_float32)
-compare_fewer_components(nir_type_float64)
-compare_fewer_components(nir_type_int8)
-compare_fewer_components(nir_type_uint8)
-compare_fewer_components(nir_type_int16)
-compare_fewer_components(nir_type_uint16)
-compare_fewer_components(nir_type_int32)
-compare_fewer_components(nir_type_uint32)
-compare_fewer_components(nir_type_int64)
-compare_fewer_components(nir_type_uint64)
+compare_fewer_components(nir_type_float16)
+compare_fewer_components(nir_type_float32)
+compare_fewer_components(nir_type_float64)
+compare_fewer_components(nir_type_int8)
+compare_fewer_components(nir_type_uint8)
+compare_fewer_components(nir_type_int16)
+compare_fewer_components(nir_type_uint16)
+compare_fewer_components(nir_type_int32)
+compare_fewer_components(nir_type_uint32)
+compare_fewer_components(nir_type_int64)
+compare_fewer_components(nir_type_uint64)
 
 TEST_F(alu_srcs_negative_equal_test, trivial_float)
 {
@@ -221,65 +221,53 @@ TEST_F(alu_srcs_negative_equal_test, trivial_negation_int)
 }
 
 static void
-count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS], nir_alu_type base_type, unsigned bits, int first)
+count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS],
+               nir_alu_type full_type, int first)
 {
-   switch (base_type) {
-   case nir_type_float:
-      switch (bits) {
-      case 16:
-         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
-            c[i].u16 = _mesa_float_to_half(float(i + first));
+   switch (full_type) {
+   case nir_type_float16:
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+         c[i].u16 = _mesa_float_to_half(float(i + first));
 
-         break;
-
-      case 32:
-         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
-            c[i].f32 = float(i + first);
-
-         break;
-
-      case 64:
-         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
-            c[i].f64 = double(i + first);
-
-         break;
+      break;
 
-      default:
-         unreachable("unknown bit size");
-      }
+   case nir_type_float32:
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+         c[i].f32 = float(i + first);
 
       break;
 
-   case nir_type_int:
-   case nir_type_uint:
-      switch (bits) {
-      case 8:
-         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
-            c[i].i8 = i + first;
+   case nir_type_float64:
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+         c[i].f64 = double(i + first);
 
-         break;
+      break;
 
-      case 16:
-         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
-            c[i].i16 = i + first;
+   case nir_type_int8:
+   case nir_type_uint8:
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+         c[i].i8 = i + first;
 
-         break;
+      break;
 
-      case 32:
-         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
-            c[i].i32 = i + first;
+   case nir_type_int16:
+   case nir_type_uint16:
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+         c[i].i16 = i + first;
 
-         break;
+      break;
 
-      case 64:
-         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
-            c[i].i64 = i + first;
+   case nir_type_int32:
+   case nir_type_uint32:
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+         c[i].i32 = i + first;
 
-         break;
+      break;
 
-      default:
-         unreachable("unknown bit size");
-      }
+   case nir_type_int64:
+   case nir_type_uint64:
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+         c[i].i64 = i + first;
 
       break;
 
@@ -292,65 +280,52 @@ count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS], nir_alu_type base_type
 static void
 negate(nir_const_value dst[NIR_MAX_VEC_COMPONENTS],
        const nir_const_value src[NIR_MAX_VEC_COMPONENTS],
-       nir_alu_type base_type, unsigned bits, unsigned components)
+       nir_alu_type full_type, unsigned components)
 {
-   switch (base_type) {
-   case nir_type_float:
-      switch (bits) {
-      case 16:
-         for (unsigned i = 0; i < components; i++)
-            dst[i].u16 = _mesa_float_to_half(-_mesa_half_to_float(src[i].u16));
+   switch (full_type) {
+   case nir_type_float16:
+      for (unsigned i = 0; i < components; i++)
+         dst[i].u16 = _mesa_float_to_half(-_mesa_half_to_float(src[i].u16));
 
-         break;
-
-      case 32:
-         for (unsigned i = 0; i < components; i++)
-            dst[i].f32 = -src[i].f32;
-
-         break;
-
-      case 64:
-         for (unsigned i = 0; i < components; i++)
-            dst[i].f64 = -src[i].f64;
-
-         break;
+      break;
 
-      default:
-         unreachable("unknown bit size");
-      }
+   case nir_type_float32:
+      for (unsigned i = 0; i < components; i++)
+         dst[i].f32 = -src[i].f32;
 
       break;
 
-   case nir_type_int:
-   case nir_type_uint:
-      switch (bits) {
-      case 8:
-         for (unsigned i = 0; i < components; i++)
-            dst[i].i8 = -src[i].i8;
+   case nir_type_float64:
+      for (unsigned i = 0; i < components; i++)
+         dst[i].f64 = -src[i].f64;
 
-         break;
+      break;
 
-      case 16:
-         for (unsigned i = 0; i < components; i++)
-            dst[i].i16 = -src[i].i16;
+   case nir_type_int8:
+   case nir_type_uint8:
+      for (unsigned i = 0; i < components; i++)
+         dst[i].i8 = -src[i].i8;
 
-         break;
+      break;
 
-      case 32:
-         for (unsigned i = 0; i < components; i++)
-            dst[i].i32 = -src[i].i32;
+   case nir_type_int16:
+   case nir_type_uint16:
+      for (unsigned i = 0; i < components; i++)
+         dst[i].i16 = -src[i].i16;
 
-         break;
+      break;
 
-      case 64:
-         for (unsigned i = 0; i < components; i++)
-            dst[i].i64 = -src[i].i64;
+   case nir_type_int32:
+   case nir_type_uint32:
+      for (unsigned i = 0; i < components; i++)
+         dst[i].i32 = -src[i].i32;
 
-         break;
+      break;
 
-      default:
-         unreachable("unknown bit size");
-      }
+   case nir_type_int64:
+   case nir_type_uint64:
+      for (unsigned i = 0; i < components; i++)
+         dst[i].i64 = -src[i].i64;
 
       break;