2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "pass_manager.h"
18 #include "eltwise_inst.h"
20 using namespace cldnn;
22 void eltwise_shrinking::run(program_impl& p)
24 std::vector<program_node*> convs_to_shrink;
26 for (auto& node : p.get_processing_order())
28 if (node->is_type<eltwise>())
30 // TODO: make fp16 work
31 if (node->get_output_layout().data_type != data_types::i8 && node->get_output_layout().data_type != data_types::f32)
33 if (node->get_output_layout().data_type != data_types::f16 || node->get_output_layout().format != format::yxfb)
39 const auto eltw = std::static_pointer_cast<const eltwise>(node->get_primitive());
40 // TODO: support cases which already have stride!
41 if (eltw->stride.empty())
43 bool can_shrink = true;
46 convs_to_shrink.clear();
47 auto users = node->get_users();
48 for (auto user : users)
50 // currently we can shrink only if users are convolutions
51 if (!user->is_type<convolution>())
57 if (user->get_output_layout().format == format::b_fs_yx_fsv4)
59 // Workaround for VIS-1079
60 // Currently, we don't have "conv + eltwise" optimization for
61 // IMAD and it blocks us to run the whole ResNet-50.i8 topology in IMAD.
62 // As workaround, this optimization will be temporary switched off for
63 // "format == b_fs_yx_fsv4"(IMAD specific data layout).
64 // TODO: Please, remove this code, when VIS - 1079 will be done.
69 const auto conv = std::static_pointer_cast<const convolution>(user->get_primitive());
70 if (conv->weights.size() != 1)
76 auto weights_node_ptr = p.get_node_ptr(conv->weights[0]);
77 auto filter_size = weights_node_ptr->get_output_layout().size;
78 // make sure this is conv 1x1
79 if (filter_size.spatial[0] != 1 || filter_size.spatial[1] != 1)
85 // make sure convolution can accept shrinked input by modifying stride
86 if (conv->stride.spatial[0] > 1 || conv->stride.spatial[1] > 1)
89 stride_x = conv->stride.spatial[0];
91 stride_y = conv->stride.spatial[1];
93 // make sure stride across all eltwise's convolution users is the same
94 if (conv->stride.spatial[0] != stride_x || conv->stride.spatial[1] != stride_y)
99 convs_to_shrink.push_back(user);
109 // add stride for every eltwise's inputs to have shrinked output
110 auto e = const_cast<eltwise*>(&(*eltw));
111 for (size_t user = 0; user < node->get_users().size(); user++)
113 e->stride.push_back({ 0,0,stride_x,stride_y });
115 node->recalc_output_layout();
117 // change stride on every convolution
118 for (size_t i = 0; i < convs_to_shrink.size(); i++)
120 const auto conv = std::static_pointer_cast<const convolution>(convs_to_shrink[i]->get_primitive());
121 auto c = const_cast<convolution*>(&(*conv));
122 c->stride.spatial[0] = 1;
123 c->stride.spatial[1] = 1;
124 // TODO: remove forcing "false" with_output_size if not needed
125 c->with_output_size = false;
126 convs_to_shrink[i]->recalc_output_layout();