2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "api/CPP/tensor.hpp"
21 #include "pass_manager.h"
23 #include "convolution_inst.h"
24 #include "eltwise_inst.h"
28 using namespace cldnn;
30 void eltwise_remove_stride::conv_stride_extend(program_impl& p, program_node& node, cldnn::tensor& tensor)
32 // make sure we have only 1 user
33 if (node.get_users().size() > 1)
36 const auto conv = std::static_pointer_cast<const convolution>(node.get_primitive());
37 auto weights_node_ptr = p.get_node_ptr(conv->weights[0]);
38 auto filter_size = weights_node_ptr->get_output_layout().size;
39 // make sure this is conv 1x1
40 if (filter_size.spatial[0] == 1 && filter_size.spatial[1] == 1)
42 auto deps = node.get_dependencies();
45 if (dep->is_type<convolution>())
47 conv_stride_extend(p, *dep, tensor);
48 dep->recalc_output_layout(true);
52 auto c = const_cast<convolution*>(&(*conv));
53 c->with_output_size = false;
54 node.recalc_output_layout(true);
58 bool can_shrink_x = (filter_size.spatial[0] - (conv->stride.spatial[0] + (tensor.spatial[0] - 1))) >= 0;
59 bool can_shrink_y = (filter_size.spatial[1] - (conv->stride.spatial[1] + (tensor.spatial[1] - 1))) >= 0;
60 if (can_shrink_x && can_shrink_y)
62 auto c = const_cast<convolution*>(&(*conv));
63 c->stride.spatial[0] += tensor.spatial[0] - 1;
64 c->stride.spatial[1] += tensor.spatial[1] - 1;
65 c->with_output_size = false;
66 node.recalc_output_layout(true);
67 tensor.spatial[0] = 1;
68 tensor.spatial[1] = 1;
73 void eltwise_remove_stride::run(program_impl& p)
75 for (auto& node : p.get_processing_order())
77 if (node->is_type<eltwise>())
79 // TODO: make fp16 work
80 if (node->get_output_layout().data_type != data_types::i8 && node->get_output_layout().data_type != data_types::f32)
82 if (node->get_output_layout().data_type != data_types::f16 || node->get_output_layout().format != format::yxfb)
88 const auto eltw = std::static_pointer_cast<const eltwise>(node->get_primitive());
89 if (!eltw->stride.empty())
91 auto deps = node->get_dependencies();
92 for (size_t i = 0; i < deps.size(); i++)
95 // TODO: add other primitives beside convolution here
96 if (dep->is_type<convolution>())
98 auto e = const_cast<eltwise*>(&(*eltw));
99 conv_stride_extend(p, *dep, e->stride[i]);