From d27733402834533862852ed13ed696f6bc70c83e Mon Sep 17 00:00:00 2001 From: Vladimir Gavrilov Date: Wed, 14 Oct 2020 16:47:43 +0300 Subject: [PATCH] nGraph implementation of NMS-5 (without `evaluate()`) (#2651) * Written nGraph NMS-5 without evaluate(). * Used NGRAPH_RTTI_DECLARATION. --- .../include/legacy/ngraph_ops/nms_ie.hpp | 28 ++ .../src/legacy_api/src/ngraph_ops/nms_ie.cpp | 71 +++++ .../core/include/ngraph/op/non_max_suppression.hpp | 171 ++++++++++- ngraph/core/include/ngraph/opsets/opset5_tbl.hpp | 2 +- ngraph/core/src/op/non_max_suppression.cpp | 326 +++++++++++++++++++++ ngraph/test/type_prop/non_max_suppression.cpp | 221 ++++++++++++++ 6 files changed, 817 insertions(+), 2 deletions(-) diff --git a/inference-engine/src/legacy_api/include/legacy/ngraph_ops/nms_ie.hpp b/inference-engine/src/legacy_api/include/legacy/ngraph_ops/nms_ie.hpp index 1aff882..49e689e 100644 --- a/inference-engine/src/legacy_api/include/legacy/ngraph_ops/nms_ie.hpp +++ b/inference-engine/src/legacy_api/include/legacy/ngraph_ops/nms_ie.hpp @@ -58,5 +58,33 @@ public: std::shared_ptr clone_with_new_inputs(const OutputVector & new_args) const override; }; +class INFERENCE_ENGINE_API_CLASS(NonMaxSuppressionIE3) : public Op { +public: + NGRAPH_RTTI_DECLARATION; + + NonMaxSuppressionIE3(const Output& boxes, + const Output& scores, + const Output& max_output_boxes_per_class, + const Output& iou_threshold, + const Output& score_threshold, + const Output& soft_nms_sigma, + int center_point_box, + bool sort_result_descending, + const ngraph::element::Type& output_type = ngraph::element::i64); + + void validate_and_infer_types() override; + + bool visit_attributes(AttributeVisitor& visitor) override; + + std::shared_ptr clone_with_new_inputs(const OutputVector & new_args) const override; + + int m_center_point_box; + bool m_sort_result_descending = true; + element::Type m_output_type; + +private: + int64_t max_boxes_output_from_input() const; +}; + } // namespace op } // namespace ngraph diff --git a/inference-engine/src/legacy_api/src/ngraph_ops/nms_ie.cpp b/inference-engine/src/legacy_api/src/ngraph_ops/nms_ie.cpp index 2aba26e..4e83869 100644 --- a/inference-engine/src/legacy_api/src/ngraph_ops/nms_ie.cpp +++ b/inference-engine/src/legacy_api/src/ngraph_ops/nms_ie.cpp @@ -101,3 +101,74 @@ void op::NonMaxSuppressionIE2::validate_and_infer_types() { m_output_type); set_output_type(0, nms->output(0).get_element_type(), nms->output(0).get_partial_shape()); } + +NGRAPH_RTTI_DEFINITION(op::NonMaxSuppressionIE3, "NonMaxSuppressionIE", 3); + +op::NonMaxSuppressionIE3::NonMaxSuppressionIE3(const Output& boxes, + const Output& scores, + const Output& max_output_boxes_per_class, + const Output& iou_threshold, + const Output& score_threshold, + const Output& soft_nms_sigma, + int center_point_box, + bool sort_result_descending, + const ngraph::element::Type& output_type) + : Op({boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, soft_nms_sigma}), + m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type) { + constructor_validate_and_infer_types(); +} + +std::shared_ptr op::NonMaxSuppressionIE3::clone_with_new_inputs(const ngraph::OutputVector &new_args) const { + check_new_args_count(this, new_args); + return make_shared(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), + new_args.at(4), new_args.at(5), m_center_point_box, m_sort_result_descending, + m_output_type); +} + +bool op::NonMaxSuppressionIE3::visit_attributes(AttributeVisitor& visitor) { + visitor.on_attribute("center_point_box", m_center_point_box); + visitor.on_attribute("sort_result_descending", m_sort_result_descending); + visitor.on_attribute("output_type", m_output_type); + return true; +} + +static constexpr size_t boxes_port = 0; +static constexpr size_t scores_port = 1; +static constexpr size_t max_output_boxes_per_class_port = 2; + +int64_t op::NonMaxSuppressionIE3::max_boxes_output_from_input() const { + int64_t max_output_boxes{0}; + + const auto max_output_boxes_input = + as_type_ptr(input_value(2).get_node_shared_ptr()); + max_output_boxes = max_output_boxes_input->cast_vector().at(0); + + return max_output_boxes; +} + +void op::NonMaxSuppressionIE3::validate_and_infer_types() { + const auto boxes_ps = get_input_partial_shape(boxes_port); + const auto scores_ps = get_input_partial_shape(scores_port); + + // NonMaxSuppression produces triplets + // that have the following format: [batch_index, class_index, box_index] + PartialShape out_shape = {Dimension::dynamic(), 3}; + + if (boxes_ps.rank().is_static() && scores_ps.rank().is_static()) { + const auto num_boxes_boxes = boxes_ps[1]; + const auto max_output_boxes_per_class_node = input_value(max_output_boxes_per_class_port).get_node_shared_ptr(); + if (num_boxes_boxes.is_static() && scores_ps[0].is_static() && scores_ps[1].is_static() && + op::is_constant(max_output_boxes_per_class_node)) { + const auto num_boxes = num_boxes_boxes.get_length(); + const auto num_classes = scores_ps[1].get_length(); + const auto max_output_boxes_per_class = max_boxes_output_from_input(); + + out_shape[0] = std::min(num_boxes, max_output_boxes_per_class) * num_classes * + scores_ps[0].get_length(); + } + } + + set_output_type(0, m_output_type, out_shape); + set_output_type(1, element::f32, out_shape); + set_output_type(2, m_output_type, Shape{1}); +} diff --git a/ngraph/core/include/ngraph/op/non_max_suppression.hpp b/ngraph/core/include/ngraph/op/non_max_suppression.hpp index 3bc8c15..c4cc697 100644 --- a/ngraph/core/include/ngraph/op/non_max_suppression.hpp +++ b/ngraph/core/include/ngraph/op/non_max_suppression.hpp @@ -235,6 +235,156 @@ namespace ngraph clone_with_new_inputs(const OutputVector& new_args) const override; }; } // namespace v4 + + namespace v5 + { + /// \brief NonMaxSuppression operation + /// + class NGRAPH_API NonMaxSuppression : public Op + { + public: + NGRAPH_RTTI_DECLARATION; + enum class BoxEncodingType + { + CORNER, + CENTER + }; + + NonMaxSuppression() = default; + + /// \brief Constructs a NonMaxSuppression operation with default values in the last + /// 4 inputs. + /// + /// \param boxes Node producing the box coordinates + /// \param scores Node producing the box scores + /// \param box_encoding Specifies the format of boxes data encoding + /// \param sort_result_descending Specifies whether it is necessary to sort selected + /// boxes across batches + /// \param output_type Specifies the output tensor type + NonMaxSuppression(const Output& boxes, + const Output& scores, + const BoxEncodingType box_encoding = BoxEncodingType::CORNER, + const bool sort_result_descending = true, + const ngraph::element::Type& output_type = ngraph::element::i64); + + /// \brief Constructs a NonMaxSuppression operation with default values in the last. + /// 3 inputs. + /// + /// \param boxes Node producing the box coordinates + /// \param scores Node producing the box scores + /// \param max_output_boxes_per_class Node producing maximum number of boxes to be + /// selected per class + /// \param box_encoding Specifies the format of boxes data encoding + /// \param sort_result_descending Specifies whether it is necessary to sort selected + /// boxes across batches + /// \param output_type Specifies the output tensor type + NonMaxSuppression(const Output& boxes, + const Output& scores, + const Output& max_output_boxes_per_class, + const BoxEncodingType box_encoding = BoxEncodingType::CORNER, + const bool sort_result_descending = true, + const ngraph::element::Type& output_type = ngraph::element::i64); + + /// \brief Constructs a NonMaxSuppression operation with default values in the last. + /// 2 inputs. + /// + /// \param boxes Node producing the box coordinates + /// \param scores Node producing the box scores + /// \param max_output_boxes_per_class Node producing maximum number of boxes to be + /// selected per class + /// \param iou_threshold Node producing intersection over union threshold + /// \param box_encoding Specifies the format of boxes data encoding + /// \param sort_result_descending Specifies whether it is necessary to sort selected + /// boxes across batches + /// \param output_type Specifies the output tensor type + NonMaxSuppression(const Output& boxes, + const Output& scores, + const Output& max_output_boxes_per_class, + const Output& iou_threshold, + const BoxEncodingType box_encoding = BoxEncodingType::CORNER, + const bool sort_result_descending = true, + const ngraph::element::Type& output_type = ngraph::element::i64); + + /// \brief Constructs a NonMaxSuppression operation with default value in the last. + /// input. + /// + /// \param boxes Node producing the box coordinates + /// \param scores Node producing the box scores + /// \param max_output_boxes_per_class Node producing maximum number of boxes to be + /// selected per class + /// \param iou_threshold Node producing intersection over union threshold + /// \param score_threshold Node producing minimum score threshold + /// \param box_encoding Specifies the format of boxes data encoding + /// \param sort_result_descending Specifies whether it is necessary to sort selected + /// boxes across batches + /// \param output_type Specifies the output tensor type + NonMaxSuppression(const Output& boxes, + const Output& scores, + const Output& max_output_boxes_per_class, + const Output& iou_threshold, + const Output& score_threshold, + const BoxEncodingType box_encoding = BoxEncodingType::CORNER, + const bool sort_result_descending = true, + const ngraph::element::Type& output_type = ngraph::element::i64); + + /// \brief Constructs a NonMaxSuppression operation. + /// + /// \param boxes Node producing the box coordinates + /// \param scores Node producing the box scores + /// \param max_output_boxes_per_class Node producing maximum number of boxes to be + /// selected per class + /// \param iou_threshold Node producing intersection over union threshold + /// \param score_threshold Node producing minimum score threshold + /// \param soft_nms_sigma Node specifying the sigma parameter for Soft-NMS + /// \param box_encoding Specifies the format of boxes data encoding + /// \param sort_result_descending Specifies whether it is necessary to sort selected + /// boxes across batches + /// \param output_type Specifies the output tensor type + NonMaxSuppression(const Output& boxes, + const Output& scores, + const Output& max_output_boxes_per_class, + const Output& iou_threshold, + const Output& score_threshold, + const Output& soft_nms_sigma, + const BoxEncodingType box_encoding = BoxEncodingType::CORNER, + const bool sort_result_descending = true, + const ngraph::element::Type& output_type = ngraph::element::i64); + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + + std::shared_ptr + clone_with_new_inputs(const OutputVector& new_args) const override; + + BoxEncodingType get_box_encoding() const { return m_box_encoding; } + void set_box_encoding(const BoxEncodingType box_encoding) + { + m_box_encoding = box_encoding; + } + bool get_sort_result_descending() const { return m_sort_result_descending; } + void set_sort_result_descending(const bool sort_result_descending) + { + m_sort_result_descending = sort_result_descending; + } + + element::Type get_output_type() const { return m_output_type; } + void set_output_type(const element::Type& output_type) + { + m_output_type = output_type; + } + using Node::set_output_type; + + protected: + BoxEncodingType m_box_encoding = BoxEncodingType::CORNER; + bool m_sort_result_descending = true; + ngraph::element::Type m_output_type = ngraph::element::i64; + void validate(); + int64_t max_boxes_output_from_input() const; + float iou_threshold_from_input() const; + float score_threshold_from_input() const; + float soft_nms_sigma_from_input() const; + }; + } // namespace v5 } // namespace op NGRAPH_API @@ -274,4 +424,23 @@ namespace ngraph "AttributeAdapter", 1}; const DiscreteTypeInfo& get_type_info() const override { return type_info; } }; -} // namespace ngraph + + NGRAPH_API + std::ostream& operator<<(std::ostream& s, + const op::v5::NonMaxSuppression::BoxEncodingType& type); + + template <> + class NGRAPH_API AttributeAdapter + : public EnumAttributeAdapterBase + { + public: + AttributeAdapter(op::v5::NonMaxSuppression::BoxEncodingType& value) + : EnumAttributeAdapterBase(value) + { + } + + static constexpr DiscreteTypeInfo type_info{ + "AttributeAdapter", 1}; + const DiscreteTypeInfo& get_type_info() const override { return type_info; } + }; +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp b/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp index e2102bd..bddd7c8 100644 --- a/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp +++ b/ngraph/core/include/ngraph/opsets/opset5_tbl.hpp @@ -157,7 +157,6 @@ NGRAPH_OP(CTCLoss, ngraph::op::v4) NGRAPH_OP(HSwish, ngraph::op::v4) NGRAPH_OP(Interpolate, ngraph::op::v4) NGRAPH_OP(Mish, ngraph::op::v4) -NGRAPH_OP(NonMaxSuppression, ngraph::op::v4) NGRAPH_OP(ReduceL1, ngraph::op::v4) NGRAPH_OP(ReduceL2, ngraph::op::v4) NGRAPH_OP(SoftPlus, ngraph::op::v4) @@ -167,6 +166,7 @@ NGRAPH_OP(Swish, ngraph::op::v4) NGRAPH_OP(GatherND, ngraph::op::v5) NGRAPH_OP(LogSoftmax, ngraph::op::v5) NGRAPH_OP(LSTMSequence, ngraph::op::v5) +NGRAPH_OP(NonMaxSuppression, ngraph::op::v5) NGRAPH_OP(GRUSequence, ngraph::op::v5) NGRAPH_OP(RNNSequence, ngraph::op::v5) NGRAPH_OP(Round, ngraph::op::v5) diff --git a/ngraph/core/src/op/non_max_suppression.cpp b/ngraph/core/src/op/non_max_suppression.cpp index f07a5e3..d631545 100644 --- a/ngraph/core/src/op/non_max_suppression.cpp +++ b/ngraph/core/src/op/non_max_suppression.cpp @@ -15,6 +15,7 @@ //***************************************************************************** #include "ngraph/op/non_max_suppression.hpp" +#include #include "ngraph/attribute_visitor.hpp" #include "ngraph/op/constant.hpp" #include "ngraph/op/util/op_types.hpp" @@ -530,3 +531,328 @@ void op::v4::NonMaxSuppression::validate_and_infer_types() } set_output_type(0, m_output_type, out_shape); } + +// ------------------------------ V5 ------------------------------ + +NGRAPH_RTTI_DEFINITION(op::v5::NonMaxSuppression, "NonMaxSuppression", 5); + +op::v5::NonMaxSuppression::NonMaxSuppression( + const Output& boxes, + const Output& scores, + const op::v5::NonMaxSuppression::BoxEncodingType box_encoding, + const bool sort_result_descending, + const element::Type& output_type) + : Op({boxes, + scores, + op::Constant::create(element::i64, Shape{}, {0}), + op::Constant::create(element::f32, Shape{}, {.0f}), + op::Constant::create(element::f32, Shape{}, {.0f}), + op::Constant::create(element::f32, Shape{}, {.0f})}) + , m_box_encoding{box_encoding} + , m_sort_result_descending{sort_result_descending} + , m_output_type{output_type} +{ + constructor_validate_and_infer_types(); +} + +op::v5::NonMaxSuppression::NonMaxSuppression( + const Output& boxes, + const Output& scores, + const Output& max_output_boxes_per_class, + const op::v5::NonMaxSuppression::BoxEncodingType box_encoding, + const bool sort_result_descending, + const element::Type& output_type) + : Op({boxes, + scores, + max_output_boxes_per_class, + op::Constant::create(element::f32, Shape{}, {.0f}), + op::Constant::create(element::f32, Shape{}, {.0f}), + op::Constant::create(element::f32, Shape{}, {.0f})}) + , m_box_encoding{box_encoding} + , m_sort_result_descending{sort_result_descending} + , m_output_type{output_type} +{ + constructor_validate_and_infer_types(); +} + +op::v5::NonMaxSuppression::NonMaxSuppression( + const Output& boxes, + const Output& scores, + const Output& max_output_boxes_per_class, + const Output& iou_threshold, + const op::v5::NonMaxSuppression::BoxEncodingType box_encoding, + const bool sort_result_descending, + const element::Type& output_type) + : Op({boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + op::Constant::create(element::f32, Shape{}, {.0f}), + op::Constant::create(element::f32, Shape{}, {.0f})}) + , m_box_encoding{box_encoding} + , m_sort_result_descending{sort_result_descending} + , m_output_type{output_type} +{ + constructor_validate_and_infer_types(); +} + +op::v5::NonMaxSuppression::NonMaxSuppression( + const Output& boxes, + const Output& scores, + const Output& max_output_boxes_per_class, + const Output& iou_threshold, + const Output& score_threshold, + const op::v5::NonMaxSuppression::BoxEncodingType box_encoding, + const bool sort_result_descending, + const element::Type& output_type) + : Op({boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + op::Constant::create(element::f32, Shape{}, {.0f})}) + , m_box_encoding{box_encoding} + , m_sort_result_descending{sort_result_descending} + , m_output_type{output_type} +{ + constructor_validate_and_infer_types(); +} + +op::v5::NonMaxSuppression::NonMaxSuppression( + const Output& boxes, + const Output& scores, + const Output& max_output_boxes_per_class, + const Output& iou_threshold, + const Output& score_threshold, + const Output& soft_nms_sigma, + const op::v5::NonMaxSuppression::BoxEncodingType box_encoding, + const bool sort_result_descending, + const element::Type& output_type) + : Op({boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + soft_nms_sigma}) + , m_box_encoding{box_encoding} + , m_sort_result_descending{sort_result_descending} + , m_output_type{output_type} +{ + constructor_validate_and_infer_types(); +} + +shared_ptr + op::v5::NonMaxSuppression::clone_with_new_inputs(const OutputVector& new_args) const +{ + check_new_args_count(this, new_args); + NODE_VALIDATION_CHECK(this, + new_args.size() >= 2 && new_args.size() <= 6, + "Number of inputs must be 2, 3, 4, 5 or 6"); + + const auto& arg2 = new_args.size() > 2 + ? new_args.at(2) + : ngraph::op::Constant::create(element::i64, Shape{}, {0}); + const auto& arg3 = new_args.size() > 3 + ? new_args.at(3) + : ngraph::op::Constant::create(element::f32, Shape{}, {.0f}); + const auto& arg4 = new_args.size() > 4 + ? new_args.at(4) + : ngraph::op::Constant::create(element::f32, Shape{}, {.0f}); + const auto& arg5 = new_args.size() > 5 + ? new_args.at(5) + : ngraph::op::Constant::create(element::f32, Shape{}, {.0f}); + + return std::make_shared(new_args.at(0), + new_args.at(1), + arg2, + arg3, + arg4, + arg5, + m_box_encoding, + m_sort_result_descending, + m_output_type); +} + +void op::v5::NonMaxSuppression::validate() +{ + const auto boxes_ps = get_input_partial_shape(0); + const auto scores_ps = get_input_partial_shape(1); + + NODE_VALIDATION_CHECK(this, + m_output_type == element::i64 || m_output_type == element::i32, + "Output type must be i32 or i64"); + + if (boxes_ps.is_dynamic() || scores_ps.is_dynamic()) + { + return; + } + + NODE_VALIDATION_CHECK(this, + boxes_ps.rank().is_static() && boxes_ps.rank().get_length() == 3, + "Expected a 3D tensor for the 'boxes' input. Got: ", + boxes_ps); + + NODE_VALIDATION_CHECK(this, + scores_ps.rank().is_static() && scores_ps.rank().get_length() == 3, + "Expected a 3D tensor for the 'scores' input. Got: ", + scores_ps); + + if (inputs().size() >= 3) + { + const auto max_boxes_ps = get_input_partial_shape(2); + NODE_VALIDATION_CHECK(this, + max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()), + "Expected a scalar for the 'max_output_boxes_per_class' input. Got: ", + max_boxes_ps); + } + + if (inputs().size() >= 4) + { + const auto iou_threshold_ps = get_input_partial_shape(3); + NODE_VALIDATION_CHECK(this, + iou_threshold_ps.is_dynamic() || + is_scalar(iou_threshold_ps.to_shape()), + "Expected a scalar for the 'iou_threshold' input. Got: ", + iou_threshold_ps); + } + + if (inputs().size() >= 5) + { + const auto score_threshold_ps = get_input_partial_shape(4); + NODE_VALIDATION_CHECK(this, + score_threshold_ps.is_dynamic() || + is_scalar(score_threshold_ps.to_shape()), + "Expected a scalar for the 'score_threshold' input. Got: ", + score_threshold_ps); + } + + if (inputs().size() >= 6) + { + const auto soft_nms_sigma = get_input_partial_shape(5); + NODE_VALIDATION_CHECK(this, + soft_nms_sigma.is_dynamic() || is_scalar(soft_nms_sigma.to_shape()), + "Expected a scalar for the 'soft_nms_sigma' input. Got: ", + soft_nms_sigma); + } + + const auto num_batches_boxes = boxes_ps[0]; + const auto num_batches_scores = scores_ps[0]; + NODE_VALIDATION_CHECK(this, + num_batches_boxes.same_scheme(num_batches_scores), + "The first dimension of both 'boxes' and 'scores' must match. Boxes: ", + num_batches_boxes, + "; Scores: ", + num_batches_scores); + + const auto num_boxes_boxes = boxes_ps[1]; + const auto num_boxes_scores = scores_ps[2]; + NODE_VALIDATION_CHECK(this, + num_boxes_boxes.same_scheme(num_boxes_scores), + "'boxes' and 'scores' input shapes must match at the second and third " + "dimension respectively. Boxes: ", + num_boxes_boxes, + "; Scores: ", + num_boxes_scores); + + NODE_VALIDATION_CHECK(this, + boxes_ps[2].is_static() && boxes_ps[2].get_length() == 4u, + "The last dimension of the 'boxes' input must be equal to 4. Got:", + boxes_ps[2]); +} + +int64_t op::v5::NonMaxSuppression::max_boxes_output_from_input() const +{ + int64_t max_output_boxes{0}; + + const auto max_output_boxes_input = + as_type_ptr(input_value(2).get_node_shared_ptr()); + max_output_boxes = max_output_boxes_input->cast_vector().at(0); + + return max_output_boxes; +} + +static constexpr size_t boxes_port = 0; +static constexpr size_t scores_port = 1; +static constexpr size_t iou_threshold_port = 3; +static constexpr size_t score_threshold_port = 4; +static constexpr size_t soft_nms_sigma_port = 5; + +float op::v5::NonMaxSuppression::iou_threshold_from_input() const +{ + float iou_threshold = 0.0f; + + const auto iou_threshold_input = + as_type_ptr(input_value(iou_threshold_port).get_node_shared_ptr()); + iou_threshold = iou_threshold_input->cast_vector().at(0); + + return iou_threshold; +} + +float op::v5::NonMaxSuppression::score_threshold_from_input() const +{ + float score_threshold = 0.0f; + + const auto score_threshold_input = + as_type_ptr(input_value(score_threshold_port).get_node_shared_ptr()); + score_threshold = score_threshold_input->cast_vector().at(0); + + return score_threshold; +} + +float op::v5::NonMaxSuppression::soft_nms_sigma_from_input() const +{ + float soft_nms_sigma = 0.0f; + + const auto soft_nms_sigma_input = + as_type_ptr(input_value(soft_nms_sigma_port).get_node_shared_ptr()); + soft_nms_sigma = soft_nms_sigma_input->cast_vector().at(0); + + return soft_nms_sigma; +} + +bool ngraph::op::v5::NonMaxSuppression::visit_attributes(AttributeVisitor& visitor) +{ + visitor.on_attribute("box_encoding", m_box_encoding); + visitor.on_attribute("sort_result_descending", m_sort_result_descending); + visitor.on_attribute("output_type", m_output_type); + return true; +} + +void op::v5::NonMaxSuppression::validate_and_infer_types() +{ + const auto boxes_ps = get_input_partial_shape(0); + const auto scores_ps = get_input_partial_shape(1); + + // NonMaxSuppression produces triplets + // that have the following format: [batch_index, class_index, box_index] + PartialShape out_shape = {Dimension::dynamic(), 3}; + + validate(); + + set_output_type(0, m_output_type, out_shape); + set_output_type(1, element::f32, out_shape); + set_output_type(2, m_output_type, Shape{1}); +} + +namespace ngraph +{ + template <> + EnumNames& + EnumNames::get() + { + static auto enum_names = EnumNames( + "op::v5::NonMaxSuppression::BoxEncodingType", + {{"corner", op::v5::NonMaxSuppression::BoxEncodingType::CORNER}, + {"center", op::v5::NonMaxSuppression::BoxEncodingType::CENTER}}); + return enum_names; + } + + constexpr DiscreteTypeInfo + AttributeAdapter::type_info; + + std::ostream& operator<<(std::ostream& s, + const op::v5::NonMaxSuppression::BoxEncodingType& type) + { + return s << as_string(type); + } +} // namespace ngraph diff --git a/ngraph/test/type_prop/non_max_suppression.cpp b/ngraph/test/type_prop/non_max_suppression.cpp index 405a384..df8bf1a 100644 --- a/ngraph/test/type_prop/non_max_suppression.cpp +++ b/ngraph/test/type_prop/non_max_suppression.cpp @@ -547,3 +547,224 @@ TEST(type_prop, nms_v4_dynamic_boxes_and_scores) ASSERT_TRUE( nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); } + +// ------------------------------ V5 ------------------------------ + +TEST(type_prop, nms_v5_incorrect_boxes_rank) +{ + try + { + const auto boxes = make_shared(element::f32, Shape{1, 2, 3, 4}); + const auto scores = make_shared(element::f32, Shape{1, 2, 3}); + + make_shared(boxes, scores); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), "Expected a 3D tensor for the 'boxes' input"); + } +} + +TEST(type_prop, nms_v5_incorrect_scores_rank) +{ + try + { + const auto boxes = make_shared(element::f32, Shape{1, 2, 3}); + const auto scores = make_shared(element::f32, Shape{1, 2}); + + make_shared(boxes, scores); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), "Expected a 3D tensor for the 'scores' input"); + } +} + +TEST(type_prop, nms_v5_incorrect_scheme_num_batches) +{ + try + { + const auto boxes = make_shared(element::f32, Shape{1, 2, 3}); + const auto scores = make_shared(element::f32, Shape{2, 2, 3}); + + make_shared(boxes, scores); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + "The first dimension of both 'boxes' and 'scores' must match"); + } +} + +TEST(type_prop, nms_v5_incorrect_scheme_num_boxes) +{ + try + { + const auto boxes = make_shared(element::f32, Shape{1, 2, 3}); + const auto scores = make_shared(element::f32, Shape{1, 2, 3}); + + make_shared(boxes, scores); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + "'boxes' and 'scores' input shapes must match at the second and third " + "dimension respectively"); + } +} + +TEST(type_prop, nms_v5_scalar_inputs_check) +{ + const auto boxes = make_shared(element::f32, Shape{1, 2, 4}); + const auto scores = make_shared(element::f32, Shape{1, 2, 2}); + + const auto scalar = make_shared(element::f32, Shape{}); + const auto non_scalar = make_shared(element::f32, Shape{1}); + + try + { + make_shared(boxes, scores, non_scalar, scalar, scalar); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + "Expected a scalar for the 'max_output_boxes_per_class' input"); + } + + try + { + make_shared(boxes, scores, scalar, non_scalar, scalar); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'iou_threshold' input"); + } + + try + { + make_shared(boxes, scores, scalar, scalar, non_scalar); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'score_threshold' input"); + } + + try + { + make_shared(boxes, scores, scalar, scalar, scalar, non_scalar); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'soft_nms_sigma' input"); + } +} + +TEST(type_prop, nms_v5_output_shape) +{ + const auto boxes = make_shared(element::f32, Shape{5, 2, 4}); + const auto scores = make_shared(element::f32, Shape{5, 3, 2}); + + const auto nms = make_shared(boxes, scores); + + ASSERT_TRUE( + nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); + ASSERT_TRUE( + nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3})); + + EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); +} + +TEST(type_prop, nms_v5_output_shape_2) +{ + const auto boxes = make_shared(element::f32, Shape{2, 7, 4}); + const auto scores = make_shared(element::f32, Shape{2, 5, 7}); + const auto max_output_boxes_per_class = op::Constant::create(element::i32, Shape{}, {3}); + const auto iou_threshold = make_shared(element::f32, Shape{}); + const auto score_threshold = make_shared(element::f32, Shape{}); + + const auto nms = make_shared( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold); + + ASSERT_EQ(nms->get_output_element_type(0), element::i64); + ASSERT_EQ(nms->get_output_element_type(1), element::f32); + ASSERT_EQ(nms->get_output_element_type(2), element::i64); + ASSERT_TRUE( + nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); + ASSERT_TRUE( + nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3})); + + EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); +} + +TEST(type_prop, nms_v5_output_shape_3) +{ + const auto boxes = make_shared(element::f32, Shape{2, 7, 4}); + const auto scores = make_shared(element::f32, Shape{2, 5, 7}); + const auto max_output_boxes_per_class = op::Constant::create(element::i16, Shape{}, {1000}); + const auto iou_threshold = make_shared(element::f32, Shape{}); + const auto score_threshold = make_shared(element::f32, Shape{}); + + const auto nms = make_shared( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold); + + ASSERT_EQ(nms->get_output_element_type(0), element::i64); + ASSERT_EQ(nms->get_output_element_type(1), element::f32); + ASSERT_EQ(nms->get_output_element_type(2), element::i64); + ASSERT_TRUE( + nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); + ASSERT_TRUE( + nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3})); + + EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); +} + +TEST(type_prop, nms_v5_output_shape_i32) +{ + const auto boxes = make_shared(element::f32, Shape{2, 7, 4}); + const auto scores = make_shared(element::f32, Shape{2, 5, 7}); + const auto max_output_boxes_per_class = op::Constant::create(element::i16, Shape{}, {3}); + const auto iou_threshold = make_shared(element::f32, Shape{}); + const auto score_threshold = make_shared(element::f32, Shape{}); + + const auto nms = + make_shared(boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + op::v5::NonMaxSuppression::BoxEncodingType::CORNER, + true, + element::i32); + + ASSERT_EQ(nms->get_output_element_type(0), element::i32); + ASSERT_EQ(nms->get_output_element_type(1), element::f32); + ASSERT_EQ(nms->get_output_element_type(2), element::i32); + ASSERT_TRUE( + nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); + ASSERT_TRUE( + nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3})); + + EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); +} + +TEST(type_prop, nms_v5_dynamic_boxes_and_scores) +{ + const auto boxes = make_shared(element::f32, PartialShape::dynamic()); + const auto scores = make_shared(element::f32, PartialShape::dynamic()); + const auto max_output_boxes_per_class = op::Constant::create(element::i16, Shape{}, {3}); + const auto iou_threshold = make_shared(element::f32, Shape{}); + const auto score_threshold = make_shared(element::f32, Shape{}); + + const auto nms = make_shared( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold); + + ASSERT_EQ(nms->get_output_element_type(0), element::i64); + ASSERT_EQ(nms->get_output_element_type(1), element::f32); + ASSERT_EQ(nms->get_output_element_type(2), element::i64); + ASSERT_TRUE( + nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); + ASSERT_TRUE( + nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3})); + + EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); +} \ No newline at end of file -- 2.7.4