libgccjit: Fix float vector comparison
authorAntoni Boucher <bouanto@zoho.com>
Sun, 20 Nov 2022 15:22:53 +0000 (10:22 -0500)
committerAntoni Boucher <bouanto@zoho.com>
Wed, 7 Dec 2022 00:40:17 +0000 (19:40 -0500)
Fix float vector comparison and add comparison tests to include float and
vectors.

gcc/testsuite:
PR jit/107770
* jit.dg/harness.h: Add new macro to to perform vector
comparisons
* jit.dg/test-expressions.c: Extend comparison tests to add float
types and vectors

gcc/jit:
PR jit/107770
* jit-playback.cc: Fix vector float comparison
* jit-playback.h: Update comparison function signature
* jit-recording.cc: Update call for "new_comparison" function
* jit-recording.h: Fix vector float comparison

Co-authored-by: Guillaume Gomez <guillaume1.gomez@gmail.com>
Signed-off-by: Guillaume Gomez <guillaume1.gomez@gmail.com>
gcc/jit/jit-playback.cc
gcc/jit/jit-playback.h
gcc/jit/jit-recording.cc
gcc/jit/jit-recording.h
gcc/testsuite/jit.dg/harness.h
gcc/testsuite/jit.dg/test-expressions.c

index 069ed70..96e9227 100644 (file)
@@ -1213,7 +1213,7 @@ playback::rvalue *
 playback::context::
 new_comparison (location *loc,
                enum gcc_jit_comparison op,
-               rvalue *a, rvalue *b)
+               rvalue *a, rvalue *b, type *vec_result_type)
 {
   // FIXME: type-checking, or coercion?
   enum tree_code inner_op;
@@ -1252,10 +1252,27 @@ new_comparison (location *loc,
   tree node_b = b->as_tree ();
   node_b = fold_const_var (node_b);
 
-  tree inner_expr = build2 (inner_op,
-                           boolean_type_node,
-                           node_a,
-                           node_b);
+  tree inner_expr;
+  tree a_type = TREE_TYPE (node_a);
+  if (VECTOR_TYPE_P (a_type))
+  {
+    /* Build a vector comparison.  See build_vec_cmp in c-typeck.cc for
+       reference.  */
+    tree t_vec_result_type = vec_result_type->as_tree ();
+    tree zero_vec = build_zero_cst (t_vec_result_type);
+    tree minus_one_vec = build_minus_one_cst (t_vec_result_type);
+    tree cmp_type = truth_type_for (a_type);
+    tree cmp = build2 (inner_op, cmp_type, node_a, node_b);
+    inner_expr = build3 (VEC_COND_EXPR, t_vec_result_type, cmp, minus_one_vec,
+                        zero_vec);
+  }
+  else
+  {
+    inner_expr = build2 (inner_op,
+                        boolean_type_node,
+                        node_a,
+                        node_b);
+  }
 
   /* Try to fold.  */
   inner_expr = fold (inner_expr);
index 1aeee2c..214f399 100644 (file)
@@ -162,7 +162,7 @@ public:
   rvalue *
   new_comparison (location *loc,
                  enum gcc_jit_comparison op,
-                 rvalue *a, rvalue *b);
+                 rvalue *a, rvalue *b, type *vec_result_type);
 
   rvalue *
   new_call (location *loc,
index 6ae5a66..2ce2722 100644 (file)
@@ -5836,7 +5836,8 @@ recording::comparison::replay_into (replayer *r)
   set_playback_obj (r->new_comparison (playback_location (r, m_loc),
                                       m_op,
                                       m_a->playback_rvalue (),
-                                      m_b->playback_rvalue ()));
+                                      m_b->playback_rvalue (),
+                                      m_type->playback_type ()));
 }
 
 /* Implementation of pure virtual hook recording::rvalue::visit_children
index 8610ea9..5d7c717 100644 (file)
@@ -1683,7 +1683,23 @@ public:
     m_op (op),
     m_a (a),
     m_b (b)
-  {}
+  {
+    type *a_type = a->get_type ();
+    vector_type *vec_type = a_type->dyn_cast_vector_type ();
+    if (vec_type != NULL)
+    {
+      type *element_type = vec_type->get_element_type ();
+      type *inner_type;
+      /* Vectors of floating-point values return a vector of integers of the
+         same size.  */
+      if (element_type->is_float ())
+       inner_type = ctxt->get_int_type (element_type->get_size (), false);
+      else
+       inner_type = element_type;
+      m_type = new vector_type (inner_type, vec_type->get_num_units ());
+      ctxt->record (m_type);
+    }
+  }
 
   void replay_into (replayer *r) final override;
 
