db61e0bde4c093de5b796adffaa305f94378a8d2
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected / fully_connected_kernel_fb_io_b8_f8.cpp
1 // Copyright (c) 2016 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include "fully_connected_kernel_fb_io_b8_f8.h"
17
18 namespace kernel_selector {
19 ParamsKey FullyConnected_fb_io_b8_f8::GetSupportedKey() const {
20     ParamsKey k;
21     k.EnableInputDataType(Datatype::F32);
22     k.EnableInputDataType(Datatype::F16);
23     k.EnableOutputDataType(Datatype::F32);
24     k.EnableOutputDataType(Datatype::F16);
25     k.EnableInputWeightsType(WeightsType::F32);
26     k.EnableInputWeightsType(WeightsType::F16);
27     k.EnableAllInputLayout();
28     k.EnableOutputLayout(DataLayout::fb);
29     k.EnableBatching();
30     k.EnableBiasPerFeature();
31     k.EnableNonBiasTerm();
32     k.EnableSubGroup();
33     return k;
34 }
35
36 FullyConnected_fb_io_b8_f8::DispatchData FullyConnected_fb_io_b8_f8::SetDefault(const fully_connected_params& arg,
37                                                                                 int) const {
38     auto kd = FullyConnectedBlockKernelBase::SetDefault(arg);
39
40     const auto& output = arg.output;
41
42     size_t groups_per_batches = GetLocalGroupsSize(arg);
43     kd.gws0 =
44         Align(output.LogicalSize() / (GetNeuronsPerWorkItem(arg) * GetBatchesPerWorkItem(arg) * groups_per_batches), 8);
45     kd.gws1 = groups_per_batches;
46     kd.lws0 = 8;
47     kd.lws1 = 1;
48
49     return kd;
50 }
51
52 bool FullyConnected_fb_io_b8_f8::Validate(const Params& p, const optional_params& o) const {
53     if (!FullyConnectedBlockKernelBase::Validate(p, o)) {
54         return false;
55     }
56
57     const auto& params = static_cast<const fully_connected_params&>(p);
58
59     const auto& output = params.output;
60     const auto batches = output.Batch().v;
61     const auto x_size = output.LogicalSize() / batches;
62
63     const auto& input = params.inputs[0];
64     const auto input_x_size = input.LogicalSize() / input.Batch().v;
65     const bool proper_input_aligment = (input_x_size % 8) == 0;
66     const bool proper_output_aligment =
67         (output.LogicalSize() /
68          (GetNeuronsPerWorkItem(params) * GetBatchesPerWorkItem(params) * GetLocalGroupsSize(params)) % 8) == 0;
69     const bool bSupportedBatch = (batches % 8) == 0;
70     const bool bSupportedFeature = (x_size % 8) == 0;
71
72     if (!bSupportedBatch || !bSupportedFeature || !proper_input_aligment || !proper_output_aligment) {
73         return false;
74     }
75
76     return true;
77 }
78
79 KernelsData FullyConnected_fb_io_b8_f8::GetKernelsData(const Params& params, const optional_params& optParams) const {
80     assert(params.GetType() == KernelType::FULLY_CONNECTED);
81     KernelsData res = {};
82     const auto& orgParams = static_cast<const fully_connected_params&>(params);
83
84     float estimated_time = orgParams.inputs[0].GetDType() == Datatype::F16 && orgParams.output.Batch().v >= 16
85                                ? FORCE_PRIORITY_3
86                                : FORCE_PRIORITY_5;
87
88     for (size_t i = 0; i < autoTuneOptions.size(); i++) {
89         KernelsData kd =
90             GetTunedKernelsDataByIndex(params, optParams, DataLayout::fb, {WeightsLayout::io}, estimated_time, static_cast<int>(i));
91         if (!kd.empty()) {
92             res.emplace_back(kd[0]);
93         }
94     }
95
96     return res;
97 }
98 }  // namespace kernel_selector