Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / batch_norm_inst.h
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "api/CPP/batch_norm.hpp"
20 #include "primitive_inst.h"
21 #include "mutable_data_inst.h"
22
23 namespace cldnn
24 {
25
26 template <>
27 struct typed_program_node<batch_norm> : public typed_program_node_base<batch_norm>
28 {
29     using parent = typed_program_node_base<batch_norm>;
30
31 public:
32     using parent::parent;
33
34     program_node& input() const { return get_dependency(0); }
35     program_node& mean() const { return get_dependency(1); }
36     program_node& variance() const { return get_dependency(2); }
37     program_node& scale() const 
38         { 
39                 if(get_dependencies().size() >= 5)
40                         return get_dependency(3); 
41                 else
42                         return get_dependency(1);
43         }
44     program_node& shift() const 
45         { 
46                 if (get_dependencies().size() >= 5)
47                         return get_dependency(4); 
48                 else
49                         return get_dependency(2);
50         }
51     program_node& inv_variance() const 
52         { 
53                 if (get_dependencies().size() == 2)
54                         return get_dependency(1);
55                 else if (get_dependencies().size() == 6)
56                         return get_dependency(5);
57                 else
58                         return get_dependency(3);
59         };
60     bool variance_term() const { return !get_primitive()->variance.empty(); }
61     bool use_global_stats() const { return !get_primitive()->mean.empty() && !get_primitive()->variance.empty(); };
62     bool use_scale_shift() const { return !get_primitive()->scale.empty() && !get_primitive()->shift.empty(); };
63     bool forwad_pass() const { return !get_primitive()->inv_variance.empty(); };
64     bool calc_mean_var() const { return (use_global_stats() && mean().is_type<mutable_data>() && variance().is_type<mutable_data>()); };
65
66 };
67
68 using batch_norm_node = typed_program_node<batch_norm>;
69
70 template <>
71 class typed_primitive_inst<batch_norm> : public typed_primitive_inst_base<batch_norm>
72 {
73     using parent = typed_primitive_inst_base<batch_norm>;
74
75 public:
76     static layout calc_output_layout(batch_norm_node const& node);
77     static std::string to_string(batch_norm_node const& node);
78
79 public:
80     typed_primitive_inst(network_impl& network, batch_norm_node const& node);
81
82     memory_impl& mean_memory() const { return dep_memory(1); }
83     memory_impl& variance_memory() const { return dep_memory(2); }
84     memory_impl& scale_memory() const 
85         {
86                 if (dependencies().size() >= 5)
87                         return dep_memory(3);
88                 else
89                         return dep_memory(1);
90         }
91     memory_impl& shift_memory() const 
92         {
93                 if (dependencies().size() >= 5)
94                         return dep_memory(4);
95                 else
96                         return dep_memory(2);
97         }
98     memory_impl& inv_variance_memory() const
99         {
100                 if (dependencies().size() == 2)
101                         return dep_memory(1);
102                 else if (dependencies().size() == 6)
103                         return dep_memory(5);
104                 else
105                         return dep_memory(3);
106         };
107     bool use_global_stats() const { return !argument.mean.empty() && !argument.variance.empty(); };
108     bool use_scale_shift() const { return !argument.scale.empty() && !argument.scale.empty(); };
109     bool forwad_pass() const { return !argument.inv_variance.empty(); };
110     bool calc_mean_var() const { return node.calc_mean_var(); };
111 };
112
113 using batch_norm_inst = typed_primitive_inst<batch_norm>;
114
115 }