Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected / fully_connected_kernel_imad.cpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "fully_connected_kernel_imad.h"
18
19 // IMAD Fully_Connected primitive implementation.
20 // Limitations are:
21 // 1. Input=Fx1x1 with Filter=1x1
22 // 2. No data padding
23
24 namespace kernel_selector
25 {
26     ParamsKey FullyConnectedKernelIMAD::GetSupportedKey() const
27     {
28         ParamsKey k;
29         k.EnableInputDataType(Datatype::INT8);
30         k.EnableInputDataType(Datatype::UINT8);
31         k.EnableOutputDataType(Datatype::INT8);
32         k.EnableOutputDataType(Datatype::UINT8);
33         k.EnableInputWeightsType(WeightsType::INT8);
34         k.EnableInputLayout(DataLayout::b_fs_yx_fsv4);
35         k.EnableOutputLayout(DataLayout::bf);
36         k.EnableBiasPerOutput();
37         k.EnableBiasPerFeature();
38         k.EnableNonBiasTerm();
39         k.EnableTensorOffset();
40         k.EnableTensorPitches();
41         k.EnableBatching();
42         k.EnableInt8Quantization();
43         k.EnableOutputCalibration();
44         return k;
45     }
46
47     FullyConnectedKernelIMAD::Parent::DispatchData
48     FullyConnectedKernelIMAD::SetDefault(const fully_connected_params& params, int) const
49     {
50         const int simdSize = 16;
51
52         auto runInfo = Parent::SetDefault(params);
53
54         runInfo.gws0 = RoundUp(params.output.Feature().v, simdSize);
55         runInfo.gws1 = params.output.Batch().v;
56         runInfo.gws2 = 1;
57
58         runInfo.lws0 = simdSize;
59         runInfo.lws1 = 1;
60         runInfo.lws2 = 1;
61
62         return runInfo;
63     } // SetDefault
64
65     bool FullyConnectedKernelIMAD::Validate(const Params& params, const optional_params& options) const
66     {
67         if (!Parent::Validate(params, options)) {
68             return false;
69         }
70
71         const auto& newParams = static_cast<const fully_connected_params&>(params);
72         const auto& in = newParams.inputs[0];
73         const auto& weights = newParams.weights;
74
75         if ((in.X().v != 1) ||
76             (in.Y().v != 1) ||
77             (weights.X().v != 1) ||
78             (weights.Y().v != 1)) {
79             // Currently only Input=Fx1x1 with Filter=1x1 is supported
80             return false;
81         }
82         if ((in.X().pad.before != 0) ||
83             (in.X().pad.after != 0) ||
84             (in.Y().pad.before != 0) ||
85             (in.Y().pad.after != 0)) {
86             // Padding is not supported
87             return false;
88         }
89         if (in.Feature().v % (4 * 8)) {
90             // Algorith requires 4 bytes read as one int
91             // with specific weight format os_is_yx_osv16_isv4
92             // wich will read 8 elements per reading
93             return false;
94         }
95
96         return true;
97     } // Validate
98
99     KernelsData FullyConnectedKernelIMAD::GetKernelsData(const Params& params, const optional_params& options) const
100     {
101
102         KernelsData res = {};
103         for (size_t i = 0; i < autoTuneOptions.size(); i++)
104         {
105             KernelsData kd = GetTunedKernelsDataByIndex(
106                                 params, options, DataLayout::b_fs_yx_fsv4,
107                                 { WeightsLayout::os_is_yx_osv16_isv4 },
108                                 FORCE_PRIORITY_1, (int)i);
109             if (!kd.empty())
110             {
111                 res.emplace_back(kd[0]);
112             }
113         }
114         return res;
115     }
116 }