Fix tdpbf16ps testcase
authorHaochen Jiang <haochen.jiang@intel.com>
Fri, 24 Dec 2021 05:55:06 +0000 (13:55 +0800)
committerliuhongt <hongtao.liu@intel.com>
Tue, 28 Dec 2021 08:58:27 +0000 (16:58 +0800)
gcc/testsuite/ChangeLog:

* gcc.target/i386/amx-check.h (check_float_tile_register):
New check function for float to prevent precision loss.
* gcc.target/i386/amxbf16-dpbf16ps-2.c: Correct the type convert
and byte offset. Use the new check function.

gcc/testsuite/gcc.target/i386/amx-check.h
gcc/testsuite/gcc.target/i386/amxbf16-dpbf16ps-2.c

index 03616ff..434b0e5 100644 (file)
@@ -139,8 +139,27 @@ int check_tile_register (__tile* ref, __tile* target)
 
   for (i = 0; i < rows; i++)
     for (j = 0; j < colsb; j++)
-       if (ref->buf[i * colsb + j] != target->buf[i * colsb + j])
-           return 0;
+      if (ref->buf[i * colsb + j] != target->buf[i * colsb + j])
+       return 0;
+
+  return 1;
+}
+
+/* Compare float tile register value with __tile variable */
+int check_float_tile_register (__tile* ref, __tile* target)
+{
+  /* Tile register should be stored from tmm to
+     memory and compare with emulation results. */
+  int rows = target->rows;
+  int colsb = target->colsb / 4;
+  int i, j;
+  uint32_t *ref_buf = (uint32_t *) ref->buf;
+  uint32_t *target_buf = (uint32_t *) target->buf;
+
+  for (i = 0; i < rows; i++)
+    for (j = 0; j < colsb; j++)
+      if (abs(ref_buf[i * colsb + j] - target_buf[i * colsb + j]) > 1)
+       return 0;
 
   return 1;
 }
index f7002ca..b00bc13 100644 (file)
@@ -12,15 +12,25 @@ void test_amx_bf16_dpbf16ps ();
 /* Transformation functions between bf16/float */
 static uint16_t make_bf16 (float f)
 {
-  uint32_t u = (uint32_t)f;
-  u = (u >> 16) & 0xffff;
-  return (uint16_t)u;
+  union
+  {
+    float f;
+    uint32_t u;
+  } fu;
+  fu.f = f;
+  fu.u = (fu.u >> 16) & 0xffff;
+  return (uint16_t) fu.u;
 }
 
 static float make_f32 (uint16_t bf)
 {
-  uint32_t u = (uint32_t)(bf << 16);
-  return (float)u;
+  union
+  {
+    float f;
+    uint32_t u;
+  } fu;
+  fu.u = (uint32_t) bf << 16;
+  return fu.f;
 }
 
 /* Init tile buffer with bf16 pairs */
@@ -54,10 +64,10 @@ void calc_matrix_dpbf16ps (__tile *dst, __tile *src1, __tile *src2)
        for (t = 0; t < 2; t+=2)
          {    
            dst_buf[i * N + k] += 
-             (make_f32(src1_buf[i * 4 * N + 4 * j + t]) *
-             make_f32(src2_buf[j * 4 * K + 4 * k + t])) +
-             (make_f32(src1_buf[i * 4 * N + 4 * j + t + 1]) *
-             make_f32(src2_buf[j * 4 * K + 4 * k + t + 1]));
+             (make_f32(src1_buf[i * 2 * N + 2 * j + t]) *
+             make_f32(src2_buf[j * 2 * K + 2 * k + t])) +
+             (make_f32(src1_buf[i * 2 * N + 2 * j + t + 1]) *
+             make_f32(src2_buf[j * 2 * K + 2 * k + t + 1]));
          }
 
 }
@@ -80,6 +90,6 @@ void test_amx_bf16_dpbf16ps ()
   _tile_dpbf16ps (1, 2, 3);
   _tile_stored (1, dst_ref.buf, _STRIDE);
 
-  if (!check_tile_register (&dst_ref, &dst))
+  if (!check_float_tile_register (&dst_ref, &dst))
         abort();
 }