2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "max_unpooling_inst.h"
18 #include "primitive_type_base.h"
19 #include "sliding_window_utils.h"
20 #include "error_handler.h"
21 #include "json_object.h"
25 primitive_type_id max_unpooling_type_id()
27 static primitive_type_base<max_unpooling> instance;
31 max_unpooling_node::typed_program_node(const std::shared_ptr<max_unpooling> prim, program_impl& prog)
34 can_share_buffer(false); // for max_unpooling initial zero values are significant
37 layout max_unpooling_inst::calc_output_layout(max_unpooling_node const& node)
39 assert((bool)node.get_primitive()->output_data_type == false
40 && "Output data type forcing is not supported for max_unpooling_node!");
41 auto desc = node.get_primitive();
43 auto input_layout = node.input().get_output_layout();
44 auto argmax_layout = node.argmax().get_output_layout();
46 CLDNN_ERROR_NOT_EQUAL(node.id(), "Argmax data type", static_cast<size_t>(argmax_layout.data_type), "expected to be fp32", static_cast<size_t>(data_types::f32), "Argmax data type is not fp32.");
48 if (desc->with_output_size)
50 tensor output_size(input_layout.size.batch[0], input_layout.size.feature[0],
51 desc->output_size.spatial[0], desc->output_size.spatial[1]);
52 return{ input_layout.data_type, input_layout.format, output_size };
55 auto input_offset = desc->input_offset;
56 auto stride = desc->stride;
57 auto window_size = desc->size;
59 CLDNN_ERROR_LESS_OR_EQUAL_THAN(node.id(), "stride spatial X", stride.spatial[0], "", 0, "Stride spatial X must be positive (>= 1)");
60 CLDNN_ERROR_LESS_OR_EQUAL_THAN(node.id(), "stride spatial Y", stride.spatial[1], "", 0, "Stride spatial Y must be positive (>= 1)");
61 CLDNN_ERROR_LESS_OR_EQUAL_THAN(node.id(), "window size spatial X", window_size.spatial[0], "", 0, "Size X (of pooling window) must be positive (>= 1)");
62 CLDNN_ERROR_LESS_OR_EQUAL_THAN(node.id(), "window size spatial Y", window_size.spatial[1], "", 0, "Size Y (of pooling window) must be positive (>= 1)");
63 CLDNN_ERROR_GREATER_THAN(node.id(), "Input offset spatial X", 2 * input_offset.spatial[0], "input layout size spatial X", input_layout.size.spatial[0], "Input offset is greater than input data range. There is no input data to process");
64 CLDNN_ERROR_GREATER_THAN(node.id(), "Input offset spatial Y", 2 * input_offset.spatial[1], "input layout size spatial Y", input_layout.size.spatial[1], "Input offset is greater than input data range. There is no input data to process");
65 CLDNN_ERROR_GREATER_THAN(node.id(), "Negate input offset spatial X", -input_offset.spatial[0], "input window size spatial X", window_size.spatial[0], "First pool is outside of image. please reduce input offset X");
66 CLDNN_ERROR_GREATER_THAN(node.id(), "Negate input offset spatial Y", -input_offset.spatial[1], "input window size spatial Y", window_size.spatial[1], "First pool is outside of image. please reduce input offset Y");
67 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input offset feature", input_offset.feature[0], "", 0, "Input offset in feature is not supported");
68 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input offset batch", input_offset.batch[0], "", 0, "Input offset in batch is not supported");
70 auto output_range = calc_sliding_window_needed_input_range(
71 input_layout.size, window_size, input_offset, stride, { 1, 1, 1, 1 }, true, 1);
73 tensor output_size(input_layout.size.batch[0], input_layout.size.feature[0],
74 output_range.spatial[0], output_range.spatial[1]);
75 return{ input_layout.data_type, input_layout.format, output_size };
78 std::string max_unpooling_inst::to_string(max_unpooling_node const& node)
80 auto desc = node.get_primitive();
81 auto node_info = node.desc_to_json();
82 auto& input = node.input();
83 auto& argmax = node.argmax();
85 std::stringstream primitive_description;
87 json_composite max_unmax_unpooling_info;
88 max_unmax_unpooling_info.add("input", input.id());
89 max_unmax_unpooling_info.add("argmax", argmax.id());
91 node_info->add("max unmax_unpooling info", max_unmax_unpooling_info);
92 node_info->dump(primitive_description);
94 return primitive_description.str();
97 max_unpooling_inst::typed_primitive_inst(network_impl& network, max_unpooling_node const& node)
98 :parent(network, node)