2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "api_extension/CPP/fused_conv_eltwise.hpp"
20 #include "primitive_inst.h"
28 struct typed_program_node<fused_conv_eltwise> : public typed_program_node_base<fused_conv_eltwise>
30 using parent = typed_program_node_base<fused_conv_eltwise>;
33 typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog)
35 , split(this->get_primitive()->split())
36 , depthwise_sep_opt(false)
38 , conv_input_qf(this->get_primitive()->conv.input_quantization_factor)
39 , conv_output_qf(this->get_primitive()->conv.output_quantization_factor)
43 void set_split(int32_t node_split) { split = node_split; }
44 int32_t get_split() const { return split; }
46 void set_depthwise_sep_opt(bool node_depthwise_sep_opt) { depthwise_sep_opt = node_depthwise_sep_opt; }
47 bool get_depthwise_sep_opt() const { return depthwise_sep_opt; }
49 void set_transposed(bool node_transposed) { transposed = node_transposed; }
50 bool get_transposed() const { return transposed; }
52 program_node& input(size_t idx = 0) const
54 if (static_cast<int32_t>(idx) >= static_cast<int32_t>(desc->input.size()))
55 throw std::range_error("input index too big");
57 return get_dependency(idx);
60 program_node& weights(size_t idx = 0) const
62 if (static_cast<int32_t>(idx) >= this->get_split())
63 throw std::range_error("weights offset too big");
65 return get_dependency(desc->input.size() + idx);
68 program_node& bias(size_t idx = 0) const
70 if (static_cast<int32_t>(idx) >= this->get_split())
71 throw std::range_error("bias offset too big");
73 return get_dependency(desc->input.size() + this->get_split() + idx);
76 program_node& weights_quantization_factors(size_t idx = 0) const
78 if (static_cast<int32_t>(idx) >= this->get_split())
79 throw std::range_error("quantization factor offset too big");
81 return get_dependency(desc->input.size() + 2 * this->get_split() + idx);
84 program_node& conv_output_calibration_factors(size_t idx = 0) const
86 if (static_cast<int32_t>(idx) >= this->get_split())
87 throw std::range_error("calibration factor offset too big");
89 return get_dependency(desc->input.size() + 3 * this->get_split() + idx);
92 program_node& eltw_output_calibration_factors() const
94 return get_dependency(desc->input.size() + 4 * this->get_split());
97 bool bias_term() const
99 return get_primitive()->conv.bias.size() > 0;
102 bool weights_quantization_term() const
104 return get_primitive()->conv.weights_quantization_factors.size() > 0;
107 bool conv_output_calibration_term() const
109 return get_primitive()->conv.output_calibration_factors.size() > 0;
112 bool eltw_output_calibration_term() const
114 return get_primitive()->eltw.output_calibration_factors.size() > 0;
117 float get_conv_input_qf() const { return conv_input_qf; }
118 float get_conv_output_qf() const { return conv_output_qf; }
119 float get_eltw_output_qf() const { return eltw_output_qf; }
123 bool depthwise_sep_opt;
126 float conv_output_qf;
127 float eltw_output_qf;
130 using fused_conv_eltwise_node = typed_program_node<fused_conv_eltwise>;
133 class typed_primitive_inst<fused_conv_eltwise> : public typed_primitive_inst_base<fused_conv_eltwise>
135 using parent = typed_primitive_inst_base<fused_conv_eltwise>;
138 static layout calc_output_layout(fused_conv_eltwise_node const& node);
139 static std::string to_string(fused_conv_eltwise_node const& node);
142 typed_primitive_inst(network_impl& network, fused_conv_eltwise_node const& node);
144 memory_impl& weights_memory(size_t index) const
146 if (static_cast<int32_t>(index) >= node.get_split())
147 throw std::range_error("weights offset too big");
149 return dep_memory(2 + index);
152 memory_impl& bias_memory(size_t index) const
154 if (static_cast<int32_t>(index) >= node.get_split())
155 throw std::range_error("bias offset too big");
157 return dep_memory(2 + node.get_split() + index);
160 memory_impl& weights_quantization_factors_memory(size_t index) const
162 if (static_cast<int32_t>(index) >= node.get_split())
163 throw std::range_error("quantization factors offset too big");
165 return dep_memory(2 + 2*node.get_split() + index);
168 memory_impl& output_calibration_factors_memory(size_t index) const
170 if (static_cast<int32_t>(index) >= node.get_split())
171 throw std::range_error("quantization factors offset too big");
173 return dep_memory(2 + 3 * node.get_split() + index);
176 memory_impl& eltw_output_calibration_factors_memory() const
178 return dep_memory(2 + 4 * node.get_split());
181 bool bias_term() const
183 return node.bias_term();
186 bool weights_quantization_factors_term() const
188 return node.weights_quantization_term();
191 bool conv_output_calibration_factors_term() const
193 return node.conv_output_calibration_term();
196 bool eltw_output_calibration_factors_term() const
198 return node.eltw_output_calibration_term();
202 using fused_conv_eltwise_inst = typed_primitive_inst<fused_conv_eltwise>;