2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "../C/fully_connected.h"
20 #include "primitive.hpp"
24 /// @addtogroup cpp_api C++ API
26 /// @addtogroup cpp_topology Network Topology
28 /// @addtogroup cpp_primitives Primitives
31 /// @brief Performs forward fully connected layer (inner product).
32 /// Also supports built-in Relu @CLDNN_PRIMITIVE_DESC{activation} available by setting it in arguments.
34 /// - Equation: Input[F x Y x F] x Output(X) == Weights(B x F x X x F) has to be fulfilled
35 /// - Bias has to be linear data [1,1,1,X], where X is equal to number of outputs.
38 /// <caption id = "multi_row">Format support</caption>
39 /// <tr><th>Data type <th>activation format <th>weights format
40 /// <tr><td rowspan="7">F32 <td rowspan="4">bfyx <td>yxfb
42 /// <tr> <td>bs_xs_xsv8_bsv8
43 /// <tr> <td>bs_x_bsv16
44 /// <tr> <td rowspan="3">yxfb <td>bfyx
46 /// <tr> <td>bs_xs_xsv8_bsv8
47 /// <tr><td rowspan="4">F16 <td rowspan="3">bfyx <td>yxfb
49 /// <tr> <td>bs_x_bsv16
50 /// <tr> <td >yxfb <td>bfyx
53 struct fully_connected : public primitive_base<fully_connected, CLDNN_PRIMITIVE_DESC(fully_connected)> {
54 CLDNN_DECLARE_PRIMITIVE(fully_connected)
56 /// @brief Constructs fully connected layer.
57 /// @param id This primitive id.
58 /// @param input Input primitive id.
59 /// @param weights Primitive id containing weights data.
60 /// @param bias Primitive id containing bias data. Provide empty string if using Relu without bias.
61 /// @param with_activation Enable Relu activation.
62 /// @param activation_slp Relu activation slope.
63 fully_connected(const primitive_id& id,
64 const primitive_id& input,
65 const primitive_id& weights,
66 const primitive_id& bias = "",
67 bool with_activation = false,
68 float activation_slp = 0.0f,
69 const padding& output_padding = padding())
70 : primitive_base(id, {input}, output_padding),
73 weights_quantization_factors(""),
74 output_calibration_factors(""),
75 input_quantization_factor(1.0f),
76 output_quantization_factor(1.0f),
77 with_activation(with_activation),
78 activation_negative_slope(activation_slp) {}
80 /// @brief Constructs fully connected layer.
81 /// @param id This primitive id.
82 /// @param input Input primitive id.
83 /// @param weights Primitive id containing weights data.
84 /// @param bias Primitive id containing bias data. Provide empty string if using Relu without bias.
85 /// @param w_quantization_factor Primitive id containing weights quanitization factors per output feature map.
86 /// @param i_quantization_factor Input quantization factor
87 /// @param o_quantization_factor Output quantization factor
88 /// @param with_activation Enable Relu activation.
89 /// @param activation_slp Relu activation slope.
90 fully_connected(const primitive_id& id,
91 const primitive_id& input,
92 const primitive_id& weights,
93 const primitive_id& bias,
94 const primitive_id& w_quantization_factor,
95 const float i_quantization_factor,
96 const float o_quantization_factor,
97 bool with_activation = false,
98 float activation_slp = 0.0f,
99 const padding& output_padding = padding())
100 : primitive_base(id, {input}, output_padding),
103 weights_quantization_factors(w_quantization_factor),
104 output_calibration_factors(""),
105 input_quantization_factor(i_quantization_factor),
106 output_quantization_factor(o_quantization_factor),
107 with_activation(with_activation),
108 activation_negative_slope(activation_slp) {}
110 /// @brief Constructs fully connected layer.
111 /// @param id This primitive id.
112 /// @param input Input primitive id.
113 /// @param weights Primitive id containing weights data.
114 /// @param bias Primitive id containing bias data. Provide empty string if using Relu without bias.
115 /// @param w_quantization_factor Primitive id containing weights quanitization factors per output feature map.
116 /// @param output_calibration_factors Primitive id containing output calibration factors per output feature map.
117 /// @param i_quantization_factor Input quantization factor
118 /// @param with_activation Enable Relu activation.
119 /// @param activation_slp Relu activation slope.
120 fully_connected(const primitive_id& id,
121 const primitive_id& input,
122 const primitive_id& weights,
123 const primitive_id& bias,
124 const primitive_id& w_quantization_factor,
125 const primitive_id& output_calibration_factors,
126 const float i_quantization_factor,
127 bool with_activation = false,
128 float activation_slp = 0.0f,
129 const padding& output_padding = padding())
130 : primitive_base(id, {input}, output_padding),
133 weights_quantization_factors(w_quantization_factor),
134 output_calibration_factors(output_calibration_factors),
135 input_quantization_factor(i_quantization_factor),
136 output_quantization_factor(1.0f),
137 with_activation(with_activation),
138 activation_negative_slope(activation_slp) {}
140 /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{fully_connected}
141 fully_connected(const dto* dto)
142 : primitive_base(dto),
143 weights(dto->weights),
145 weights_quantization_factors(dto->weights_quantization_factors),
146 output_calibration_factors(dto->output_calibration_factors),
147 input_quantization_factor(dto->input_quantization_factor),
148 output_quantization_factor(dto->output_quantization_factor),
149 with_activation(dto->with_activation != 0),
150 activation_negative_slope(dto->activation_negative_slope) {}
152 /// @brief Primitive id containing weights data.
153 primitive_id weights;
154 /// @brief Primitive id containing bias data.
156 /// @brief Primitive id containing weights quanitization factors per output feature map.
157 primitive_id weights_quantization_factors;
158 /// @brief Primitive id containing output quanitization factors per output feature map.
159 primitive_id output_calibration_factors;
160 /// @brief Input quantization factor
161 float input_quantization_factor;
162 /// @brief Output quantization factor
163 float output_quantization_factor;
164 /// @brief Enable Relu activation.
165 bool with_activation;
166 /// @brief Relu activation slope.
167 float activation_negative_slope;
170 std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
171 std::vector<std::reference_wrapper<const primitive_id>> ret;
172 ret.push_back(weights);
177 if (!weights_quantization_factors.empty())
178 ret.push_back(weights_quantization_factors);
180 if (!output_calibration_factors.empty())
181 ret.push_back(output_calibration_factors);
186 void update_dto(dto& dto) const override {
187 dto.weights = weights.c_str();
188 dto.bias = bias.c_str();
189 dto.weights_quantization_factors = weights_quantization_factors.c_str();
190 dto.output_calibration_factors = output_calibration_factors.c_str();
191 dto.input_quantization_factor = input_quantization_factor;
192 dto.output_quantization_factor = output_quantization_factor;
193 dto.with_activation = with_activation;
194 dto.activation_negative_slope = activation_negative_slope;