2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "api_extension/CPP/fused_conv_bn_scale.hpp"
20 #include "primitive_inst.h"
28 struct typed_program_node<fused_conv_bn_scale> : public typed_program_node_base<fused_conv_bn_scale>
30 using parent = typed_program_node_base<fused_conv_bn_scale>;
33 typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog)
35 , split(this->get_primitive()->split())
39 void set_split(int32_t node_split) { split = node_split; }
40 int32_t get_split() const { return split; }
42 program_node& input(size_t idx = 0) const
44 if (static_cast<int32_t>(idx) >= static_cast<int32_t>(desc->input.size()))
45 throw std::range_error("input index too big");
47 return get_dependency(idx);
50 program_node& weights(size_t idx = 0) const
52 if (static_cast<int32_t>(idx) >= this->get_split())
53 throw std::range_error("weights offset too big");
55 return get_dependency(desc->input.size() + idx);
58 program_node& bias(size_t idx = 0) const
60 if (static_cast<int32_t>(idx) >= this->get_split())
61 throw std::range_error("bias offset too big");
63 return get_dependency(desc->input.size() + this->get_split() + idx);
66 program_node& weights_quantization_factors(size_t idx = 0) const
68 if (static_cast<int32_t>(idx) >= this->get_split())
69 throw std::range_error("quantization factor offset too big");
71 return get_dependency(desc->input.size() + 2*this->get_split() + idx);
74 program_node& output_calibration_factors(size_t idx = 0) const
76 if (static_cast<int32_t>(idx) >= this->get_split())
77 throw std::range_error("calibration factor offset too big");
79 return get_dependency(desc->input.size() + 3 * this->get_split() + idx);
82 bool bias_term() const
84 return get_primitive()->bias.size() > 0;
87 bool scale_bias_term() const
89 return !get_primitive()->scale_bias.empty();
92 bool is_fused_in_training() const
94 return !get_primitive()->inv_variance.empty();
101 using fused_conv_bn_scale_node = typed_program_node<fused_conv_bn_scale>;
104 class typed_primitive_inst<fused_conv_bn_scale> : public typed_primitive_inst_base<fused_conv_bn_scale>
106 using parent = typed_primitive_inst_base<fused_conv_bn_scale>;
109 static layout calc_output_layout(fused_conv_bn_scale_node const& node);
110 static std::string to_string(fused_conv_bn_scale_node const& node);
113 typed_primitive_inst(network_impl& network, fused_conv_bn_scale_node const& node);
115 memory_impl& weights_memory(size_t index) const
117 if (static_cast<int32_t>(index) >= node.get_split())
118 throw std::range_error("weights offset too big");
120 return dep_memory(inputs_memory_count() + index);
123 memory_impl& bias_memory(size_t index) const
125 if (static_cast<int32_t>(index) >= node.get_split())
126 throw std::range_error("bias offset too big");
128 return dep_memory(inputs_memory_count() + node.get_split() + index);
131 bool bias_term() const
133 return node.bias_term();
136 bool scale_bias_term() const
138 return node.scale_bias_term();
141 bool is_fused_in_training() const
143 return node.is_fused_in_training();
147 using fused_conv_bn_scale_inst = typed_primitive_inst<fused_conv_bn_scale>;