Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / convolution_grad_weights_inst.h
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "api/CPP/convolution_grad_weights.hpp"
20 #include "primitive_inst.h"
21
22 namespace cldnn
23 {
24
25 template <>
26 struct typed_program_node<convolution_grad_weights> : public typed_program_node_base<convolution_grad_weights>
27 {
28     using parent = typed_program_node_base<convolution_grad_weights>;
29
30 public:
31     typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog)
32         : parent(prim, prog)
33         , split(this->get_primitive()->split())
34         , depthwise_sep_opt(false)
35     {
36     }
37
38     
39     void set_split(int32_t node_split) { split = node_split; }
40     int32_t get_split() const { return split; }
41
42     void set_depthwise_sep_opt(bool node_depthwise_sep_opt) { depthwise_sep_opt = node_depthwise_sep_opt; }
43     bool get_depthwise_sep_opt() const { return depthwise_sep_opt; }
44
45     program_node& input(size_t idx = 0) const { return get_dependency(idx); }
46
47     program_node& weights(size_t idx = 0) const
48     {
49         if (static_cast<int32_t>(idx) >= get_split())
50             throw std::range_error("weights offset too big");
51
52         return get_dependency(2 + idx);
53     }
54
55     program_node& bias(size_t idx = 0) const
56     { 
57         if (static_cast<int32_t>(idx) >= get_split())
58             throw std::range_error("bias offset too big");
59
60         return get_dependency(2 + this->get_split() + idx);
61     }
62
63     program_node& prev_weights_grad(size_t idx = 0) const
64     {
65         if (static_cast<int32_t>(idx) >= get_split())
66             throw std::range_error("prev weights grad offset too big");
67         return get_dependency(2 + (bias_term() ? 2 : 1) * get_split() + idx);
68     }
69
70     program_node& prev_bias_grad(size_t idx = 0) const
71     {
72         if (static_cast<int32_t>(idx) >= get_split())
73             throw std::range_error("prev bias grad offset too big");
74         return get_dependency(2 + 3 * get_split() + idx);
75     }
76
77     bool use_momentum() const
78     {
79         if (get_primitive()->prev_weights_grad.size() != 0)
80             return true;
81         else
82             return false;
83     }
84
85     bool bias_term() const
86     {
87         if (get_primitive()->bias.size() != 0)
88             return true;
89         else
90             return false;
91     }
92
93     bool output_grad_w() const
94     {
95         return get_primitive()->output_grad_w;
96     }
97
98 private:
99     int32_t split;
100     bool depthwise_sep_opt;
101 };
102
103 using convolution_grad_weights_node = typed_program_node<convolution_grad_weights>;
104
105 template <>
106 class typed_primitive_inst<convolution_grad_weights> : public typed_primitive_inst_base<convolution_grad_weights>
107 {
108     using parent = typed_primitive_inst_base<convolution_grad_weights>;
109
110 public:
111     static layout calc_output_layout(convolution_grad_weights_node const& node);
112     static std::string to_string(convolution_grad_weights_node const& node);
113
114 public:
115     typed_primitive_inst(network_impl& network, convolution_grad_weights_node const& node);
116
117     memory_impl& weights_memory(size_t index) const
118     {
119         if (static_cast<int32_t>(index) >= node.get_split())
120             throw std::range_error("weights offset too big");
121
122         return dep_memory(2 + index);
123     }
124
125     memory_impl& bias_memory(size_t index) const
126     {
127         if (argument.bias.size() == 0 && static_cast<int32_t>(index) >= node.get_split())
128             throw std::range_error("no bias data");
129
130         if (static_cast<int32_t>(index) > node.get_split())
131             throw std::range_error("bias offset too big");
132
133         return dep_memory(2 + node.get_split() + index);
134     }
135
136     memory_impl& prev_weights_grad(size_t index) const
137     {
138         if(argument.prev_weights_grad.size() == 0 && static_cast<int32_t>(index) >= node.get_split())
139             throw std::range_error("no prev weights grad data");
140
141         if (static_cast<int32_t>(index) >= node.get_split())
142             throw std::range_error("prev weights grad offset too big");
143
144         return dep_memory(2 + (bias_term() ? 2 : 1) * node.get_split() + index);
145     }
146
147     memory_impl& prev_bias_grad(size_t index) const
148     {
149         if (argument.prev_bias_grad.size() == 0 && static_cast<int32_t>(index) >= node.get_split())
150             throw std::range_error("no prev bias grad data");
151
152         if (static_cast<int32_t>(index) >= node.get_split())
153             throw std::range_error("prev bias grad offset too big");
154
155         return dep_memory(2 + 3 * node.get_split() + index);
156     }
157
158     bool use_momentum() const
159     {
160         if (argument.prev_weights_grad.size() != 0)
161             return true;
162         else
163             return false;
164     }
165
166     bool bias_term() const
167     {
168         if (argument.bias.size() != 0)
169             return true;
170         else
171             return false;
172     }
173
174     bool output_grad_w() const
175     {
176         return argument.output_grad_w;
177     }
178 };
179
180 using convolution_grad_weights_inst = typed_primitive_inst<convolution_grad_weights>;
181
182 }