Merge pull request #3709 from nursik/develop
[platform/upstream/openblas.git] / test / compare_sgemm_sbgemm.c
1 /***************************************************************************
2 Copyright (c) 2020, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *****************************************************************************/
27 #include <stdio.h>
28 #include <stdint.h>
29 #include "../common.h"
30 #define SGEMM   BLASFUNC(sgemm)
31 #define SBGEMM   BLASFUNC(sbgemm)
32 typedef union
33 {
34   unsigned short v;
35   struct
36   {
37 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
38     unsigned short s:1;
39     unsigned short e:8;
40     unsigned short m:7;
41 #else
42     unsigned short m:7;
43     unsigned short e:8;
44     unsigned short s:1;
45 #endif
46   } bits;
47 } bfloat16_bits;
48
49 typedef union
50 {
51   float v;
52   struct
53   {
54 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
55     uint32_t s:1;
56     uint32_t e:8;
57     uint32_t m:23;
58 #else
59     uint32_t m:23;
60     uint32_t e:8;
61     uint32_t s:1;
62 #endif
63   } bits;
64 } float32_bits;
65
66 float
67 float16to32 (bfloat16_bits f16)
68 {
69   float32_bits f32;
70   f32.bits.s = f16.bits.s;
71   f32.bits.e = f16.bits.e;
72   f32.bits.m = (uint32_t) f16.bits.m << 16;
73   return f32.v;
74 }
75
76 int
77 main (int argc, char *argv[])
78 {
79   int m, n, k;
80   int i, j, l;
81   int x;
82   int ret = 0;
83   int loop = 100;
84   char transA = 'N', transB = 'N';
85   float alpha = 1.0, beta = 0.0;
86
87   for (x = 0; x <= loop; x++)
88     {
89       m = k = n = x;
90       float A[m * k];
91       float B[k * n];
92       float C[m * n];
93       bfloat16_bits AA[m * k], BB[k * n];
94       float DD[m * n], CC[m * n];
95
96       for (j = 0; j < m; j++)
97         {
98           for (i = 0; i < m; i++)
99             {
100               A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
101               B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
102               C[j * k + i] = 0;
103               AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
104               BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
105               CC[j * k + i] = 0;
106               DD[j * k + i] = 0;
107             }
108         }
109       SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
110              &m, B, &k, &beta, C, &m);
111       SBGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
112               &m, BB, &k, &beta, CC, &m);
113       for (i = 0; i < n; i++)
114         for (j = 0; j < m; j++)
115           for (l = 0; l < k; l++)
116             if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
117               ret++;
118       if (transA == 'N' && transB == 'N')
119         {
120           for (i = 0; i < n; i++)
121             for (j = 0; j < m; j++)
122               for (l = 0; l < k; l++)
123                 {
124                   DD[i * m + j] +=
125                     float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]);
126                 }
127           for (i = 0; i < n; i++)
128             for (j = 0; j < m; j++)
129               for (l = 0; l < k; l++)
130                 if (CC[i * m + j] != DD[i * m + j])
131                   ret++;
132         }
133     }
134   if (ret != 0)
135     fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
136   return ret;
137 }