2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "../C/fused_conv_bn_scale.h"
20 #include "api/CPP/primitive.hpp"
24 /// @addtogroup cpp_api C++ API
26 /// @addtogroup cpp_topology Network Topology
28 /// @addtogroup cpp_primitives Primitives
31 /// @brief Primitives that fuses convolution, batch norm, scale and optionally Relu.
32 struct fused_conv_bn_scale : public primitive_base<fused_conv_bn_scale, CLDNN_PRIMITIVE_DESC(fused_conv_bn_scale)>
34 CLDNN_DECLARE_PRIMITIVE(fused_conv_bn_scale)
36 /// @brief Constructs convolution primitive fused with batch norm and scale.
37 /// @param id This primitive id.
38 /// @param input Input primitive id.
39 /// @param weights List of primitive ids containing weights data.
40 /// @param bias List of primitive ids containing bias data.
41 /// @param epsilon Small number to protect from 0 dividing.
42 /// @param scale_input Scale input primitive id with values needed for product computation. Used in fused scale part.
43 /// @param scale_bias Primitive id containing bias data for fused scale part.
44 /// @param input_offset Defines a shift, relative to (0,0) position of the input buffer, where (0,0) point of the convolution window should start calculations.
45 /// @param stride Defines shift in input buffer between adjacent calculations of output values.
46 /// @param inv_variance Primitive id containing inverted variance calculated in this primitive. Used in fused batch norm part.
47 /// @param with_activation Enable Relu activation.
48 /// @param activation_slp Relu activation slope.
50 const primitive_id& id,
51 const primitive_id& input,
52 const std::vector<primitive_id>& weights,
53 const std::vector<primitive_id>& bias,
55 const primitive_id& scale_input,
56 const primitive_id& scale_bias = "",
57 tensor stride = { 1, 1, 1, 1 },
58 tensor dilation = { 1, 1, 1, 1 },
59 tensor input_offset = { 0,0,0,0 },
60 const primitive_id& inv_variance = "",
61 bool with_activation = false,
62 float activation_slp = 0.0f,
63 const padding& output_padding = padding()
65 :primitive_base(id, { input, scale_input }, output_padding)
66 , weights(_weights.cpp_ids)
68 , input_offset(input_offset)
71 , with_activation(with_activation)
72 , activation_negative_slope(activation_slp)
73 , with_output_size(false)
74 , scale_bias(scale_bias)
75 , inv_variance(inv_variance)
80 if ((bias.size() != 0) && (weights.size() != bias.size()))
81 throw std::runtime_error("convolution's weights/bias count does not match");
84 /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{fused_conv_bn_scale}
85 fused_conv_bn_scale(const dto* dto)
87 , weights(_weights.cpp_ids)
89 , input_offset(dto->input_offset)
91 , dilation(dto->dilation)
92 , with_activation(dto->with_activation != 0)
93 , activation_negative_slope(dto->activation_negative_slope)
94 , scale_bias(dto->scale_bias)
95 , inv_variance(dto->inv_variance)
96 , epsilon(dto->epsilon)
97 , _weights(dto->weights)
100 if (!dto->split || (weights.size() != bias.size() && bias.size() != 0) || dto->split != weights.size())
101 throw std::invalid_argument("Invalid convolution dto: bad split value");
104 /// @brief List of primitive ids containing weights data.
105 fixed_size_vector_ref weights;
106 /// @brief List of primitive ids containing bias data.
107 fixed_size_vector_ref bias;
108 /// @brief Defines a shift, relative to (0,0) position of the input buffer, where (0,0) point of the convolution window should start calculations.
110 /// @brief Defines shift in input buffer between adjacent calculations of output values.
112 /// @brief Defines gaps in the input - dilation rate k=1 is normal convolution, k=2 means skipping one pixel per input, k=4 means skipping 3 pixels.
113 /// As an example in one dimension, a filter w of size 3 would compute over input x the following: w[0]*x[0] + w[1]*x[1] + w[2]*x[2] for dilation of 1.
114 /// For dilation 2 the filter would instead compute w[0]*x[0] + w[1]*x[2] + w[2]*x[4].
116 /// @brief Enable Relu activation.
117 bool with_activation;
118 /// @brief Relu activation slope.
119 float activation_negative_slope;
120 /// @brief Indicates that the primitive has user-defined output size (non-zero value).
121 bool with_output_size;
122 /// @brief User-defined output data size of the primitive (w/o padding).
124 /// @brief Primitive id containing scale bias data for fused convolution.
125 primitive_id scale_bias;
126 /// @brief Primitive id containing inverted variance used in future gradient computing for fused convolution.
127 primitive_id inv_variance;
128 /// @brief Epsilon for fused convolution.
130 /// @brief On how many cards split the computation to.
131 int32_t split() const { return static_cast<int32_t>(weights.size()); }
134 primitive_id_arr _weights;
135 primitive_id_arr _bias;
137 std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
139 std::vector<std::reference_wrapper<const primitive_id>> ret;
140 ret.reserve(weights.size() + bias.size() + !scale_bias.empty() + !inv_variance.empty());
141 for (auto& w : weights)
145 if (!scale_bias.empty())
146 ret.push_back(scale_bias);
147 if (!inv_variance.empty())
148 ret.push_back(inv_variance);
152 void update_dto(dto& dto) const override
154 dto.weights = _weights.ref();
155 dto.bias = _bias.ref();
156 dto.input_offset = input_offset;
158 dto.dilation = dilation;
160 dto.with_activation = with_activation;
161 dto.activation_negative_slope = activation_negative_slope;
162 dto.epsilon = epsilon;
163 dto.inv_variance = inv_variance.c_str();
164 dto.scale_bias = scale_bias.c_str();