2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "pooling_inst.h"
18 #include "primitive_type_base.h"
19 #include "sliding_window_utils.h"
20 #include "error_handler.h"
21 #include "json_object.h"
25 primitive_type_id pooling_type_id()
27 static primitive_type_base<pooling> instance;
31 layout pooling_inst::calc_output_layout(parent::typed_node const& node)
33 assert((bool)node.get_primitive()->output_data_type == false
34 && "Output data type forcing is not supported for pooling_node!");
35 auto desc = node.get_primitive();
37 auto input_layout = node.input().get_output_layout();
39 auto input_offset = desc->input_offset;
40 auto stride = desc->stride;
41 auto window_size = desc->size;
43 if (!desc->argmax.empty())
44 CLDNN_ERROR_NOT_EQUAL(node.id(), "Pooling mode", static_cast<size_t>(desc->mode), "should be max_with_argmax", static_cast<size_t>(pooling_mode::max_with_argmax), "Pooling mode should be set to max_with_argmax when argmax primitive is present.");
46 if (desc->mode == pooling_mode::max_with_argmax)
48 CLDNN_ERROR_NOT_EQUAL(node.id(), "Argmax primitive", static_cast<size_t>(desc->argmax.empty()), "should not be empty", static_cast<size_t>(0), "Argmax primitive not present despite max_with_argmax mode.");
50 auto argmax_layout = node.argmax().get_output_layout();
51 CLDNN_ERROR_NOT_EQUAL(node.id(), "Argmax data type", static_cast<size_t>(argmax_layout.data_type), "expected to be fp32", static_cast<size_t>(data_types::f32), "Argmax data type is not fp32.");
52 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Input_layout.format", input_layout.format.value, "argmax_layout.format", argmax_layout.format);
55 if (desc->global_pooling) {
56 window_size.spatial[0] = input_layout.size.spatial[0];
57 window_size.spatial[1] = input_layout.size.spatial[1];
60 // TODO: Consider moving general parameter verification to arguments constructor.
61 CLDNN_ERROR_LESS_OR_EQUAL_THAN(node.id(), "stride spatial X", stride.spatial[0], "", 0, "Stride spatial X must be positive (>= 1)");
62 CLDNN_ERROR_LESS_OR_EQUAL_THAN(node.id(), "stride spatial Y", stride.spatial[1], "", 0, "Stride spatial Y must be positive (>= 1)");
63 CLDNN_ERROR_LESS_OR_EQUAL_THAN(node.id(), "window size spatial X", window_size.spatial[0], "", 0, "Size X (of pooling window) must be positive (>= 1)");
64 CLDNN_ERROR_LESS_OR_EQUAL_THAN(node.id(), "window size spatial Y", window_size.spatial[1], "", 0, "Size Y (of pooling window) must be positive (>= 1)");
65 CLDNN_ERROR_GREATER_THAN(node.id(), "Input offset spatial X", 2 * input_offset.spatial[0], "input layout size spatial X", input_layout.size.spatial[0], "Input offset is greater than input data range. There is no input data to process");
66 CLDNN_ERROR_GREATER_THAN(node.id(), "Input offset spatial Y", 2 * input_offset.spatial[1], "input layout size spatial Y", input_layout.size.spatial[1], "Input offset is greater than input data range. There is no input data to process");
67 CLDNN_ERROR_GREATER_THAN(node.id(), "Negate input offset spatial X", -input_offset.spatial[0], "input window size spatial X", window_size.spatial[0], "First pool is outside of image. please reduce input offset X");
68 CLDNN_ERROR_GREATER_THAN(node.id(), "Negate input offset spatial Y", -input_offset.spatial[1], "input window size spatial Y", window_size.spatial[1], "First pool is outside of image. please reduce input offset Y");
69 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input offset feature", input_offset.feature[0], "", 0, "Input offset in feature is not supported");
70 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input offset batch", input_offset.batch[0], "", 0, "Input offset in batch is not supported");
72 if (desc->with_output_size)
74 CLDNN_ERROR_LESS_OR_EQUAL_THAN(node.id(), "User-defined size of output X", desc->output_size.spatial[0], "", 0, "User-defined size of output layout (spatial X) must be positive (>= 1)");
75 CLDNN_ERROR_LESS_OR_EQUAL_THAN(node.id(), "User-defined size of output Y", desc->output_size.spatial[1], "", 0, "User-defined size of output layout (spatial Y) must be positive (>= 1)");
77 tensor output_size(input_layout.size.batch[0], input_layout.size.feature[0],
78 desc->output_size.spatial[0], desc->output_size.spatial[1]);
79 return { input_layout.data_type, input_layout.format, output_size };
82 // TODO: Check compatibility of output size calculation (with caffe).
83 auto output_range = calc_sliding_window_output_range<swor_mode::exceed_once_data>(
84 input_layout.size, window_size, input_offset, stride, {1, 1, 1, 1}, true, 1);
86 tensor output_size(input_layout.size.batch[0], input_layout.size.feature[0],
87 output_range.spatial[0], output_range.spatial[1]);
88 return{ input_layout.data_type, input_layout.format, output_size };
91 std::string pooling_inst::to_string(pooling_node const& node)
93 auto desc = node.get_primitive();
94 auto strd = desc->stride;
95 auto mode = desc->mode == pooling_mode::max ? "max" : "average";
96 auto node_info = node.desc_to_json();
97 auto kernel_size = desc->size;
99 std::stringstream primitive_description;
101 json_composite pooling_info;
102 pooling_info.add("mode", mode);
103 pooling_info.add("stride", strd.to_string());
104 pooling_info.add("kernel size", kernel_size.to_string());
105 pooling_info.add("input offset", desc->input_offset.to_string());
106 if (desc->with_output_size)
108 json_composite ud_out_size_info;
109 ud_out_size_info.add("size", desc->output_size.to_string());
110 pooling_info.add("with_user_defined_output_size", ud_out_size_info);
113 node_info->add("pooling info", pooling_info);
114 node_info->dump(primitive_description);
116 return primitive_description.str();