updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / bfloat16_utils.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef BFLOAT16_UTILS_HPP
18 #define BFLOAT16_UTILS_HPP
19
20 #include "nstl.hpp"
21 #include "jit_avx512_core_bf16cvt.hpp"
22
23 namespace mkldnn {
24 namespace impl {
25 namespace cpu {
26 namespace bf16_cvt_utils {
27
28 union f32_bf16_t {
29     float vfloat;
30     mkldnn_bfloat16_t vbfloat[2];
31 };
32
33 extern jit_avx512_core_cvt_ps_to_bf16_t cvt_one_ps_to_bf16;
34 extern jit_avx512_core_cvt_ps_to_bf16_t cvt_ps_to_bf16_;
35 extern jit_avx512_core_cvt_bf16_to_ps_t cvt_bf16_to_ps_;
36 extern jit_avx512_core_add_cvt_ps_to_bf16_t add_cvt_ps_to_bf16_;
37
38 inline mkldnn_bfloat16_t cvt_float_to_bfloat16(float inp) {
39     assert(mayiuse(avx512_core));
40     mkldnn_bfloat16_t out;
41     jit_call_t p;
42     p.inp = (void *)&inp;
43     p.out = (void *)&out;
44     cvt_one_ps_to_bf16.jit_ker(&p);
45     return out;
46 }
47
48 inline void cvt_float_to_bfloat16(mkldnn_bfloat16_t *out, const float *inp) {
49     assert(mayiuse(avx512_core));
50     jit_call_t p;
51     p.inp = (void *)inp;
52     p.out = (void *)out;
53     cvt_one_ps_to_bf16.jit_ker(&p);
54 }
55
56 inline void cvt_float_to_bfloat16(mkldnn_bfloat16_t *out, const float *inp,
57         size_t size) {
58     assert(mayiuse(avx512_core));
59     jit_call_t p_;
60     p_.inp = (void *)inp;
61     p_.out = (void *)out;
62     p_.size = size;
63     cvt_ps_to_bf16_.jit_ker(&p_);
64  }
65
66 inline float cvt_bfloat16_to_float(mkldnn_bfloat16_t inp) {
67     assert(mayiuse(avx512_core));
68     f32_bf16_t cvt = {0};
69     cvt.vbfloat[1] = inp;
70     return cvt.vfloat;
71 }
72
73 inline void cvt_bfloat16_to_float(float *out, const mkldnn_bfloat16_t *inp) {
74     assert(mayiuse(avx512_core));
75     f32_bf16_t cvt = {0};
76     cvt.vbfloat[1] = *inp;
77     *out = cvt.vfloat;
78 }
79
80 inline void cvt_bfloat16_to_float(float *out, const mkldnn_bfloat16_t *inp,
81         size_t size) {
82     assert(mayiuse(avx512_core));
83     jit_call_t p_;
84     p_.inp = (void *)inp;
85     p_.out = (void *)out;
86     p_.size = size;
87     cvt_bf16_to_ps_.jit_ker(&p_);
88 }
89
90 // performs element-by-element sum of inp and add float arrays and stores
91 // result to bfloat16 out array with downconversion
92 inline void add_floats_and_cvt_to_bfloat16(mkldnn_bfloat16_t *out,
93         const float *inp0,
94         const float *inp1,
95         size_t size) {
96     assert(mayiuse(avx512_core));
97     jit_call_t p_;
98     p_.inp = (void *)inp0;
99     p_.add = (void *)inp1;
100     p_.out = (void *)out;
101     p_.size = size;
102     add_cvt_ps_to_bf16_.jit_ker(&p_);
103 }
104
105 inline mkldnn_bfloat16_t approx_bfloat16_lowest() {
106     /* jit fails to convert FLT_MIN to bfloat16.
107      * It converst FLT_MIN to -INF. Truncate FLT_MIN
108      * to bfloat16 to get a value close to minimum bfloat16*/
109     f32_bf16_t f_raw = {0};
110     f_raw.vfloat = nstl::numeric_limits<float>::lowest();
111     f_raw.vbfloat[0] = 0;
112     return f_raw.vbfloat[1];
113 }
114
115 inline bool is_float_representable_in_bfloat16(float x) {
116     f32_bf16_t cvt = {0};
117     cvt.vfloat = x;
118     return cvt.vbfloat[0] == 0;
119 }
120
121 }
122 }
123 }
124 }
125
126 #endif