2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "api/CPP/convolution_grad_weights.hpp"
20 #include "primitive_inst.h"
26 struct typed_program_node<convolution_grad_weights> : public typed_program_node_base<convolution_grad_weights>
28 using parent = typed_program_node_base<convolution_grad_weights>;
31 typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog)
33 , split(this->get_primitive()->split())
34 , depthwise_sep_opt(false)
39 void set_split(int32_t node_split) { split = node_split; }
40 int32_t get_split() const { return split; }
42 void set_depthwise_sep_opt(bool node_depthwise_sep_opt) { depthwise_sep_opt = node_depthwise_sep_opt; }
43 bool get_depthwise_sep_opt() const { return depthwise_sep_opt; }
45 program_node& input(size_t idx = 0) const { return get_dependency(idx); }
47 program_node& weights(size_t idx = 0) const
49 if (static_cast<int32_t>(idx) >= get_split())
50 throw std::range_error("weights offset too big");
52 return get_dependency(2 + idx);
55 program_node& bias(size_t idx = 0) const
57 if (static_cast<int32_t>(idx) >= get_split())
58 throw std::range_error("bias offset too big");
60 return get_dependency(2 + this->get_split() + idx);
63 program_node& prev_weights_grad(size_t idx = 0) const
65 if (static_cast<int32_t>(idx) >= get_split())
66 throw std::range_error("prev weights grad offset too big");
67 return get_dependency(2 + (bias_term() ? 2 : 1) * get_split() + idx);
70 program_node& prev_bias_grad(size_t idx = 0) const
72 if (static_cast<int32_t>(idx) >= get_split())
73 throw std::range_error("prev bias grad offset too big");
74 return get_dependency(2 + 3 * get_split() + idx);
77 bool use_momentum() const
79 if (get_primitive()->prev_weights_grad.size() != 0)
85 bool bias_term() const
87 if (get_primitive()->bias.size() != 0)
93 bool output_grad_w() const
95 return get_primitive()->output_grad_w;
100 bool depthwise_sep_opt;
103 using convolution_grad_weights_node = typed_program_node<convolution_grad_weights>;
106 class typed_primitive_inst<convolution_grad_weights> : public typed_primitive_inst_base<convolution_grad_weights>
108 using parent = typed_primitive_inst_base<convolution_grad_weights>;
111 static layout calc_output_layout(convolution_grad_weights_node const& node);
112 static std::string to_string(convolution_grad_weights_node const& node);
115 typed_primitive_inst(network_impl& network, convolution_grad_weights_node const& node);
117 memory_impl& weights_memory(size_t index) const
119 if (static_cast<int32_t>(index) >= node.get_split())
120 throw std::range_error("weights offset too big");
122 return dep_memory(2 + index);
125 memory_impl& bias_memory(size_t index) const
127 if (argument.bias.size() == 0 && static_cast<int32_t>(index) >= node.get_split())
128 throw std::range_error("no bias data");
130 if (static_cast<int32_t>(index) > node.get_split())
131 throw std::range_error("bias offset too big");
133 return dep_memory(2 + node.get_split() + index);
136 memory_impl& prev_weights_grad(size_t index) const
138 if(argument.prev_weights_grad.size() == 0 && static_cast<int32_t>(index) >= node.get_split())
139 throw std::range_error("no prev weights grad data");
141 if (static_cast<int32_t>(index) >= node.get_split())
142 throw std::range_error("prev weights grad offset too big");
144 return dep_memory(2 + (bias_term() ? 2 : 1) * node.get_split() + index);
147 memory_impl& prev_bias_grad(size_t index) const
149 if (argument.prev_bias_grad.size() == 0 && static_cast<int32_t>(index) >= node.get_split())
150 throw std::range_error("no prev bias grad data");
152 if (static_cast<int32_t>(index) >= node.get_split())
153 throw std::range_error("prev bias grad offset too big");
155 return dep_memory(2 + 3 * node.get_split() + index);
158 bool use_momentum() const
160 if (argument.prev_weights_grad.size() != 0)
166 bool bias_term() const
168 if (argument.bias.size() != 0)
174 bool output_grad_w() const
176 return argument.output_grad_w;
180 using convolution_grad_weights_inst = typed_primitive_inst<convolution_grad_weights>;