Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / fully_connected_grad_weights.hpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "primitive.hpp"
20 #include <vector>
21
22 namespace cldnn {
23 /// @addtogroup cpp_api C++ API
24 /// @{
25 /// @addtogroup cpp_topology Network Topology
26 /// @{
27 /// @addtogroup cpp_primitives Primitives
28 /// @{
29
30 /// @brief Performs backward fully connected layer (inner product) for weights and biases.
31 struct fully_connected_grad_weights
32     : public primitive_base<fully_connected_grad_weights> {
33     CLDNN_DECLARE_PRIMITIVE(fully_connected_grad_weights)
34
35     /// @brief Constructs fully connected layer for weights and biases.
36     /// @param id This primitive id.
37     /// @param input Input gradient primitive id.
38     /// @param input Input primitive id.
39     /// @param weights Primitive id containing weights data.
40     /// @param bias Primitive id containing bias data. Provide empty string if using Relu without bias.
41     /// @param fc_grad Id of primitive which uses weights and biases updated in this primitive.
42     /// This is for correct order of calculating. Leave empty if primitive is last in backward pass.
43     fully_connected_grad_weights(const primitive_id& id,
44                                  const primitive_id& input_grad,
45                                  const primitive_id& input,
46                                  const primitive_id& weights,
47                                  const primitive_id& bias = "",
48                                  const primitive_id& fc_grad = "",
49                                  const padding& output_padding = padding())
50         : primitive_base(id, {input_grad, input}, output_padding),
51           weights(weights),
52           bias(bias),
53           fc_grad(fc_grad),
54           prev_weights_grad(""),
55           prev_bias_grad("") {}
56
57     /// @brief Constructs fully connected layer for weights and biases with momentum optimizer.
58     /// @param id This primitive id.
59     /// @param input Input gradient primitive id.
60     /// @param input Input primitive id.
61     /// @param weights Primitive id containing weights data.
62     /// @param bias Primitive id containing bias data. Provide empty string if using Relu without bias.
63     /// @param prev_weights_grad Id of primitive which contains weights gradient data calculated in previous iteration. Used in momentum optimizer.
64     /// @param prev_bias_grad Id of primitive which contains bias gradient data calculated in previous iteration. Used in momentum optimizer.
65     /// @param fc_grad Id of primitive which uses weights and biases updated in this primitive. This is for correct order of calculating.
66     fully_connected_grad_weights(const primitive_id& id,
67                                  const primitive_id& input_grad,
68                                  const primitive_id& input,
69                                  const primitive_id& weights,
70                                  const primitive_id& bias,
71                                  const primitive_id& prev_weights_grad,
72                                  const primitive_id& prev_bias_grad,
73                                  const primitive_id& fc_grad = "",
74                                  const padding& output_padding = padding())
75         : primitive_base(id, {input_grad, input}, output_padding),
76           weights(weights),
77           bias(bias),
78           fc_grad(fc_grad),
79           prev_weights_grad(prev_weights_grad),
80           prev_bias_grad(prev_bias_grad) {}
81
82     /// @brief Primitive id containing weights data.
83     primitive_id weights;
84     /// @brief Primitive id containing bias data.
85     primitive_id bias;
86     /// @brief Primitive id containing fully connected gradient data.
87     primitive_id fc_grad;
88     /// @brief Id of primitive containing weights gradient data calculated in previous iteration. It's memory size should be same as weights.
89     primitive_id prev_weights_grad;
90     /// @brief Id of primitive containing bias gradient data calculated in previous iteration. It's memory size should be same as biases.
91     primitive_id prev_bias_grad;
92
93 protected:
94     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
95         std::vector<std::reference_wrapper<const primitive_id>> ret;
96         ret.reserve(1 + !bias.empty() + !fc_grad.empty() + !prev_weights_grad.empty() + !prev_bias_grad.empty());
97
98         ret.push_back(weights);
99         if (!bias.empty())
100             ret.push_back(bias);
101
102         if (!prev_weights_grad.empty())
103             ret.push_back(prev_weights_grad);
104         if (!prev_bias_grad.empty())
105             ret.push_back(prev_bias_grad);
106         if (!fc_grad.empty())
107             ret.push_back(fc_grad);
108
109         return ret;
110     }
111 };
112 /// @}
113 /// @}
114 /// @}
115 }  // namespace cldnn