index 7b70ce7..e423abe 100644 (file)
@@ -68,6 +68,21 @@ static char test[1024];
     }                                        \
   } while (0)
 
+#define CHECK_VECTOR_VALUE(LEN, ACTUAL, EXPECTED) \
+  do {                                       \
+    for (int __check_vector_it = 0; __check_vector_it < LEN; ++__check_vector_it) { \
+      if ((ACTUAL)[__check_vector_it] != (EXPECTED)[__check_vector_it]) { \
+          fail ("%s: %s: actual: %s != expected: %s (position %d)", \
+              test, __func__, #ACTUAL, #EXPECTED, __check_vector_it);  \
+        fprintf (stderr, "incorrect value\n"); \
+        abort ();                              \
+      } \
+    } \
+  pass ("%s: %s: actual: %s == expected: %s", \
+        test, __func__, #ACTUAL, #EXPECTED);  \
+  } while (0)
+
+
 #define CHECK_DOUBLE_VALUE(ACTUAL, EXPECTED) \
   do {                                       \
     double expected = (EXPECTED);           \
index f9cc64f..13b3baf 100644 (file)
@@ -383,15 +383,7 @@ make_test_of_comparison (gcc_jit_context *ctxt,
   gcc_jit_param *param_b =
     gcc_jit_context_new_param (ctxt, NULL, type, "b");
   gcc_jit_param *params[] = {param_a, param_b};
-  gcc_jit_type *bool_type =
-    gcc_jit_context_get_type (ctxt, GCC_JIT_TYPE_BOOL);
-  gcc_jit_function *test_fn =
-    gcc_jit_context_new_function (ctxt, NULL,
-                                 GCC_JIT_FUNCTION_EXPORTED,
-                                 bool_type,
-                                 funcname,
-                                 2, params,
-                                 0);
+
   gcc_jit_rvalue *comparison =
     gcc_jit_context_new_comparison (
       ctxt,
@@ -400,6 +392,16 @@ make_test_of_comparison (gcc_jit_context *ctxt,
       gcc_jit_param_as_rvalue (param_a),
       gcc_jit_param_as_rvalue (param_b));
 
+  gcc_jit_type *comparison_type = gcc_jit_rvalue_get_type(comparison);
+
+  gcc_jit_function *test_fn =
+    gcc_jit_context_new_function (ctxt, NULL,
+                                 GCC_JIT_FUNCTION_EXPORTED,
+                                 comparison_type,
+                                 funcname,
+                                 2, params,
+                                 0);
+
   gcc_jit_block *initial = gcc_jit_function_new_block (test_fn, "initial");
   gcc_jit_block_end_with_return (initial, NULL, comparison);
 
@@ -407,48 +409,103 @@ make_test_of_comparison (gcc_jit_context *ctxt,
     gcc_jit_rvalue_as_object (comparison));
 }
 
-static void
-make_tests_of_comparisons (gcc_jit_context *ctxt)
+static void run_test_of_comparison(gcc_jit_context *ctxt,
+                        gcc_jit_type *type,
+                        enum gcc_jit_comparison op,
+                        const char *funcname,
+                        const char *vec_funcname,
+                        const char *expected)
 {
-  gcc_jit_type *int_type =
-    gcc_jit_context_get_type (ctxt, GCC_JIT_TYPE_INT);
+  gcc_jit_type *vec_type =
+    gcc_jit_type_get_vector (type, 4);
 
   CHECK_STRING_VALUE (
     make_test_of_comparison (ctxt,
-                            int_type,
-                            GCC_JIT_COMPARISON_EQ,
-                            "test_COMPARISON_EQ_on_int"),
-    "a == b");
-  CHECK_STRING_VALUE (
-    make_test_of_comparison (ctxt,
-                            int_type,
-                            GCC_JIT_COMPARISON_NE,
-                            "test_COMPARISON_NE_on_int"),
-    "a != b");
-  CHECK_STRING_VALUE (
-    make_test_of_comparison (ctxt,
-                            int_type,
-                            GCC_JIT_COMPARISON_LT,
-                            "test_COMPARISON_LT_on_int"),
-    "a < b");
+                            type,
+                            op,
+                            funcname),
+    expected);
   CHECK_STRING_VALUE (
     make_test_of_comparison (ctxt,
-                            int_type,
-                            GCC_JIT_COMPARISON_LE,
-                            "test_COMPARISON_LE_on_int"),
-    "a <= b");
-  CHECK_STRING_VALUE (
-    make_test_of_comparison (ctxt,
-                            int_type,
-                            GCC_JIT_COMPARISON_GT,
-                            "test_COMPARISON_GT_on_int"),
-    "a > b");
-  CHECK_STRING_VALUE (
-    make_test_of_comparison (ctxt,
-                            int_type,
-                            GCC_JIT_COMPARISON_GE,
-                            "test_COMPARISON_GE_on_int"),
-    "a >= b");
+                            vec_type,
+                            op,
+                            vec_funcname),
+    expected);
+}
+
+static void
+make_tests_of_comparisons (gcc_jit_context *ctxt)
+{
+  gcc_jit_type *int_type =
+    gcc_jit_context_get_type (ctxt, GCC_JIT_TYPE_INT);
+  gcc_jit_type *float_type =
+    gcc_jit_context_get_type (ctxt, GCC_JIT_TYPE_FLOAT);
+
+  run_test_of_comparison(
+       ctxt,
+       int_type,
+       GCC_JIT_COMPARISON_EQ,
+       "test_COMPARISON_EQ_on_int",
+       "test_COMPARISON_EQ_on_vec_int",
+       "a == b");
+  run_test_of_comparison(
+       ctxt,
+       int_type,
+       GCC_JIT_COMPARISON_NE,
+       "test_COMPARISON_NE_on_int",
+       "test_COMPARISON_NE_on_vec_int",
+       "a != b");
+  run_test_of_comparison(
+       ctxt,
+       int_type,
+       GCC_JIT_COMPARISON_LT,
+       "test_COMPARISON_LT_on_int",
+       "test_COMPARISON_LT_on_vec_int",
+       "a < b");
+  run_test_of_comparison(
+       ctxt,
+       int_type,
+       GCC_JIT_COMPARISON_LE,
+       "test_COMPARISON_LE_on_int",
+       "test_COMPARISON_LE_on_vec_int",
+       "a <= b");
+  run_test_of_comparison(
+       ctxt,
+       int_type,
+       GCC_JIT_COMPARISON_GT,
+       "test_COMPARISON_GT_on_int",
+       "test_COMPARISON_GT_on_vec_int",
+       "a > b");
+  run_test_of_comparison(
+       ctxt,
+       int_type,
+       GCC_JIT_COMPARISON_GE,
+       "test_COMPARISON_GE_on_int",
+       "test_COMPARISON_GE_on_vec_int",
+       "a >= b");
+
+  // Float tests
+  run_test_of_comparison(
+       ctxt,
+       float_type,
+       GCC_JIT_COMPARISON_NE,
+       "test_COMPARISON_NE_on_float",
+       "test_COMPARISON_NE_on_vec_float",
+       "a != b");
+  run_test_of_comparison(
+       ctxt,
+       float_type,
+       GCC_JIT_COMPARISON_LT,
+       "test_COMPARISON_LT_on_float",
+       "test_COMPARISON_LT_on_vec_float",
+       "a < b");
+  run_test_of_comparison(
+       ctxt,
+       float_type,
+       GCC_JIT_COMPARISON_GT,
+       "test_COMPARISON_GT_on_float",
+       "test_COMPARISON_GT_on_vec_float",
+       "a > b");
 }
 
 static void
