2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "batch_norm_inst.h"
18 #include "primitive_type_base.h"
19 #include "error_handler.h"
20 #include "json_object.h"
21 #include "mutable_data_inst.h"
25 primitive_type_id batch_norm_type_id()
27 static primitive_type_base<batch_norm> instance;
31 layout batch_norm_inst::calc_output_layout(batch_norm_node const& node)
33 assert((bool)node.get_primitive()->output_data_type == false
34 && "Output data type forcing is not supported for batch_norm_node!");
35 return node.input().get_non_padded_output_layout();
38 std::string batch_norm_inst::to_string(batch_norm_node const& node)
40 bool variance_term = node.variance_term();
42 std::stringstream primitive_description;
43 json_composite batch_norm_info;
44 if (node.use_global_stats())
46 batch_norm_info.add("mean_id", node.mean().id());
49 batch_norm_info.add("variance_id", node.variance().id());
52 if (node.use_scale_shift())
54 batch_norm_info.add("scale_id", node.scale().id());
55 batch_norm_info.add("shift_id", node.shift().id());
57 if (node.forwad_pass())
59 batch_norm_info.add("inv_var", node.inv_variance().id());
61 batch_norm_info.add("epsilon", node.get_primitive()->epsilon);
63 node.desc_to_json()->add("batch norm info", batch_norm_info);
64 node.desc_to_json()->dump(primitive_description);
66 return primitive_description.str();
70 batch_norm_inst::typed_primitive_inst(network_impl& network, batch_norm_node const& node)
71 :parent(network, node)
73 if (use_global_stats())
75 auto mean_format = node.mean().get_output_layout().format;
76 auto variance_format = node.variance().get_output_layout().format;
78 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Mean format", mean_format.value, "supported mean formats", format::yxfb, format::bfyx, format::byxf);
79 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Variance format", variance_format.value, "supported variance formats", format::yxfb, format::bfyx, format::byxf);
81 auto is_mean_mutable_data = node.mean().is_type<mutable_data>();
82 auto is_var_mutable_data = node.variance().is_type<mutable_data>();
84 CLDNN_ERROR_BOOL(node.id(), "mean and variance are not the same type", (is_mean_mutable_data != is_var_mutable_data), "");
87 if (use_scale_shift()) {
88 auto scale_format = node.scale().get_output_layout().format;
89 auto shift_format = node.shift().get_output_layout().format;
91 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Scale format", scale_format.value, "supported scale formats", format::yxfb, format::bfyx, format::byxf);
92 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Shift format", shift_format.value, "supported shift formats", format::yxfb, format::bfyx, format::byxf);
97 auto is_inv_var_mutable_data = node.inv_variance().is_type<mutable_data>();
98 CLDNN_ERROR_BOOL(node.id(), "inv_variance is not mutable_data type", !is_inv_var_mutable_data, "");