Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / activation / activation_kernel_base.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "activation_kernel_base.h"
18 #include "kernel_selector_utils.h"
19  
20 namespace kernel_selector 
21 {
22
23     ActivationKernelBase::DispatchData ActivationKernelBase::SetDefault(const activation_params& arg) const
24     {
25         const auto& out = arg.output;
26
27         DispatchData runInfo;
28         std::vector<size_t> global = { out.X().v, out.Y().v, out.Feature().v*out.Batch().v };
29         if (out.GetLayout() == DataLayout::yxfb)
30         {
31             global[0] = out.Feature().v*out.Batch().v;
32             global[1] = out.X().v;
33             global[2] = out.Y().v;
34         }
35         std::vector<size_t> local = GetOptimalLocalWorkGroupSizes(global);
36         runInfo.gws0 = global[0];
37         runInfo.gws1 = global[1];
38         runInfo.gws2 = global[2];
39         runInfo.lws0 = local[0];
40         runInfo.lws1 = local[1];
41         runInfo.lws2 = local[2];
42
43         runInfo.effiency = DONT_USE_IF_HAVE_SOMETHING_ELSE;
44         runInfo.fp16UnitUsed = out.GetDType() == Datatype::F16;
45
46         return runInfo;
47     }
48
49     JitConstants ActivationKernelBase::GetJitConstants(const activation_params& params, DispatchData) const
50     {
51         JitConstants jit = MakeBaseParamsJitConstants(params);
52
53         const auto& inputNlParams = params.inputActivationParams;
54
55         jit.AddConstants({
56             MakeJitConstant("PARAMS_NUM", GetActivationAdditionalParamsNumber(params.activation.function)),
57         });
58
59         if (!inputNlParams.empty())
60         {
61             jit.AddConstants({
62                 MakeJitConstant("ADDITIONAL_PARAMS", inputNlParams[0]),
63                 MakeJitConstant("PARAMETERIZED", ""),
64             });
65         }
66
67         return jit;
68     }
69
70     bool ActivationKernelBase::Validate(const Params& p, const optional_params& o) const
71     {
72         if (p.GetType() != KernelType::ACTIVATION ||
73             o.GetType() != KernelType::ACTIVATION)
74         {
75             return false;
76         }
77
78         return true;
79     }
80
81     KernelsData ActivationKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
82     {
83         if (!Validate(params, options))
84         {
85             return{};
86         }
87
88         KernelData kd = KernelData::Default<activation_params>(params);
89
90         activation_params& newParams = *static_cast<activation_params*>(kd.params.get());
91         const std::string kernel_id = GetEntryPoint(kernelName, params.layerID, options);
92
93         auto runInfo = SetDefault(newParams);
94         auto cldnn_jit = GetJitConstants(newParams, runInfo);
95         auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
96         auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
97         
98         auto& kernel = kd.kernels[0];
99         FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point);
100         
101         if (newParams.gradient)
102             kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 1 });
103
104         if (!newParams.inputActivationParams.empty())
105         {
106             kernel.arguments.push_back({ ArgumentDescriptor::Types::SLOPE, 0 });
107         }
108
109         kd.estimatedTime = runInfo.effiency;
110
111         return{ kd };
112     }
113 }