2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "primitive.hpp"
23 /// @addtogroup cpp_api C++ API
25 /// @addtogroup cpp_topology Network Topology
27 /// @addtogroup cpp_primitives Primitives
30 /// @brief Performs backward fully connected layer (inner product) for weights and biases.
31 struct fully_connected_grad_weights
32 : public primitive_base<fully_connected_grad_weights> {
33 CLDNN_DECLARE_PRIMITIVE(fully_connected_grad_weights)
35 /// @brief Constructs fully connected layer for weights and biases.
36 /// @param id This primitive id.
37 /// @param input Input gradient primitive id.
38 /// @param input Input primitive id.
39 /// @param weights Primitive id containing weights data.
40 /// @param bias Primitive id containing bias data. Provide empty string if using Relu without bias.
41 /// @param fc_grad Id of primitive which uses weights and biases updated in this primitive.
42 /// This is for correct order of calculating. Leave empty if primitive is last in backward pass.
43 fully_connected_grad_weights(const primitive_id& id,
44 const primitive_id& input_grad,
45 const primitive_id& input,
46 const primitive_id& weights,
47 const primitive_id& bias = "",
48 const primitive_id& fc_grad = "",
49 const padding& output_padding = padding())
50 : primitive_base(id, {input_grad, input}, output_padding),
54 prev_weights_grad(""),
57 /// @brief Constructs fully connected layer for weights and biases with momentum optimizer.
58 /// @param id This primitive id.
59 /// @param input Input gradient primitive id.
60 /// @param input Input primitive id.
61 /// @param weights Primitive id containing weights data.
62 /// @param bias Primitive id containing bias data. Provide empty string if using Relu without bias.
63 /// @param prev_weights_grad Id of primitive which contains weights gradient data calculated in previous iteration. Used in momentum optimizer.
64 /// @param prev_bias_grad Id of primitive which contains bias gradient data calculated in previous iteration. Used in momentum optimizer.
65 /// @param fc_grad Id of primitive which uses weights and biases updated in this primitive. This is for correct order of calculating.
66 fully_connected_grad_weights(const primitive_id& id,
67 const primitive_id& input_grad,
68 const primitive_id& input,
69 const primitive_id& weights,
70 const primitive_id& bias,
71 const primitive_id& prev_weights_grad,
72 const primitive_id& prev_bias_grad,
73 const primitive_id& fc_grad = "",
74 const padding& output_padding = padding())
75 : primitive_base(id, {input_grad, input}, output_padding),
79 prev_weights_grad(prev_weights_grad),
80 prev_bias_grad(prev_bias_grad) {}
82 /// @brief Primitive id containing weights data.
84 /// @brief Primitive id containing bias data.
86 /// @brief Primitive id containing fully connected gradient data.
88 /// @brief Id of primitive containing weights gradient data calculated in previous iteration. It's memory size should be same as weights.
89 primitive_id prev_weights_grad;
90 /// @brief Id of primitive containing bias gradient data calculated in previous iteration. It's memory size should be same as biases.
91 primitive_id prev_bias_grad;
94 std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
95 std::vector<std::reference_wrapper<const primitive_id>> ret;
96 ret.reserve(1 + !bias.empty() + !fc_grad.empty() + !prev_weights_grad.empty() + !prev_bias_grad.empty());
98 ret.push_back(weights);
102 if (!prev_weights_grad.empty())
103 ret.push_back(prev_weights_grad);
104 if (!prev_bias_grad.empty())
105 ret.push_back(prev_bias_grad);
106 if (!fc_grad.empty())
107 ret.push_back(fc_grad);