CLDNN_ERROR_NOT_EQUAL(arg.id(), "roi_pooling padding filling value", padding_filling_value, "padding mode", 0.0f, "Unknown padding mode in roi_pooling.");
CLDNN_ERROR_NOT_PROPER_FORMAT(arg.id(), "Input_layout.format", input_layout.format.value, "output_layout.format", output_layout.format);
- auto group_sz = primitive->group_sz;
- auto in_feat = input_layout.get_buffer_size().feature[0];
- auto out_feat = output_layout.get_buffer_size().feature[0];
-
- CLDNN_ERROR_LESS_THAN(arg.id(), "Group size", group_sz, "value", 0, "");
- if (group_sz) {
- CLDNN_ERROR_NOT_EQUAL(arg.id(), "input feture map", in_feat, "group_sz * group_sz * out_feat", group_sz * group_sz * out_feat, "");
- }
- CLDNN_ERROR_BOOL(arg.id(), "Batching", !hasSingleBatchOutput(arg.input()), "PS/ RoI Pooling doesn't support batching.");
-
auto roi_params = get_default_params<kernel_selector::roi_pooling_params>(arg);
auto roi_optional_params = get_default_optional_params<kernel_selector::roi_pooling_optional_params>(arg.get_program());
const auto& out = roi_params.output;
-
+
const auto roi_bfyx = convert_data_tensor(rois_layout);
const auto roi_bf = roi_bfyx.FlattenFeatureAndSpatials();
roi_params.inputs.push_back(roi_bf);
roi_params.output = { out.GetDims(), out.GetDType(), kernel_selector::data_layout::brfyx, out.GetViewOffset(), out.PhysicalSize(), out.GetPaddedVal() }; // TOOD: it's an hack - cldnn doesn't support roi pooling with batching
- roi_params.mode = cldnn_2_pool_type(primitive->mode);
- roi_params.pooledWidth = primitive->pooled_width;
- roi_params.pooledHeight = primitive->pooled_height;
- roi_params.spatialScale = primitive->spatial_scale;
- roi_params.groupSize = group_sz;
+ roi_params.mode = cldnn_2_pool_type(primitive->mode);
+ roi_params.position_sensitive = primitive->position_sensitive;
+ roi_params.pooledWidth = primitive->pooled_width;
+ roi_params.pooledHeight = primitive->pooled_height;
+ roi_params.spatialScale = primitive->spatial_scale;
+ roi_params.spatial_bins_x = primitive->spatial_bins_x;
+ roi_params.spatial_bins_y = primitive->spatial_bins_y;
auto& kernel_selector = kernel_selector::roi_pooling_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(roi_params, roi_optional_params);