2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "api/CPP/batch_norm.hpp"
20 #include "primitive_inst.h"
21 #include "mutable_data_inst.h"
27 struct typed_program_node<batch_norm> : public typed_program_node_base<batch_norm>
29 using parent = typed_program_node_base<batch_norm>;
34 program_node& input() const { return get_dependency(0); }
35 program_node& mean() const { return get_dependency(1); }
36 program_node& variance() const { return get_dependency(2); }
37 program_node& scale() const
39 if(get_dependencies().size() >= 5)
40 return get_dependency(3);
42 return get_dependency(1);
44 program_node& shift() const
46 if (get_dependencies().size() >= 5)
47 return get_dependency(4);
49 return get_dependency(2);
51 program_node& inv_variance() const
53 if (get_dependencies().size() == 2)
54 return get_dependency(1);
55 else if (get_dependencies().size() == 6)
56 return get_dependency(5);
58 return get_dependency(3);
60 bool variance_term() const { return !get_primitive()->variance.empty(); }
61 bool use_global_stats() const { return !get_primitive()->mean.empty() && !get_primitive()->variance.empty(); };
62 bool use_scale_shift() const { return !get_primitive()->scale.empty() && !get_primitive()->shift.empty(); };
63 bool forwad_pass() const { return !get_primitive()->inv_variance.empty(); };
64 bool calc_mean_var() const { return (use_global_stats() && mean().is_type<mutable_data>() && variance().is_type<mutable_data>()); };
68 using batch_norm_node = typed_program_node<batch_norm>;
71 class typed_primitive_inst<batch_norm> : public typed_primitive_inst_base<batch_norm>
73 using parent = typed_primitive_inst_base<batch_norm>;
76 static layout calc_output_layout(batch_norm_node const& node);
77 static std::string to_string(batch_norm_node const& node);
80 typed_primitive_inst(network_impl& network, batch_norm_node const& node);
82 memory_impl& mean_memory() const { return dep_memory(1); }
83 memory_impl& variance_memory() const { return dep_memory(2); }
84 memory_impl& scale_memory() const
86 if (dependencies().size() >= 5)
91 memory_impl& shift_memory() const
93 if (dependencies().size() >= 5)
98 memory_impl& inv_variance_memory() const
100 if (dependencies().size() == 2)
101 return dep_memory(1);
102 else if (dependencies().size() == 6)
103 return dep_memory(5);
105 return dep_memory(3);
107 bool use_global_stats() const { return !argument.mean.empty() && !argument.variance.empty(); };
108 bool use_scale_shift() const { return !argument.scale.empty() && !argument.scale.empty(); };
109 bool forwad_pass() const { return !argument.inv_variance.empty(); };
110 bool calc_mean_var() const { return node.calc_mean_var(); };
113 using batch_norm_inst = typed_primitive_inst<batch_norm>;