Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / roi_pooling.cpp
index 0d45548..cbaca7b 100644 (file)
@@ -29,44 +29,35 @@ primitive_type_id roi_pooling_type_id()
 
 layout roi_pooling_inst::calc_output_layout(roi_pooling_node const& node)
 {
+    assert((bool)node.get_primitive()->output_data_type == false
+           && "Output data type forcing is not supported for roi_pooling_node!");
     auto desc = node.get_primitive();
     layout data_layout = node.input().get_output_layout();
-    int fm = data_layout.size.feature[0];
-
     layout rois_layout = node.rois().get_output_layout();
     int num_rois = rois_layout.size.batch[0];
+    int out_fm = desc->position_sensitive ? desc->output_dim : data_layout.size.feature[0];
 
-    int gss = desc->group_sz * desc->group_sz;
-
-
-    CLDNN_ERROR_LESS_THAN(node.id(), "Group size", desc->group_sz, "value", 0, "");
-    if (gss && fm % gss != 0)
-    {
-        CLDNN_ERROR_MESSAGE(node.id(), "group_sz must be either 0 (For RoIPooling) or satisfy fm % (group_sz^2) == 0");
-    }
-    
-    if (gss)
-    {
-        fm /= gss;
-    }
-
-    return layout(data_layout.data_type, format::bfyx, { num_rois, fm, desc->pooled_width, desc->pooled_height });
+    return layout(data_layout.data_type, format::bfyx, { num_rois, out_fm, desc->pooled_width, desc->pooled_height });
 }
 
 std::string roi_pooling_inst::to_string(roi_pooling_node const& node)
 {
     auto desc      = node.get_primitive();
     auto mode      = desc->mode == pooling_mode::max ? "max" : desc->mode == pooling_mode::bilinear ? "bilinear" : "average";
+    auto is_ps     = desc->position_sensitive ? "true" : "false";
     auto node_info = node.desc_to_json();
 
     std::stringstream primitive_description;
 
     json_composite roi_info;
     roi_info.add("mode", mode);
+    roi_info.add("position sensitive", is_ps);
     roi_info.add("pooled_w", desc->pooled_width);
     roi_info.add("pooled_h", desc->pooled_height);
     roi_info.add("spatial_scale", desc->spatial_scale);
-    roi_info.add("group_sz", desc->group_sz);
+    roi_info.add("output_dim", desc->output_dim);
+    roi_info.add("spatial_bins_x", desc->spatial_bins_x);
+    roi_info.add("spatial_bins_y", desc->spatial_bins_y);
 
     node_info->add("roi info", roi_info);
     node_info->dump(primitive_description);