Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / fused_conv_bn_scale_inst.h
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "api_extension/CPP/fused_conv_bn_scale.hpp"
20 #include "primitive_inst.h"
21
22 #include <memory>
23
24 namespace cldnn
25 {
26
27 template <>
28 struct typed_program_node<fused_conv_bn_scale> : public typed_program_node_base<fused_conv_bn_scale>
29 {
30     using parent = typed_program_node_base<fused_conv_bn_scale>;
31
32 public:
33     typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog)
34         : parent(prim, prog)
35         , split(this->get_primitive()->split())
36     {
37     }
38
39     void set_split(int32_t node_split) { split = node_split; }
40     int32_t get_split() const { return split; }
41
42     program_node& input(size_t idx = 0) const
43     {
44         if (static_cast<int32_t>(idx) >= static_cast<int32_t>(desc->input.size()))
45             throw std::range_error("input index too big");
46
47         return get_dependency(idx);
48     }
49
50     program_node& weights(size_t idx = 0) const
51     {
52         if (static_cast<int32_t>(idx) >= this->get_split())
53             throw std::range_error("weights offset too big");
54
55         return get_dependency(desc->input.size() + idx);
56     }
57
58     program_node& bias(size_t idx = 0) const
59     { 
60         if (static_cast<int32_t>(idx) >= this->get_split())
61             throw std::range_error("bias offset too big");
62
63         return get_dependency(desc->input.size() + this->get_split() + idx);
64     }
65
66     program_node& weights_quantization_factors(size_t idx = 0) const
67     {
68         if (static_cast<int32_t>(idx) >= this->get_split())
69             throw std::range_error("quantization factor offset too big");
70
71         return get_dependency(desc->input.size() + 2*this->get_split() + idx);
72     }
73
74     program_node& output_calibration_factors(size_t idx = 0) const
75     {
76         if (static_cast<int32_t>(idx) >= this->get_split())
77             throw std::range_error("calibration factor offset too big");
78
79         return get_dependency(desc->input.size() + 3 * this->get_split() + idx);
80     }
81
82     bool bias_term() const
83     {
84         return get_primitive()->bias.size() > 0;
85     }
86
87     bool scale_bias_term() const
88     {
89         return !get_primitive()->scale_bias.empty();
90     }
91
92     bool is_fused_in_training() const
93     {
94         return !get_primitive()->inv_variance.empty();
95     }
96
97 private:
98     int32_t split;
99 };
100
101 using fused_conv_bn_scale_node = typed_program_node<fused_conv_bn_scale>;
102
103 template <>
104 class typed_primitive_inst<fused_conv_bn_scale> : public typed_primitive_inst_base<fused_conv_bn_scale>
105 {
106     using parent = typed_primitive_inst_base<fused_conv_bn_scale>;
107
108 public:
109     static layout calc_output_layout(fused_conv_bn_scale_node const& node);
110     static std::string to_string(fused_conv_bn_scale_node const& node);
111
112 public:
113     typed_primitive_inst(network_impl& network, fused_conv_bn_scale_node const& node);
114
115     memory_impl& weights_memory(size_t index) const
116     {
117         if (static_cast<int32_t>(index) >= node.get_split())
118             throw std::range_error("weights offset too big");
119         
120         return dep_memory(inputs_memory_count() + index);
121     }
122
123     memory_impl& bias_memory(size_t index) const
124     { 
125         if (static_cast<int32_t>(index) >= node.get_split())
126             throw std::range_error("bias offset too big");
127
128         return dep_memory(inputs_memory_count() + node.get_split() + index);
129     }
130
131     bool bias_term() const
132     {
133         return node.bias_term();
134     }
135
136     bool scale_bias_term() const
137     {
138         return node.scale_bias_term();
139     }
140
141     bool is_fused_in_training() const
142     {
143         return node.is_fused_in_training();
144     }
145 };
146
147 using fused_conv_bn_scale_inst = typed_primitive_inst<fused_conv_bn_scale>;
148
149 }