Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / graph_optimizer / add_reshape_to_primitives.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18
19 #include "pass_manager.h"
20 #include "batch_norm_inst.h"
21 #include "reshape_inst.h"
22
23 using namespace cldnn;
24
25 //Some primitives require a specific shape for thier inputs/parameters.
26 //We should check this and add reshape to be compliant with this.
27 //
28 //Example: batch_norm primitive requires that mean/variance/scale/shift is shape {1, X, 1, 1}
29 void add_reshape_to_primitives::run(program_impl& p)
30 {
31     auto processing_order = p.get_processing_order();
32
33     for (auto& node : processing_order)
34     {
35         //if node is batch_norm and mean/var are given (i.e. use eltwise kernel to calculate batch_norm)
36         if (node->is_type<batch_norm>() &&
37             (!node->as<batch_norm>().calc_mean_var() && node->as<batch_norm>().use_global_stats()))
38         {
39             auto mean_layout = node->as<batch_norm>().mean().get_output_layout();
40             auto mean_size = mean_layout.size;
41             auto mean_x = mean_size.spatial[0];
42             auto mean_y = mean_size.spatial[1];
43             auto mean_b = mean_size.batch[0];
44
45             if (mean_x != 1
46                 || mean_y != 1
47                 || mean_b != 1)
48             {
49                 auto mean_name = node->as<batch_norm>().mean().id();
50                 std::vector<int32_t> mean_sizes = mean_size.sizes();
51                 int32_t mean_max_size = *std::max_element(std::begin(mean_sizes), std::end(mean_sizes));
52
53                 auto r_prim = std::make_shared<reshape>("reshape_" + mean_name + "_" + node->id(), mean_name, tensor(1, mean_max_size, 1, 1));
54                 auto& r_prim_node = p.get_or_create(r_prim);
55
56                 p.add_intermediate(r_prim_node, *node, 1, true);
57             }
58
59             auto variance_size = node->as<batch_norm>().variance().get_output_layout().size;
60             auto variance_x = variance_size.spatial[0];
61             auto variance_y = variance_size.spatial[1];
62             auto variance_b = variance_size.batch[0];
63
64             if (variance_x != 1
65                 || variance_y != 1
66                 || variance_b != 1)
67             {
68                 auto variance_name = node->as<batch_norm>().variance().id();
69                 std::vector<int32_t> variance_sizes = variance_size.sizes();
70                 int32_t variance_max_size = *std::max_element(std::begin(variance_sizes), std::end(variance_sizes));
71
72                 auto r_prim = std::make_shared<reshape>("reshape_" + variance_name + "_" + node->id(), variance_name, tensor(1, variance_max_size, 1, 1));
73                 auto& r_prim_node = p.get_or_create(r_prim);
74
75                 p.add_intermediate(r_prim_node, *node, 2, true);
76             }
77
78             if (node->as<batch_norm>().use_scale_shift())
79             {
80                 auto scale_size = node->as<batch_norm>().scale().get_output_layout().size;
81                 auto scale_x = scale_size.spatial[0];
82                 auto scale_y = scale_size.spatial[1];
83                 auto scale_b = scale_size.batch[0];
84
85                 if (scale_x != 1
86                     || scale_y != 1
87                     || scale_b != 1)
88                 {
89                     auto scale_name = node->as<batch_norm>().scale().id();
90                     std::vector<int32_t> scale_sizes = scale_size.sizes();
91                     int32_t scale_max_size = *std::max_element(std::begin(scale_sizes), std::end(scale_sizes));
92
93                     auto r_prim = std::make_shared<reshape>("reshape_" + scale_name + "_" + node->id(), scale_name, tensor(1, scale_max_size, 1, 1));
94                     auto& r_prim_node = p.get_or_create(r_prim);
95
96                     p.add_intermediate(r_prim_node, *node, 3, true);
97                 }
98
99                 auto shift_size = node->as<batch_norm>().shift().get_output_layout().size;
100                 auto shift_x = shift_size.spatial[0];
101                 auto shift_y = shift_size.spatial[1];
102                 auto shift_b = shift_size.batch[0];
103
104                 if (shift_x != 1
105                     || shift_y != 1
106                     || shift_b != 1)
107                 {
108                     auto shift_name = node->as<batch_norm>().shift().id();
109                     std::vector<int32_t> shift_sizes = shift_size.sizes();
110                     int32_t shift_max_size = *std::max_element(std::begin(shift_sizes), std::end(shift_sizes));
111
112                     auto r_prim = std::make_shared<reshape>("reshape_" + shift_name + "_" + node->id(), shift_name, tensor(1, shift_max_size, 1, 1));
113                     auto& r_prim_node = p.get_or_create(r_prim);
114
115                     p.add_intermediate(r_prim_node, *node, 4, true);
116                 }
117             }
118         }
119     }
120 }