Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / proposal.cpp
index c2cd1fb..7e98104 100644 (file)
@@ -1,5 +1,5 @@
 /*
-// Copyright (c) 2017-2018 Intel Corporation
+// Copyright (c) 2017-2019 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -39,10 +39,12 @@ primitive_type_id proposal_type_id()
 
 layout proposal_inst::calc_output_layout(proposal_node const& node)
 {
+    assert((bool)node.get_primitive()->output_data_type == false
+           && "Output data type forcing is not supported for proposal_node!");
     auto desc = node.get_primitive();
     layout input_layout = node.get_dependency(cls_scores_index).get_output_layout();
 
-    return layout(input_layout.data_type, format::bfyx, { desc->post_nms_topn, CLDNN_ROI_VECTOR_SIZE, 1, 1 });
+    return layout(input_layout.data_type, format::bfyx, { input_layout.size.batch[0] * desc->post_nms_topn, CLDNN_ROI_VECTOR_SIZE, 1, 1 });
 }
 
 static inline std::string stringify_vector(std::vector<float> v)
@@ -81,10 +83,12 @@ std::string proposal_inst::to_string(proposal_node const& node)
 
     std::stringstream primitive_description;
 
-    auto swap_xy = desc->swap_xy ? "true" : "false";
-    auto initial_clip = desc->initial_clip ? "true" : "false";
-    auto round_ratios = desc->round_ratios ? "true" : "false";
-    auto shift_anchors = desc->shift_anchors ? "true" : "false";
+    auto swap_xy         = desc->swap_xy         ? "true" : "false";
+    auto initial_clip    = desc->initial_clip    ? "true" : "false";
+    auto round_ratios    = desc->round_ratios    ? "true" : "false";
+    auto shift_anchors   = desc->shift_anchors   ? "true" : "false";
+    auto clip_before_nms = desc->clip_before_nms ? "true" : "false";
+    auto clip_after_nms  = desc->clip_after_nms  ? "true" : "false";
 
     json_composite proposal_info;
     proposal_info.add("cls score", stringify_port(node.cls_score()));
@@ -107,6 +111,8 @@ std::string proposal_inst::to_string(proposal_node const& node)
     params.add("initial clip", initial_clip);
     params.add("round ratios", round_ratios);
     params.add("shift anchors", shift_anchors);
+    params.add("clip_before_nms", clip_before_nms);
+    params.add("clip_after_nms", clip_after_nms);
     proposal_info.add("params", params);
 
     node_info->add("proposal info", proposal_info);