2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "pass_manager.h"
20 #include "batch_norm_inst.h"
21 #include "reshape_inst.h"
23 using namespace cldnn;
25 //Some primitives require a specific shape for thier inputs/parameters.
26 //We should check this and add reshape to be compliant with this.
28 //Example: batch_norm primitive requires that mean/variance/scale/shift is shape {1, X, 1, 1}
29 void add_reshape_to_primitives::run(program_impl& p)
31 auto processing_order = p.get_processing_order();
33 for (auto& node : processing_order)
35 //if node is batch_norm and mean/var are given (i.e. use eltwise kernel to calculate batch_norm)
36 if (node->is_type<batch_norm>() &&
37 (!node->as<batch_norm>().calc_mean_var() && node->as<batch_norm>().use_global_stats()))
39 auto mean_layout = node->as<batch_norm>().mean().get_output_layout();
40 auto mean_size = mean_layout.size;
41 auto mean_x = mean_size.spatial[0];
42 auto mean_y = mean_size.spatial[1];
43 auto mean_b = mean_size.batch[0];
49 auto mean_name = node->as<batch_norm>().mean().id();
50 std::vector<int32_t> mean_sizes = mean_size.sizes();
51 int32_t mean_max_size = *std::max_element(std::begin(mean_sizes), std::end(mean_sizes));
53 auto r_prim = std::make_shared<reshape>("reshape_" + mean_name + "_" + node->id(), mean_name, tensor(1, mean_max_size, 1, 1));
54 auto& r_prim_node = p.get_or_create(r_prim);
56 p.add_intermediate(r_prim_node, *node, 1, true);
59 auto variance_size = node->as<batch_norm>().variance().get_output_layout().size;
60 auto variance_x = variance_size.spatial[0];
61 auto variance_y = variance_size.spatial[1];
62 auto variance_b = variance_size.batch[0];
68 auto variance_name = node->as<batch_norm>().variance().id();
69 std::vector<int32_t> variance_sizes = variance_size.sizes();
70 int32_t variance_max_size = *std::max_element(std::begin(variance_sizes), std::end(variance_sizes));
72 auto r_prim = std::make_shared<reshape>("reshape_" + variance_name + "_" + node->id(), variance_name, tensor(1, variance_max_size, 1, 1));
73 auto& r_prim_node = p.get_or_create(r_prim);
75 p.add_intermediate(r_prim_node, *node, 2, true);
78 if (node->as<batch_norm>().use_scale_shift())
80 auto scale_size = node->as<batch_norm>().scale().get_output_layout().size;
81 auto scale_x = scale_size.spatial[0];
82 auto scale_y = scale_size.spatial[1];
83 auto scale_b = scale_size.batch[0];
89 auto scale_name = node->as<batch_norm>().scale().id();
90 std::vector<int32_t> scale_sizes = scale_size.sizes();
91 int32_t scale_max_size = *std::max_element(std::begin(scale_sizes), std::end(scale_sizes));
93 auto r_prim = std::make_shared<reshape>("reshape_" + scale_name + "_" + node->id(), scale_name, tensor(1, scale_max_size, 1, 1));
94 auto& r_prim_node = p.get_or_create(r_prim);
96 p.add_intermediate(r_prim_node, *node, 3, true);
99 auto shift_size = node->as<batch_norm>().shift().get_output_layout().size;
100 auto shift_x = shift_size.spatial[0];
101 auto shift_y = shift_size.spatial[1];
102 auto shift_b = shift_size.batch[0];
108 auto shift_name = node->as<batch_norm>().shift().id();
109 std::vector<int32_t> shift_sizes = shift_size.sizes();
110 int32_t shift_max_size = *std::max_element(std::begin(shift_sizes), std::end(shift_sizes));
112 auto r_prim = std::make_shared<reshape>("reshape_" + shift_name + "_" + node->id(), shift_name, tensor(1, shift_max_size, 1, 1));
113 auto& r_prim_node = p.get_or_create(r_prim);
115 p.add_intermediate(r_prim_node, *node, 4, true);