[IE CLDNN] Fix linear_onnx Interpolate selection (#2769)
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / resample / resample_kernel_ref.cpp
1 // Copyright (c) 2016-2020 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <core/common/kernel_selector_utils.h>
16 #include "resample_kernel_ref.h"
17
18 #include <algorithm>
19 #include <vector>
20 #include <string>
21
22 namespace kernel_selector {
23
24 ParamsKey ResampleKernelRef::GetSupportedKey() const {
25     ParamsKey k;
26     k.EnableInputDataType(Datatype::UINT8);
27     k.EnableInputDataType(Datatype::INT8);
28     k.EnableInputDataType(Datatype::F16);
29     k.EnableInputDataType(Datatype::F32);
30     k.EnableOutputDataType(Datatype::UINT8);
31     k.EnableOutputDataType(Datatype::INT8);
32     k.EnableOutputDataType(Datatype::F16);
33     k.EnableOutputDataType(Datatype::F32);
34     k.EnableDifferentTypes();
35     k.EnableAllInputLayout();
36     k.EnableAllOutputLayout();
37     k.EnableTensorOffset();
38     k.EnableTensorPitches();
39     k.EnableBatching();
40     k.EnableReampleType(ResampleType::NEAREST_NEIGHBOR);
41     k.EnableReampleType(ResampleType::CAFFE_BILINEAR_INTERP);
42     k.EnableReampleType(ResampleType::BILINEAR_INTERP);
43     k.EnableReampleType(ResampleType::CUBIC);
44     k.EnableReampleType(ResampleType::LINEAR_ONNX);
45     return k;
46 }
47
48 KernelsData ResampleKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
49     return GetCommonKernelsData(params, options);
50 }
51
52 static size_t packing_factor(const resample_params& params) {
53     // TODO Add support for only input packing
54     bool in_out_8bit = (params.inputs[0].GetDType() == Datatype::UINT8 || params.inputs[0].GetDType() == Datatype::INT8) &&
55                        (params.output.GetDType() == Datatype::UINT8 || params.output.GetDType() == Datatype::INT8);
56
57     if (!in_out_8bit)
58         return 1;
59
60     auto get_layout_packing_factor = [](const DataLayout& layout) -> size_t {
61         switch (layout) {
62         case DataLayout::b_fs_yx_fsv16:
63             return 16;
64         case DataLayout::b_fs_yx_fsv4:
65             return 4;
66         default:
67             break;
68         }
69         return 1;
70     };
71
72     size_t input_factor = get_layout_packing_factor(params.inputs[0].GetLayout());
73     size_t output_factor = get_layout_packing_factor(params.output.GetLayout());
74
75     if (input_factor % output_factor == 0 || output_factor % input_factor == 0)
76         return std::min(input_factor, output_factor);
77     return 1;
78 }
79
80 static bool use_packing(const resample_params& params) {
81     if (params.resampleType != ResampleType::NEAREST_NEIGHBOR)
82         return false;
83
84     auto pack = packing_factor(params);
85     if (pack == 1)
86         return false;
87
88     if (params.inputs[0].Feature().pad.before % pack != 0 || params.output.Feature().pad.before % pack != 0)
89         return false;
90
91     auto packed_work_items = params.output.X().v * params.output.Y().v * params.output.Z().v
92         * CeilDiv(params.output.Feature().v, pack) * params.output.Batch().v;
93     // TODO Loosen this requirement to minimum EUs needed to saturate cache bandwidth
94     constexpr size_t max_work_items_per_eu = 32 * 7;
95     auto minimum_work_items = params.engineInfo.computeUnitsCount * max_work_items_per_eu;
96
97     if (packed_work_items < minimum_work_items)
98         return false;
99
100     return true;
101 }
102
103 JitConstants ResampleKernelRef::GetJitConstants(const resample_params& params) const {
104     JitConstants jit = ResampleKernelBase::GetJitConstants(params);
105
106     if (use_packing(params)) {
107         jit.AddConstant(MakeJitConstant("PACK_SIZE", packing_factor(params)));
108         jit.AddConstant(MakeJitConstant("FEATURE_PACKED_MODE", "1"));
109     }
110
111     if (!params.fused_ops.empty()) {
112         std::vector<std::string> idx_order;
113         if (DataTensor::ChannelsCount(params.output.GetLayout()) == 4) {
114             idx_order = {"batch", "OF_ID", "oy", "ox"};
115         } else if (DataTensor::ChannelsCount(params.output.GetLayout()) == 5) {
116             idx_order = {"batch", "OF_ID", "oz", "oy", "ox"};
117         }
118
119         FusedOpsConfiguration conf = {"", idx_order, "interp_val", GetAccumulatorType(params), 1};
120         jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
121     }
122
123     return jit;
124 }
125
126 ResampleKernelBase::DispatchData ResampleKernelRef::SetDefault(const resample_params& arg) const {
127     auto dispatchData = Parent::SetDefault(arg);
128
129     if (use_packing(arg)) {
130         auto pack = packing_factor(arg);
131         dispatchData.gws = { arg.output.X().v, arg.output.Y().v * arg.output.Z().v, CeilDiv(arg.output.Feature().v, pack) * arg.output.Batch().v };
132         dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, arg.engineInfo);
133     }
134
135     return dispatchData;
136 }
137 }  // namespace kernel_selector