Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api_extension / CPP / fused_conv_bn_scale.hpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "../C/fused_conv_bn_scale.h"
20 #include "api/CPP/primitive.hpp"
21
22 namespace cldnn
23 {
24 /// @addtogroup cpp_api C++ API
25 /// @{
26 /// @addtogroup cpp_topology Network Topology
27 /// @{
28 /// @addtogroup cpp_primitives Primitives
29 /// @{
30
31 /// @brief Primitives that fuses convolution, batch norm, scale and optionally Relu.
32 struct fused_conv_bn_scale : public primitive_base<fused_conv_bn_scale, CLDNN_PRIMITIVE_DESC(fused_conv_bn_scale)>
33 {
34     CLDNN_DECLARE_PRIMITIVE(fused_conv_bn_scale)
35
36     /// @brief Constructs convolution primitive fused with batch norm and scale.
37     /// @param id This primitive id.
38     /// @param input Input primitive id.
39     /// @param weights List of primitive ids containing weights data.
40     /// @param bias List of primitive ids containing bias data.
41     /// @param epsilon Small number to protect from 0 dividing.
42     /// @param scale_input Scale input primitive id with values needed for product computation. Used in fused scale part.
43     /// @param scale_bias Primitive id containing bias data for fused scale part.
44     /// @param input_offset Defines a shift, relative to (0,0) position of the input buffer, where (0,0) point of the convolution window should start calculations.
45     /// @param stride Defines shift in input buffer between adjacent calculations of output values.
46     /// @param inv_variance Primitive id containing inverted variance calculated in this primitive. Used in fused batch norm part.
47     /// @param with_activation Enable Relu activation.
48     /// @param activation_slp Relu activation slope.
49     fused_conv_bn_scale(
50         const primitive_id& id,
51         const primitive_id& input,
52         const std::vector<primitive_id>& weights,
53         const std::vector<primitive_id>& bias,
54         float epsilon,
55         const primitive_id& scale_input,
56         const primitive_id& scale_bias = "",
57         tensor stride = { 1, 1, 1, 1 },
58         tensor dilation = { 1, 1, 1, 1 },
59         tensor input_offset = { 0,0,0,0 },
60         const primitive_id& inv_variance = "",
61         bool with_activation = false,
62         float activation_slp = 0.0f,
63         const padding& output_padding = padding()
64     )
65         :primitive_base(id, { input, scale_input }, output_padding)
66         , weights(_weights.cpp_ids)
67         , bias(_bias.cpp_ids)
68         , input_offset(input_offset)
69         , stride(stride)
70         , dilation(dilation)
71         , with_activation(with_activation)
72         , activation_negative_slope(activation_slp)
73         , with_output_size(false)
74         , scale_bias(scale_bias)
75         , inv_variance(inv_variance)
76         , epsilon(epsilon)
77         , _weights(weights)
78         , _bias(bias)
79     {
80         if ((bias.size() != 0) && (weights.size() != bias.size()))
81             throw std::runtime_error("convolution's weights/bias count does not match");
82     }
83
84     /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{fused_conv_bn_scale}
85     fused_conv_bn_scale(const dto* dto)
86         :primitive_base(dto)
87         , weights(_weights.cpp_ids)
88         , bias(_bias.cpp_ids)
89         , input_offset(dto->input_offset)
90         , stride(dto->stride)
91         , dilation(dto->dilation)
92         , with_activation(dto->with_activation != 0)
93         , activation_negative_slope(dto->activation_negative_slope)
94         , scale_bias(dto->scale_bias)
95         , inv_variance(dto->inv_variance)
96         , epsilon(dto->epsilon)
97         , _weights(dto->weights)
98         , _bias(dto->bias)
99     {
100         if (!dto->split || (weights.size() != bias.size() && bias.size() != 0) || dto->split != weights.size())
101             throw std::invalid_argument("Invalid convolution dto: bad split value");
102     }
103
104     /// @brief List of primitive ids containing weights data.
105     fixed_size_vector_ref weights;
106     /// @brief List of primitive ids containing bias data.
107     fixed_size_vector_ref bias;
108     /// @brief Defines a shift, relative to (0,0) position of the input buffer, where (0,0) point of the convolution window should start calculations.
109     tensor input_offset;
110     /// @brief Defines shift in input buffer between adjacent calculations of output values.
111     tensor stride;
112     /// @brief Defines gaps in the input - dilation rate k=1 is normal convolution, k=2 means skipping one pixel per input, k=4 means skipping 3 pixels.
113     /// As an example in one dimension, a filter w of size 3 would compute over input x the following: w[0]*x[0] + w[1]*x[1] + w[2]*x[2] for dilation of 1. 
114     /// For dilation 2 the filter would instead compute w[0]*x[0] + w[1]*x[2] + w[2]*x[4].
115     tensor dilation;
116     /// @brief Enable Relu activation.
117     bool with_activation;
118     /// @brief Relu activation slope.
119     float activation_negative_slope;
120     /// @brief Indicates that the primitive has user-defined output size (non-zero value).
121     bool with_output_size;
122     /// @brief User-defined output data size of the primitive (w/o padding).
123     tensor output_size;
124     /// @brief Primitive id containing scale bias data for fused convolution.
125     primitive_id scale_bias;
126     /// @brief Primitive id containing inverted variance used in future gradient computing for fused convolution.
127     primitive_id inv_variance;
128     /// @brief Epsilon for fused convolution.
129     float epsilon;
130     /// @brief On how many cards split the computation to.
131     int32_t split() const { return static_cast<int32_t>(weights.size()); }
132
133 protected:
134     primitive_id_arr _weights;
135     primitive_id_arr _bias;
136
137     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
138     {
139         std::vector<std::reference_wrapper<const primitive_id>> ret;
140         ret.reserve(weights.size() + bias.size() + !scale_bias.empty() + !inv_variance.empty());
141         for (auto& w : weights)
142             ret.push_back(w);
143         for (auto& b : bias)
144             ret.push_back(b);
145         if (!scale_bias.empty())
146             ret.push_back(scale_bias);
147         if (!inv_variance.empty())
148             ret.push_back(inv_variance);
149         return ret;
150     }
151
152     void update_dto(dto& dto) const override
153     {
154         dto.weights = _weights.ref();
155         dto.bias = _bias.ref();
156         dto.input_offset = input_offset;
157         dto.stride = stride;
158         dto.dilation = dilation;
159         dto.split = split();
160         dto.with_activation = with_activation;
161         dto.activation_negative_slope = activation_negative_slope;
162         dto.epsilon = epsilon;
163         dto.inv_variance = inv_variance.c_str();
164         dto.scale_bias = scale_bias.c_str();
165     }
166 };
167 /// @}
168 /// @}
169 /// @}
170 }