Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / graph_optimizer / eltwise_remove_stride.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18
19 #include "api/CPP/tensor.hpp"
20
21 #include "pass_manager.h"
22
23 #include "convolution_inst.h"
24 #include "eltwise_inst.h"
25
26 #include <memory>
27
28 using namespace cldnn;
29
30 void eltwise_remove_stride::conv_stride_extend(program_impl& p, program_node& node, cldnn::tensor& tensor)
31 {
32     // make sure we have only 1 user
33     if (node.get_users().size() > 1)
34         return;
35
36     const auto conv = std::static_pointer_cast<const convolution>(node.get_primitive());
37     auto weights_node_ptr = p.get_node_ptr(conv->weights[0]);
38     auto filter_size = weights_node_ptr->get_output_layout().size;
39     // make sure this is conv 1x1
40     if (filter_size.spatial[0] == 1 && filter_size.spatial[1] == 1)
41     {
42         auto deps = node.get_dependencies();
43         for (auto dep : deps)
44         {
45             if (dep->is_type<convolution>())
46             {
47                 conv_stride_extend(p, *dep, tensor);
48                 dep->recalc_output_layout(true);
49                 break;
50             }
51         }
52         auto c = const_cast<convolution*>(&(*conv));
53         c->with_output_size = false;
54         node.recalc_output_layout(true);
55     }
56     else
57     {
58         bool can_shrink_x = (filter_size.spatial[0] - (conv->stride.spatial[0] + (tensor.spatial[0] - 1))) >= 0;
59         bool can_shrink_y = (filter_size.spatial[1] - (conv->stride.spatial[1] + (tensor.spatial[1] - 1))) >= 0;
60         if (can_shrink_x && can_shrink_y)
61         {
62             auto c = const_cast<convolution*>(&(*conv));
63             c->stride.spatial[0] += tensor.spatial[0] - 1;
64             c->stride.spatial[1] += tensor.spatial[1] - 1;
65             c->with_output_size = false;
66             node.recalc_output_layout(true);
67             tensor.spatial[0] = 1;
68             tensor.spatial[1] = 1;
69         }
70     }
71 }
72
73 void eltwise_remove_stride::run(program_impl& p)
74 {
75     for (auto& node : p.get_processing_order())
76     {
77         if (node->is_type<eltwise>())
78         {
79             // TODO: make fp16 work
80             if (node->get_output_layout().data_type != data_types::i8 && node->get_output_layout().data_type != data_types::f32)
81             {
82                 if (node->get_output_layout().data_type != data_types::f16 || node->get_output_layout().format != format::yxfb)
83                 {
84                     continue;
85                 }
86             }
87
88             const auto eltw = std::static_pointer_cast<const eltwise>(node->get_primitive());
89             if (!eltw->stride.empty())
90             {
91                 auto deps = node->get_dependencies();
92                 for (size_t i = 0; i < deps.size(); i++)
93                 {
94                     auto dep = deps[i];
95                     // TODO: add other primitives beside convolution here
96                     if (dep->is_type<convolution>())
97                     {
98                         auto e = const_cast<eltwise*>(&(*eltw));
99                         conv_stride_extend(p, *dep, e->stride[i]);
100                     }
101                 }
102             }
103         }
104     }
105 }