Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / batch_norm.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "batch_norm_inst.h"
18 #include "primitive_type_base.h"
19 #include "error_handler.h"
20 #include "json_object.h"
21 #include "mutable_data_inst.h"
22
23 namespace cldnn
24 {
25 primitive_type_id batch_norm_type_id()
26 {
27     static primitive_type_base<batch_norm> instance;
28     return &instance;
29 }
30
31 layout batch_norm_inst::calc_output_layout(batch_norm_node const& node)
32 {
33     assert((bool)node.get_primitive()->output_data_type == false
34            && "Output data type forcing is not supported for batch_norm_node!");
35     return node.input().get_non_padded_output_layout();
36 }
37
38 std::string batch_norm_inst::to_string(batch_norm_node const& node)
39 {
40     bool variance_term = node.variance_term();
41
42     std::stringstream primitive_description;
43     json_composite batch_norm_info;
44     if (node.use_global_stats())
45     {
46         batch_norm_info.add("mean_id", node.mean().id());
47         if (variance_term)
48         {
49             batch_norm_info.add("variance_id", node.variance().id());
50         }
51     }
52     if (node.use_scale_shift())
53     {
54         batch_norm_info.add("scale_id", node.scale().id());
55         batch_norm_info.add("shift_id", node.shift().id());
56     }
57     if (node.forwad_pass())
58     {
59         batch_norm_info.add("inv_var", node.inv_variance().id());
60     }
61     batch_norm_info.add("epsilon", node.get_primitive()->epsilon);
62
63     node.desc_to_json()->add("batch norm info", batch_norm_info);
64     node.desc_to_json()->dump(primitive_description);
65
66     return primitive_description.str();
67 }
68
69
70 batch_norm_inst::typed_primitive_inst(network_impl& network, batch_norm_node const& node)
71     :parent(network, node) 
72 {
73     if (use_global_stats())
74     {
75         auto mean_format = node.mean().get_output_layout().format;
76         auto variance_format = node.variance().get_output_layout().format;
77
78         CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Mean format", mean_format.value, "supported mean formats", format::yxfb, format::bfyx, format::byxf);
79         CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Variance format", variance_format.value, "supported variance formats", format::yxfb, format::bfyx, format::byxf);
80
81                 auto is_mean_mutable_data = node.mean().is_type<mutable_data>();
82                 auto is_var_mutable_data = node.variance().is_type<mutable_data>();
83
84                 CLDNN_ERROR_BOOL(node.id(), "mean and variance are not the same type", (is_mean_mutable_data != is_var_mutable_data), "");
85     }
86
87         if (use_scale_shift()) {
88                 auto scale_format = node.scale().get_output_layout().format;
89                 auto shift_format = node.shift().get_output_layout().format;
90
91                 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Scale format", scale_format.value, "supported scale formats", format::yxfb, format::bfyx, format::byxf);
92                 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Shift format", shift_format.value, "supported shift formats", format::yxfb, format::bfyx, format::byxf);
93         }
94
95         if (forwad_pass())
96         {
97                 auto is_inv_var_mutable_data = node.inv_variance().is_type<mutable_data>();
98                 CLDNN_ERROR_BOOL(node.id(), "inv_variance is not mutable_data type", !is_inv_var_mutable_data, "");
99         }
100 }
101 }