2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 void FUNC(intel_sub_group_block_write_4)( __local uint* p, uint4 data )
19 p[ get_sub_group_local_id() ] = data.s0;
21 p[ get_sub_group_local_id() ] = data.s1;
23 p[ get_sub_group_local_id() ] = data.s2;
25 p[ get_sub_group_local_id() ] = data.s3;
28 uint4 FUNC(intel_sub_group_block_read_uint4)(const __local uint* p)
31 uint idx = get_sub_group_local_id();
33 ret.s0 = p[idx]; idx += get_max_sub_group_size();
34 ret.s1 = p[idx]; idx += get_max_sub_group_size();
35 ret.s2 = p[idx]; idx += get_max_sub_group_size();
36 ret.s3 = p[idx]; idx += get_max_sub_group_size();
41 uint8 FUNC(intel_sub_group_block_read_uint8)(const __local uint* p)
44 uint idx = get_sub_group_local_id();
46 ret.s0 = p[idx]; idx += get_max_sub_group_size();
47 ret.s1 = p[idx]; idx += get_max_sub_group_size();
48 ret.s2 = p[idx]; idx += get_max_sub_group_size();
49 ret.s3 = p[idx]; idx += get_max_sub_group_size();
50 ret.s4 = p[idx]; idx += get_max_sub_group_size();
51 ret.s5 = p[idx]; idx += get_max_sub_group_size();
52 ret.s6 = p[idx]; idx += get_max_sub_group_size();
53 ret.s7 = p[idx]; idx += get_max_sub_group_size();
58 inline int FUNC(mmad_4)(char4 input, char4 weight, int acc)
60 acc += (input[0] * weight[0]);
61 acc += (input[1] * weight[1]);
62 acc += (input[2] * weight[2]);
63 acc += (input[3] * weight[3]);
67 inline int FUNC(mmad8)(int8 A_scalars, int8 B_vectors, int acc)
69 acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_char4(B_vectors[0]), acc);
70 acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_char4(B_vectors[1]), acc);
71 acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_char4(B_vectors[2]), acc);
72 acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_char4(B_vectors[3]), acc);
73 acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_char4(B_vectors[4]), acc);
74 acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_char4(B_vectors[5]), acc);
75 acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_char4(B_vectors[6]), acc);
76 acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_char4(B_vectors[7]), acc);
81 inline int4 FUNC(mmad4x8)(int4 A_vectors, int8 B_vectors, int4 acc)
84 for(uint i = 0; i < 4; i++)
87 A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
88 A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
89 A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
90 A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
91 A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
92 A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
93 A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
94 A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
95 ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);
100 inline int8 FUNC(mmad8x8)(int8 A_vectors, int8 B_vectors, int8 acc)
103 for(uint i = 0; i < 8; i++)
106 A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
107 A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
108 A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
109 A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
110 A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
111 A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
112 A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
113 A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
114 ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);
119 // TODO: remove it when cl_intel_subgroups_char extension will work
120 inline void FUNC(sub_group_block_write_uchar8)(__global uchar* outPtr, uchar8 v)
122 #ifdef cl_intel_subgroups_char
123 intel_sub_group_block_write_uc8(outPtr, v);
125 uint idx = get_sub_group_local_id();
127 outPtr[idx] = v.s0; idx += get_max_sub_group_size();
128 outPtr[idx] = v.s1; idx += get_max_sub_group_size();
129 outPtr[idx] = v.s2; idx += get_max_sub_group_size();
130 outPtr[idx] = v.s3; idx += get_max_sub_group_size();
131 outPtr[idx] = v.s4; idx += get_max_sub_group_size();
132 outPtr[idx] = v.s5; idx += get_max_sub_group_size();
133 outPtr[idx] = v.s6; idx += get_max_sub_group_size();
134 outPtr[idx] = v.s7; idx += get_max_sub_group_size();
138 inline uchar8 FUNC(sub_group_block_read_uchar8)(const __global uchar* ptr)
140 #ifdef cl_intel_subgroups_char
141 return intel_sub_group_block_read_uc8(ptr);
143 uint idx = get_sub_group_local_id();
147 ret.s0 = ptr[idx]; idx += get_max_sub_group_size();
148 ret.s1 = ptr[idx]; idx += get_max_sub_group_size();
149 ret.s2 = ptr[idx]; idx += get_max_sub_group_size();
150 ret.s3 = ptr[idx]; idx += get_max_sub_group_size();
151 ret.s4 = ptr[idx]; idx += get_max_sub_group_size();
152 ret.s5 = ptr[idx]; idx += get_max_sub_group_size();
153 ret.s6 = ptr[idx]; idx += get_max_sub_group_size();
154 ret.s7 = ptr[idx]; idx += get_max_sub_group_size();
164 #define MMAD_8(A, B, C) FUNC_CALL(mmad8)(A, B, C)
165 #define MMAD_4x8(A, B, C) FUNC_CALL(mmad4x8)(A, B, C)
166 #define MMAD_8x8(A, B, C) FUNC_CALL(mmad8x8)(A, B, C)
167 #define SLM_BLOCK_WRITE_4(A, B) (FUNC_CALL(intel_sub_group_block_write_4)(A, B))
168 #define SLM_BLOCK_READ_4(A) (FUNC_CALL(intel_sub_group_block_read_uint4)(A))
169 #define SLM_BLOCK_READ_8(A) (FUNC_CALL(intel_sub_group_block_read_uint8)(A))