Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fully_connected / fully_connected_kernel_fb_io_block.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "fully_connected_kernel_fb_io_block.h"
18
19 namespace kernel_selector 
20 {
21     ParamsKey FullyConnected_fb_io_block::GetSupportedKey() const
22     {
23         ParamsKey k;
24         k.EnableInputDataType(Datatype::F16);
25         k.EnableOutputDataType(Datatype::F16);
26         k.EnableInputWeightsType(WeightsType::F16);
27         k.EnableInputWeightsType(WeightsType::F32);
28         k.EnableAllInputLayout();
29         k.EnableOutputLayout(DataLayout::fb);
30         k.EnableBatching();
31         k.EnableBiasPerFeature();
32         k.EnableNonBiasTerm();
33         k.EnableSubGroup();
34         return k;
35     }
36
37
38     FullyConnected_fb_io_block::DispatchData FullyConnected_fb_io_block::SetDefault(const fully_connected_params& arg, int ) const
39     {
40         auto kd = FullyConnectedKernelBase::SetDefault(arg);
41         const auto& output = arg.output;
42         
43         auto batch_size = output.Batch().v;
44         auto response_size = output.Feature().v;
45
46         constexpr uint32_t unit_byte_size = sizeof(short);
47         const char* chunk_type = "uint";
48         constexpr uint32_t chunk_byte_size = sizeof(uint32_t);
49         constexpr uint32_t sub_group_size = 16;
50         constexpr uint32_t units_per_chunk = chunk_byte_size / unit_byte_size;
51         constexpr uint32_t units_per_sg_read = sub_group_size * units_per_chunk;
52
53         
54         // Number of response groups. Each group (except last) writes units_per_sg_read responses
55         // for at least one input data set from batch.
56         auto rg_count = CeilDiv(response_size, units_per_sg_read);
57
58         kd.lws0 = sub_group_size;
59         // Number of work items needed to process all response groups.
60         kd.gws0 = rg_count * sub_group_size;
61         kd.lws1 = 1;
62         kd.gws1 = batch_size / units_per_sg_read;
63
64         kd.unit_byte_size    = unit_byte_size;
65         kd.chunk_type        = chunk_type;
66         kd.chunk_byte_size   = chunk_byte_size;
67         kd.units_per_chunk   = units_per_chunk;
68         kd.bytes_per_sg_read = sub_group_size * chunk_byte_size;
69         kd.units_per_sg_read = units_per_sg_read;
70         kd.rg_count          = (uint32_t)rg_count;
71         kd.last_rg_size      = response_size % units_per_sg_read;
72         return kd;
73     }
74
75     JitConstants FullyConnected_fb_io_block::GetJitConstants(const fully_connected_params& params, const FullyConnectedKernelBase::DispatchData& run_info) const
76     {
77         auto cldnn_jit = FullyConnectedKernelBase::GetJitConstants(params, run_info);
78         cldnn_jit.AddConstants({
79             MakeJitConstant("SUB_GROUP_SIZE",        run_info.lws0),
80             MakeJitConstant("WORK_ITEMS_PER_BATCH",  run_info.gws1),
81             MakeJitConstant("UNIT_BYTE_SIZE",        run_info.unit_byte_size),
82             MakeJitConstant("CHUNK_TYPE",            run_info.chunk_type),
83             MakeJitConstant("CHUNK_BYTE_SIZE",       run_info.chunk_byte_size),
84             MakeJitConstant("UNITS_PER_CHUNK",       run_info.units_per_chunk),
85             MakeJitConstant("BYTES_PER_SG_READ",     run_info.bytes_per_sg_read),
86             MakeJitConstant("UNITS_PER_SG_READ",     run_info.units_per_sg_read),
87             MakeJitConstant("RG_COUNT",              run_info.rg_count),
88             MakeJitConstant("LAST_RG_SIZE",          run_info.last_rg_size),
89         });
90         return cldnn_jit;
91     }
92
93     bool FullyConnected_fb_io_block::Validate(const Params& p, const optional_params& o) const
94     {
95         if (!FullyConnectedKernelBase::Validate(p, o))
96         {
97             return false;
98         }
99
100         const auto& params = static_cast<const fully_connected_params&>(p);
101
102         const auto& output = params.output;
103         const auto responseSize = output.Feature().v;
104         const auto batches = output.Batch().v;
105         const auto xSize = output.LogicalSize() / batches;
106
107         constexpr uint32_t subGroupSize         = 16;
108         constexpr uint32_t bytesPerElement      = sizeof(short);
109         constexpr uint32_t chunkSizeInBytes     = sizeof(uint32_t);
110         constexpr uint32_t chunkSizeInElements  = chunkSizeInBytes / bytesPerElement;
111         constexpr uint32_t elementsPerBlockRead = subGroupSize * chunkSizeInElements;
112
113         const bool bSupportedBatch = 
114             (batches > 0) && 
115             ((batches % 8) == 0) &&
116             ((batches % elementsPerBlockRead) == 0);
117
118         const bool bSupportedFeature = 
119             (responseSize > 0) && 
120             (((responseSize * bytesPerElement) % 4) == 0) &&
121             ((xSize % 8) == 0);
122
123         if (!bSupportedBatch ||
124             !bSupportedFeature)
125         {
126             return false;
127         }
128
129         return true;
130     }
131
132     KernelsData FullyConnected_fb_io_block::GetKernelsData(const Params& params, const optional_params& optParams) const
133     {
134         assert(params.GetType() == KernelType::FULLY_CONNECTED);
135
136         const auto& orgParams = static_cast<const fully_connected_params&>(params);
137
138         float estimated_time =
139             orgParams.inputs[0].GetDType() == Datatype::F16 && orgParams.output.Batch().v >= 16 ?
140             FORCE_PRIORITY_3 : FORCE_PRIORITY_5;
141
142         // TODO: it should be fb_io. but the original code use this kernel with yxfb and yxio 
143         //       (fb == fyxb flatten fyx, not yxfb flatten yxf).
144         //       the order of the add operation cause some numeric changes. in order to avoid them right now we use yxfb/oiyx instead.
145         // return GetCommonKernelsData(params, optParams, DataLayout::fb, WeightsLayout::io, estimated_time);
146         //return GetCommonKernelsData(params, optParams, DataLayout::yxfb, { WeightsLayout::yxio }, estimated_time);
147
148         KernelsData res = {};
149         for (size_t i = 0; i < autoTuneOptions.size(); i++)
150         {
151             KernelsData kd = GetTunedKernelsDataByIndex(params, optParams, DataLayout::yxfb, { WeightsLayout::yxio }, estimated_time, (int)i);
152             if (!kd.empty())
153             {
154                 res.emplace_back(kd[0]);
155             }
156         }
157
158         return res;
159         }
160 }