Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / graph_optimizer / eltwise_shrinking.cpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "pass_manager.h"
18 #include "eltwise_inst.h"
19
20 using namespace cldnn;
21
22 void eltwise_shrinking::run(program_impl& p)
23 {
24     std::vector<program_node*> convs_to_shrink;
25
26     for (auto& node : p.get_processing_order())
27     {
28         if (node->is_type<eltwise>())
29         {
30             // TODO: make fp16 work
31             if (node->get_output_layout().data_type != data_types::i8 && node->get_output_layout().data_type != data_types::f32)
32             {
33                 if (node->get_output_layout().data_type != data_types::f16 || node->get_output_layout().format != format::yxfb)
34                 {
35                     continue;
36                 }
37             }
38
39             const auto eltw = std::static_pointer_cast<const eltwise>(node->get_primitive());
40             // TODO: support cases which already have stride!
41             if (eltw->stride.empty())
42             {
43                 bool can_shrink = true;
44                 int32_t stride_x = 0;
45                 int32_t stride_y = 0;
46                 convs_to_shrink.clear();
47                 auto users = node->get_users();
48                 for (auto user : users)
49                 {
50                     // currently we can shrink only if users are convolutions
51                     if (!user->is_type<convolution>())
52                     {
53                         can_shrink = false;
54                         break;
55                     }
56
57                     if (user->get_output_layout().format == format::b_fs_yx_fsv4)
58                     {
59                         // Workaround for VIS-1079
60                         // Currently, we don't have "conv + eltwise" optimization for
61                         // IMAD and it blocks us to run the whole ResNet-50.i8 topology in IMAD.
62                         // As workaround, this optimization will be temporary switched off for
63                         // "format == b_fs_yx_fsv4"(IMAD specific data layout).
64                         // TODO: Please, remove this code, when VIS - 1079 will be done.
65                         can_shrink = false;
66                         break;
67                     }
68
69                     const auto conv = std::static_pointer_cast<const convolution>(user->get_primitive());
70                     if (conv->weights.size() != 1)
71                     {
72                         can_shrink = false;
73                         break;
74                     }
75
76                     auto weights_node_ptr = p.get_node_ptr(conv->weights[0]);
77                     auto filter_size = weights_node_ptr->get_output_layout().size;
78                     // make sure this is conv 1x1
79                     if (filter_size.spatial[0] != 1 || filter_size.spatial[1] != 1)
80                     {
81                         can_shrink = false;
82                         break;
83                     }
84
85                     // make sure convolution can accept shrinked input by modifying stride
86                     if (conv->stride.spatial[0] > 1 || conv->stride.spatial[1] > 1)
87                     {
88                         if (stride_x == 0)
89                             stride_x = conv->stride.spatial[0];
90                         if (stride_y == 0)
91                             stride_y = conv->stride.spatial[1];
92
93                         // make sure stride across all eltwise's convolution users is the same
94                         if (conv->stride.spatial[0] != stride_x || conv->stride.spatial[1] != stride_y)
95                         {
96                             can_shrink = false;
97                             break;
98                         }
99                         convs_to_shrink.push_back(user);
100                     }
101                     else
102                     {
103                         can_shrink = false;
104                         break;
105                     }
106                 }
107                 if (can_shrink)
108                 {
109                     // add stride for every eltwise's inputs to have shrinked output
110                     auto e = const_cast<eltwise*>(&(*eltw));
111                     for (size_t user = 0; user < node->get_users().size(); user++)
112                     {
113                         e->stride.push_back({ 0,0,stride_x,stride_y });
114                     }
115                     node->recalc_output_layout();
116
117                     // change stride on every convolution
118                     for (size_t i = 0; i < convs_to_shrink.size(); i++)
119                     {
120                         const auto conv = std::static_pointer_cast<const convolution>(convs_to_shrink[i]->get_primitive());
121                         auto c = const_cast<convolution*>(&(*conv));
122                         c->stride.spatial[0] = 1;
123                         c->stride.spatial[1] = 1;
124                         // TODO: remove forcing "false" with_output_size if not needed
125                         c->with_output_size = false;
126                         convs_to_shrink[i]->recalc_output_layout();
127                     }
128                 }
129             }
130         }
131     }
132 }