2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "api/CPP/deconvolution.hpp"
20 #include "primitive_inst.h"
26 struct typed_program_node<deconvolution> : public typed_program_node_base<deconvolution>
28 using parent = typed_program_node_base<deconvolution>;
31 typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog)
33 , split(this->get_primitive()->split())
34 , depthwise_sep_opt(false)
35 , groups(this->get_primitive()->groups)
37 support_padding(true);
41 void set_split(int32_t node_split) { split = node_split; }
42 int32_t get_split() const { return split; }
44 void set_depthwise_sep_opt(bool node_depthwise_sep_opt) { depthwise_sep_opt = node_depthwise_sep_opt; }
45 bool get_depthwise_sep_opt() const { return depthwise_sep_opt; }
47 void set_groups(uint32_t node_groups) { groups = node_groups; }
48 uint32_t get_groups() const { return groups; }
50 program_node& input() const { return get_dependency(0); }
52 program_node& weights(size_t idx = 0) const
54 if (static_cast<int32_t>(idx) >= get_split())
55 throw std::range_error("weights offset too big");
57 return get_dependency(1 + idx);
60 program_node& bias(size_t idx = 0) const
62 if (static_cast<int32_t>(idx) >= get_split())
63 throw std::range_error("bias offset too big");
65 return get_dependency(1 + this->get_split() + idx);
68 bool bias_term() const
70 if (get_primitive()->bias.size() != 0)
76 program_node& fused_sum(size_t idx = 0) const
78 if (static_cast<int32_t>(idx) > 0)
79 throw std::range_error("Only one input for fused sum is supported");
81 size_t d_idx = 1 + this->get_split() + idx;
82 d_idx += bias_term() ? this->get_split() : 0;
83 return get_dependency(d_idx);
86 bool has_fused_sum() const
88 size_t d_idx = 1 + this->get_split();
89 d_idx += bias_term() ? this->get_split() : 0;
90 return dependencies.size() == (d_idx + 1);
95 bool depthwise_sep_opt;
99 using deconvolution_node = typed_program_node<deconvolution>;
102 class typed_primitive_inst<deconvolution> : public typed_primitive_inst_base<deconvolution>
104 using parent = typed_primitive_inst_base<deconvolution>;
107 static layout calc_output_layout(deconvolution_node const& node);
108 static std::string to_string(deconvolution_node const& node);
111 typed_primitive_inst(network_impl& network, deconvolution_node const& node);
113 memory_impl& weights_memory(size_t index) const
115 if (node.get_groups() == 1) {
116 if (static_cast<int32_t>(index) >= node.get_split())
117 throw std::range_error("weights offset too big");
118 return dep_memory(1 + index);
120 else { // all weights are in one buffer
121 return dep_memory(1);
125 memory_impl& bias_memory(size_t index) const
127 if (node.get_groups() == 1) {
128 if (argument.bias.size() == 0 && static_cast<int32_t>(index) >= node.get_split())
129 throw std::range_error("no bias data");
130 if (static_cast<int32_t>(index) > node.get_split())
131 throw std::range_error("bias offset too big");
132 return dep_memory(1 + node.get_split() + index);
134 else { // all bias are in one buffer
135 return dep_memory(2);
139 bool bias_term() const
141 if (argument.bias.size() != 0)
148 using deconvolution_inst = typed_primitive_inst<deconvolution>;