@@ -502,6 +559,93 @@ verify_comparisons (gcc_jit_result *result)
   CHECK_VALUE (test_COMPARISON_GE_on_int (0, 0), 1);
   CHECK_VALUE (test_COMPARISON_GE_on_int (1, 2), 0);
   CHECK_VALUE (test_COMPARISON_GE_on_int (2, 1), 1);
+
+  typedef int __vector __attribute__ ((__vector_size__ (sizeof(int) * 2)));
+  typedef __vector (*test_vec_fn) (__vector, __vector);
+
+  __vector zero_zero = {0, 0};
+  __vector zero_one = {0, 1};
+  __vector one_zero = {1, 0};
+
+  __vector true_true = {-1, -1};
+  __vector false_true = {0, -1};
+  __vector true_false = {-1, 0};
+  __vector false_false = {0, 0};
+
+  test_vec_fn test_COMPARISON_EQ_on_vec_int =
+    (test_vec_fn)gcc_jit_result_get_code (result,
+                                     "test_COMPARISON_EQ_on_vec_int");
+  CHECK_NON_NULL (test_COMPARISON_EQ_on_vec_int);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_EQ_on_vec_int (zero_zero, zero_zero), true_true);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_EQ_on_vec_int (zero_one, one_zero), false_false);
+
+  test_vec_fn test_COMPARISON_NE_on_vec_int =
+    (test_vec_fn)gcc_jit_result_get_code (result,
+                                     "test_COMPARISON_NE_on_vec_int");
+  CHECK_NON_NULL (test_COMPARISON_NE_on_vec_int);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_NE_on_vec_int (zero_zero, zero_zero), false_false);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_NE_on_vec_int (zero_one, one_zero), true_true);
+
+  test_vec_fn test_COMPARISON_LT_on_vec_int =
+    (test_vec_fn)gcc_jit_result_get_code (result,
+                                     "test_COMPARISON_LT_on_vec_int");
+  CHECK_NON_NULL (test_COMPARISON_LT_on_vec_int);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_LT_on_vec_int (zero_zero, zero_zero), false_false);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_LT_on_vec_int (zero_one, one_zero), true_false);
+
+  test_vec_fn test_COMPARISON_LE_on_vec_int =
+    (test_vec_fn)gcc_jit_result_get_code (result,
+                                     "test_COMPARISON_LE_on_vec_int");
+  CHECK_NON_NULL (test_COMPARISON_LE_on_vec_int);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_LE_on_vec_int (zero_zero, zero_zero), true_true);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_LE_on_vec_int (zero_one, one_zero), true_false);
+
+  test_vec_fn test_COMPARISON_GT_on_vec_int =
+    (test_vec_fn)gcc_jit_result_get_code (result,
+                                     "test_COMPARISON_GT_on_vec_int");
+  CHECK_NON_NULL (test_COMPARISON_GT_on_vec_int);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_GT_on_vec_int (zero_zero, zero_zero), false_false);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_GT_on_vec_int (zero_one, one_zero), false_true);
+
+  test_vec_fn test_COMPARISON_GE_on_vec_int =
+    (test_vec_fn)gcc_jit_result_get_code (result,
+                                     "test_COMPARISON_GE_on_vec_int");
+  CHECK_NON_NULL (test_COMPARISON_GE_on_vec_int);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_GE_on_vec_int (zero_zero, zero_zero), true_true);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_GE_on_vec_int (zero_one, one_zero), false_true);
+
+  typedef float __vector_f __attribute__ ((__vector_size__ (sizeof(float) * 2)));
+  typedef __vector (*test_vec_f_fn) (__vector_f, __vector_f);
+
+  __vector_f zero_zero_f = {0, 0};
+  __vector_f zero_one_f = {0, 1};
+  __vector_f one_zero_f = {1, 0};
+
+  __vector_f true_true_f = {-1, -1};
+  __vector_f false_true_f = {0, -1};
+  __vector_f true_false_f = {-1, 0};
+  __vector_f false_false_f = {0, 0};
+
+  test_vec_f_fn test_COMPARISON_NE_on_vec_float =
+    (test_vec_f_fn)gcc_jit_result_get_code (result,
+                                     "test_COMPARISON_NE_on_vec_float");
+  CHECK_NON_NULL (test_COMPARISON_NE_on_vec_float);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_NE_on_vec_float (zero_zero_f, zero_zero_f), false_false_f);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_NE_on_vec_float (zero_one_f, one_zero_f), true_true_f);
+
+  test_vec_f_fn test_COMPARISON_LT_on_vec_float =
+    (test_vec_f_fn)gcc_jit_result_get_code (result,
+                                     "test_COMPARISON_LT_on_vec_float");
+  CHECK_NON_NULL (test_COMPARISON_LT_on_vec_float);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_LT_on_vec_float (zero_zero_f, zero_zero_f), false_false_f);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_LT_on_vec_float (zero_one_f, one_zero_f), true_false_f);
+
+  test_vec_f_fn test_COMPARISON_GT_on_vec_float =
+    (test_vec_f_fn)gcc_jit_result_get_code (result,
+                                     "test_COMPARISON_GT_on_vec_float");
+  CHECK_NON_NULL (test_COMPARISON_GT_on_vec_float);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_GT_on_vec_float (zero_zero_f, zero_zero_f), false_false_f);
+  CHECK_VECTOR_VALUE (2, test_COMPARISON_GT_on_vec_float (zero_one_f, one_zero_f), false_true_f);
 }
 
 /**********************************************************************