ceb6dc1e5ec9b9fcb5d6006738779c129c826ade
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected / fully_connected_kernel_mmad.cpp
1 // Copyright (c) 2016-2020 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "fully_connected_kernel_mmad.h"
16 #include "kernel_selector_utils.h"
17
18 namespace kernel_selector {
19
20 namespace {
21     static const size_t sub_group_size = 8;
22 }  // namespace
23
24 ParamsKey FullyConnectedKernelMMAD::GetSupportedKey() const {
25     ParamsKey k;
26     k.EnableInputDataType(Datatype::INT8);
27     k.EnableInputDataType(Datatype::UINT8);
28
29     k.EnableOutputDataType(Datatype::INT8);
30     k.EnableOutputDataType(Datatype::UINT8);
31     k.EnableOutputDataType(Datatype::F32);
32     k.EnableOutputDataType(Datatype::F16);
33
34     k.EnableInputWeightsType(WeightsType::INT8);
35
36     k.EnableDifferentInputWeightsTypes();
37     k.EnableDifferentTypes();
38
39     k.EnableInputLayout(DataLayout::bfyx);
40     k.EnableInputLayout(DataLayout::b_fs_yx_fsv32);
41     k.EnableInputLayout(DataLayout::b_fs_zyx_fsv32);
42     k.EnableOutputLayout(DataLayout::bf);
43
44     k.EnableBiasPerOutput();
45     k.EnableBiasPerFeature();
46     k.EnableNonBiasTerm();
47     k.EnableTensorOffset();
48     k.EnableTensorPitches();
49     k.EnableBatching();
50     k.EnableQuantization(QuantizationType::SYMMETRIC);
51     return k;
52 }
53
54 bool FullyConnectedKernelMMAD::Validate(const Params& params, const optional_params& options) const {
55     if (!Parent::Validate(params, options))
56         return false;
57
58     auto fc_params = static_cast<const fully_connected_params&>(params);
59     auto input = fc_params.inputs[0];
60     if (input.GetLayout() == DataLayout::bfyx &&
61         (input.X().LogicalDimPadded() != 1 || input.Y().LogicalDimPadded() != 1 || input.Z().LogicalDimPadded() != 1)) {
62         return false;
63     }
64
65     return true;
66 }
67
68 FullyConnectedKernelMMAD::DispatchData FullyConnectedKernelMMAD::SetDefault(const fully_connected_params& params,
69                                                                             int) const {
70     auto runInfo = Parent::SetDefault(params);
71
72     const auto& out = params.output;
73
74     std::vector<size_t> global = { Align(out.Feature().v, sub_group_size), out.Batch().v, 1 };
75     auto local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
76
77     runInfo.gws0 = global[0];
78     runInfo.gws1 = global[1];
79     runInfo.gws2 = global[2];
80
81     runInfo.lws0 = local[0];
82     runInfo.lws1 = local[1];
83     runInfo.lws2 = local[2];
84
85     return runInfo;
86 }
87
88 JitConstants FullyConnectedKernelMMAD::GetJitConstants(const fully_connected_params& params,
89                                                        const DispatchData& runInfo) const {
90     auto jit = Parent::GetJitConstants(params, runInfo);
91
92     auto& input = params.inputs[0];
93     auto& weights = params.weights;
94
95     jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", sub_group_size));
96     if (input.GetDims().size() == 5) {
97         jit.AddConstant(MakeJitConstant("FILTER_GET_OFFSET(f)", "GET_FILTER_OS_IS_YX_ISA8_OSV8_ISV4_INDEX(FILTER, f, 0, 0, 0)"));
98     } else {
99         jit.AddConstant(MakeJitConstant("FILTER_GET_OFFSET(f)", "GET_FILTER_OS_IS_ZYX_ISA8_OSV8_ISV4_INDEX(FILTER, f, 0, 0, 0, 0)"));
100     }
101
102     Datatype input_packed_type = Datatype::INT32;
103     Datatype filter_packed_type = Datatype::INT32;
104
105     if (input.GetDType() == Datatype::UINT8) {
106         input_packed_type = Datatype::UINT32;
107     } else if (input.GetDType() == Datatype::INT8) {
108         input_packed_type = Datatype::INT32;
109     }
110
111     if (weights.GetDType() == WeightsType::UINT8) {
112         filter_packed_type = Datatype::UINT32;
113     } else if (weights.GetDType() == WeightsType::INT8) {
114         filter_packed_type = Datatype::INT32;
115     }
116
117     jit.Merge(MakeTypeJitConstants(input_packed_type, "INPUT_PACKED"));
118     jit.Merge(MakeTypeJitConstants(filter_packed_type, "FILTER_PACKED"));
119
120     auto filter_spatial_size = weights.X().v * weights.Y().v * weights.Z().v;
121     int filter_spatial_pitch = 4 * 8 * 8;
122
123     jit.AddConstant(MakeJitConstant("FILTER_SPATIAL_SIZE", filter_spatial_size));
124     jit.AddConstant(MakeJitConstant("MMAD_FILTER_SPATIAL_PITCH", filter_spatial_pitch));
125     jit.AddConstant(MakeJitConstant("MMAD_FILTER_FBLOCK_PITCH", filter_spatial_size * filter_spatial_pitch));
126
127     size_t input_x_pitch = input.X().pitch;
128     size_t input_y_pitch = input.Y().pitch;
129     size_t input_z_pitch = input.Z().pitch;
130
131     if (input.GetLayout() == DataLayout::bfyx) {
132         jit.AddConstant(MakeJitConstant("MMAD_INPUT_FBLOCK_PITCH", 32));
133     } else if (input.GetLayout() == DataLayout::b_fs_yx_fsv32 || input.GetLayout() == DataLayout::b_fs_zyx_fsv32) {
134         input_x_pitch = 32;
135         input_y_pitch *= 32;
136         input_z_pitch *= 32;
137         jit.AddConstant(MakeJitConstant("MMAD_INPUT_FBLOCK_PITCH", input.Feature().pitch * 32));
138     }
139
140     if (input.GetLayout() == DataLayout::bfyx && input.Feature().v % 32 != 0) {
141         jit.AddConstant(MakeJitConstant("HAS_FEATURE_LEFTOVERS", true));
142         jit.AddConstant(MakeJitConstant("FEATURE_BLOCKS_COUNT", input.Feature().v / 32));
143     } else {
144         jit.AddConstant(MakeJitConstant("FEATURE_BLOCKS_COUNT", CeilDiv(input.Feature().v, 32)));
145     }
146
147     jit.AddConstant(MakeJitConstant("MMAD_INPUT_SPATIAL_PITCH", input_x_pitch));
148     jit.AddConstant(MakeJitConstant("MMAD_INPUT_X_PITCH", input_x_pitch));
149     jit.AddConstant(MakeJitConstant("MMAD_INPUT_Y_PITCH", input_y_pitch));
150     jit.AddConstant(MakeJitConstant("MMAD_INPUT_Z_PITCH", input_z_pitch));
151
152     bool split_spatial = input.X().pad.Total() != 0 || input.Y().pad.Total() != 0 || input.Z().pad.Total() != 0;
153     bool spatial_major = DataTensor::Channelndex(input.GetLayout(), Tensor::DataChannelName::X) <
154                          DataTensor::Channelndex(input.GetLayout(), Tensor::DataChannelName::FEATURE);
155
156     jit.AddConstant(MakeJitConstant("SPLIT_SPATIAL", split_spatial));
157     jit.AddConstant(MakeJitConstant("SPATIAL_MAJOR", spatial_major));
158
159     if (!params.fused_ops.empty()) {
160         auto input_dt = GetActivationType(params);
161         FusedOpsConfiguration conf = { "", {"b", "f", "0", "0"}, "dequantized", input_dt, 1 };
162         jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
163     }
164
165     return jit;
166 }
167
168 KernelsData FullyConnectedKernelMMAD::GetKernelsData(const Params& params, const optional_params& options) const {
169     auto fc_params = static_cast<const fully_connected_params&>(params);
170     auto& input = fc_params.inputs[0];
171
172     auto w_layout = WeightsLayout::os_is_yx_isa8_osv8_isv4;
173     if (input.GetDims().size() == 5) {
174         w_layout = WeightsLayout::os_is_zyx_isa8_osv8_isv4;
175     }
176
177     KernelsData res = {};
178     for (size_t i = 0; i < autoTuneOptions.size(); i++) {
179         KernelsData kd = GetTunedKernelsDataByIndex(params,
180                                                     options,
181                                                     input.GetLayout(),
182                                                     w_layout,
183                                                     FORCE_PRIORITY_9,
184                                                     static_cast<int>(i));
185         if (!kd.empty()) {
186             res.emplace_back(kd[0]);
187         }
188     }
189     return res;
190 }
191 }  // namespace kernel_selector