Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / cl_kernels / include / mmad.cl
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 void FUNC(intel_sub_group_block_write_4)( __local uint* p, uint4 data )
18 {
19     p[ get_sub_group_local_id() ] = data.s0;
20     p += 8;
21     p[ get_sub_group_local_id() ] = data.s1;
22     p += 8;
23     p[ get_sub_group_local_id() ] = data.s2;
24     p += 8;
25     p[ get_sub_group_local_id() ] = data.s3;
26 }
27
28 uint4 FUNC(intel_sub_group_block_read_uint4)(const __local uint* p)
29 {
30     uint4 ret;
31     uint idx = get_sub_group_local_id();
32
33     ret.s0 = p[idx]; idx += get_max_sub_group_size();
34     ret.s1 = p[idx]; idx += get_max_sub_group_size();
35     ret.s2 = p[idx]; idx += get_max_sub_group_size();
36     ret.s3 = p[idx]; idx += get_max_sub_group_size();
37
38     return ret;
39 }
40
41 uint8 FUNC(intel_sub_group_block_read_uint8)(const __local uint* p)
42 {
43     uint8 ret;
44     uint idx = get_sub_group_local_id();
45
46     ret.s0 = p[idx]; idx += get_max_sub_group_size();
47     ret.s1 = p[idx]; idx += get_max_sub_group_size();
48     ret.s2 = p[idx]; idx += get_max_sub_group_size();
49     ret.s3 = p[idx]; idx += get_max_sub_group_size();
50     ret.s4 = p[idx]; idx += get_max_sub_group_size();
51     ret.s5 = p[idx]; idx += get_max_sub_group_size();
52     ret.s6 = p[idx]; idx += get_max_sub_group_size();
53     ret.s7 = p[idx]; idx += get_max_sub_group_size();
54
55     return ret;
56 }
57
58 inline int FUNC(mmad_4)(char4 input, char4 weight, int acc)
59 {
60         acc += (input[0] * weight[0]);
61         acc += (input[1] * weight[1]);
62         acc += (input[2] * weight[2]);
63         acc += (input[3] * weight[3]);
64         return acc;
65 }
66
67 inline int FUNC(mmad8)(int8 A_scalars, int8 B_vectors, int acc)
68 {
69         acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_char4(B_vectors[0]), acc);
70         acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_char4(B_vectors[1]), acc);
71         acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_char4(B_vectors[2]), acc);
72         acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_char4(B_vectors[3]), acc);
73         acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_char4(B_vectors[4]), acc);
74         acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_char4(B_vectors[5]), acc);
75         acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_char4(B_vectors[6]), acc);
76         acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_char4(B_vectors[7]), acc);
77
78         return acc;
79 }
80
81 inline int4 FUNC(mmad4x8)(int4 A_vectors, int8 B_vectors, int4 acc)
82 {
83     int4 ret;
84     for(uint i = 0; i < 4; i++)
85     {
86         int8 A_scalars;
87         A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
88         A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
89         A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
90         A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
91         A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
92         A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
93         A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
94         A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
95         ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);    
96     }
97     return ret;
98 }
99
100 inline int8 FUNC(mmad8x8)(int8 A_vectors, int8 B_vectors, int8 acc)
101 {
102     int8 ret;
103     for(uint i = 0; i < 8; i++)
104     {
105         int8 A_scalars;
106         A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
107         A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
108         A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
109         A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
110         A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
111         A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
112         A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
113         A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
114         ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);    
115     }
116     return ret;
117 }
118
119 // TODO: remove it when cl_intel_subgroups_char extension will work
120 inline void FUNC(sub_group_block_write_uchar8)(__global uchar* outPtr, uchar8 v)
121 {
122 #ifdef cl_intel_subgroups_char
123     intel_sub_group_block_write_uc8(outPtr, v);
124 #else
125     uint idx = get_sub_group_local_id();
126
127         outPtr[idx] = v.s0; idx += get_max_sub_group_size();
128     outPtr[idx] = v.s1; idx += get_max_sub_group_size();
129     outPtr[idx] = v.s2; idx += get_max_sub_group_size();
130     outPtr[idx] = v.s3; idx += get_max_sub_group_size();
131     outPtr[idx] = v.s4; idx += get_max_sub_group_size();
132     outPtr[idx] = v.s5; idx += get_max_sub_group_size();
133     outPtr[idx] = v.s6; idx += get_max_sub_group_size();
134     outPtr[idx] = v.s7; idx += get_max_sub_group_size();
135 #endif
136 }
137
138 inline uchar8 FUNC(sub_group_block_read_uchar8)(const __global uchar* ptr)
139 {
140 #ifdef cl_intel_subgroups_char
141     return intel_sub_group_block_read_uc8(ptr);
142 #else
143     uint idx = get_sub_group_local_id();
144
145     uchar8 ret;
146
147     ret.s0 = ptr[idx]; idx += get_max_sub_group_size();
148     ret.s1 = ptr[idx]; idx += get_max_sub_group_size();
149     ret.s2 = ptr[idx]; idx += get_max_sub_group_size();
150     ret.s3 = ptr[idx]; idx += get_max_sub_group_size();
151     ret.s4 = ptr[idx]; idx += get_max_sub_group_size();
152     ret.s5 = ptr[idx]; idx += get_max_sub_group_size();
153     ret.s6 = ptr[idx]; idx += get_max_sub_group_size();
154     ret.s7 = ptr[idx]; idx += get_max_sub_group_size();
155
156     return ret;
157
158 #endif
159 }
160
161 //
162
163
164 #define MMAD_8(A, B, C) FUNC_CALL(mmad8)(A, B, C)
165 #define MMAD_4x8(A, B, C) FUNC_CALL(mmad4x8)(A, B, C)
166 #define MMAD_8x8(A, B, C) FUNC_CALL(mmad8x8)(A, B, C)
167 #define SLM_BLOCK_WRITE_4(A, B) (FUNC_CALL(intel_sub_group_block_write_4)(A, B))
168 #define SLM_BLOCK_READ_4(A) (FUNC_CALL(intel_sub_group_block_read_uint4)(A))
169 #define SLM_BLOCK_READ_8(A) (FUNC_CALL(intel_sub_group_block_read_uint8)(A))