Round-5 nGraph implementation (#2652)
authorAnton Chetverikov <Anton.Chetverikov@intel.com>
Wed, 14 Oct 2020 05:10:05 +0000 (08:10 +0300)
committerGitHub <noreply@github.com>
Wed, 14 Oct 2020 05:10:05 +0000 (08:10 +0300)
* Implement nGraph Round-5 operation

* Remove reference implementation

* Add shape infer tests

* Fix codestyle

ngraph/core/include/ngraph/op/op_version_tbl.hpp
ngraph/core/include/ngraph/op/round.hpp
ngraph/core/include/ngraph/opsets/opset5_tbl.hpp
ngraph/core/src/op/round.cpp
ngraph/test/CMakeLists.txt
ngraph/test/type_prop/round.cpp [new file with mode: 0644]

index 94a6e71..decf321 100644 (file)
@@ -163,6 +163,7 @@ NGRAPH_OP(Reverse, ngraph::op::v0, 0)
 NGRAPH_OP(Reverse, ngraph::op::v1, 1)
 NGRAPH_OP(ReverseSequence, ngraph::op::v0, 0)
 NGRAPH_OP(Round, ngraph::op::v0, 0)
+NGRAPH_OP(Round, ngraph::op::v5, 5)
 NGRAPH_OP(ROIAlign, ngraph::op::v3, 3)
 NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3, 3)
 NGRAPH_OP(ScatterUpdate, ngraph::op::v3, 3)
index 8e37ae1..69026d8 100644 (file)
@@ -16,6 +16,8 @@
 
 #pragma once
 
+#include "ngraph/node.hpp"
+#include "ngraph/op/op.hpp"
 #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
 
 namespace ngraph
@@ -54,5 +56,61 @@ namespace ngraph
         NGRAPH_SUPPRESS_DEPRECATED_START
         using v0::Round;
         NGRAPH_SUPPRESS_DEPRECATED_END
+
+        namespace v5
+        {
+            /// \brief Elementwise round operation. The output is round to the nearest integer
+            /// for each value. In case of halfs, the rule is defined in attribute 'mode':
+            ///     'HALF_TO_EVEN' - round halfs to the nearest even integer.
+            ///     'HALF_AWAY_FROM_ZERO': - round in such a way that the result heads away from
+            /// zero.
+
+            class NGRAPH_API Round : public ngraph::op::Op
+            {
+            public:
+                enum class RoundMode
+                {
+                    HALF_TO_EVEN,
+                    HALF_AWAY_FROM_ZERO
+                };
+                NGRAPH_RTTI_DECLARATION;
+
+                /// \brief Constructs a round operation.
+                Round() = default;
+
+                /// \brief Constructs a round operation.
+                ///
+                /// \param arg Node that produces the input tensor.
+                /// \param mode Rule to resolve halfs
+                Round(const Output<Node>& arg, const RoundMode mode);
+
+                bool visit_attributes(AttributeVisitor& visitor) override;
+                void validate_and_infer_types() override;
+
+                virtual std::shared_ptr<Node>
+                    clone_with_new_inputs(const OutputVector& new_args) const override;
+
+                RoundMode get_mode() const { return m_mode; }
+            private:
+                RoundMode m_mode;
+            };
+        }
     }
+    NGRAPH_API
+    std::ostream& operator<<(std::ostream& s, const op::v5::Round::RoundMode& type);
+
+    template <>
+    class NGRAPH_API AttributeAdapter<op::v5::Round::RoundMode>
+        : public EnumAttributeAdapterBase<op::v5::Round::RoundMode>
+    {
+    public:
+        AttributeAdapter(op::v5::Round::RoundMode& value)
+            : EnumAttributeAdapterBase<op::v5::Round::RoundMode>(value)
+        {
+        }
+
+        static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v5::Round::RoundMode>",
+                                                    5};
+        const DiscreteTypeInfo& get_type_info() const override { return type_info; }
+    };
 }
index 6128d65..43c8d50 100644 (file)
@@ -167,4 +167,5 @@ NGRAPH_OP(Swish, ngraph::op::v4)
 NGRAPH_OP(LogSoftmax, ngraph::op::v5)
 NGRAPH_OP(LSTMSequence, ngraph::op::v5)
 NGRAPH_OP(GRUSequence, ngraph::op::v5)
-NGRAPH_OP(RNNSequence, ngraph::op::v5)
\ No newline at end of file
+NGRAPH_OP(RNNSequence, ngraph::op::v5)
+NGRAPH_OP(Round, ngraph::op::v5)
index 7bf920c..e296ba2 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "ngraph/op/round.hpp"
 #include "itt.hpp"
+#include "ngraph/attribute_visitor.hpp"
 #include "ngraph/op/util/eval_copy.hpp"
 #include "ngraph/runtime/host_tensor.hpp"
 #include "ngraph/runtime/reference/copy.hpp"
@@ -28,16 +29,16 @@ using namespace ngraph;
 
 constexpr NodeTypeInfo op::Round::type_info;
 
-op::Round::Round(const Output<Node>& arg)
+op::v0::Round::Round(const Output<Node>& arg)
     : UnaryElementwiseArithmetic(arg)
 {
     constructor_validate_and_infer_types();
 }
 
-shared_ptr<Node> op::Round::clone_with_new_inputs(const OutputVector& new_args) const
+shared_ptr<Node> op::v0::Round::clone_with_new_inputs(const OutputVector& new_args) const
 {
     check_new_args_count(this, new_args);
-    return make_shared<Round>(new_args.at(0));
+    return make_shared<v0::Round>(new_args.at(0));
 }
 
 namespace roundop
@@ -94,8 +95,56 @@ namespace roundop
     }
 }
 
-bool op::Round::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
+bool op::v0::Round::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
 {
-    OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::Round::evaluate");
+    OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Round::evaluate");
     return roundop::evaluate_round(inputs[0], outputs[0], shape_size(get_output_shape(0)));
 }
+NGRAPH_SUPPRESS_DEPRECATED_END
+
+NGRAPH_RTTI_DEFINITION(op::v5::Round, "Round", 5);
+
+op::v5::Round::Round(const Output<Node>& arg, RoundMode mode)
+    : Op({arg})
+    , m_mode(mode)
+{
+    constructor_validate_and_infer_types();
+}
+
+bool ngraph::op::v5::Round::visit_attributes(AttributeVisitor& visitor)
+{
+    visitor.on_attribute("mode", m_mode);
+    return true;
+}
+
+void op::v5::Round::validate_and_infer_types()
+{
+    set_output_size(1);
+    set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
+}
+
+shared_ptr<Node> op::v5::Round::clone_with_new_inputs(const OutputVector& new_args) const
+{
+    check_new_args_count(this, new_args);
+    return make_shared<v5::Round>(new_args.at(0), m_mode);
+}
+
+namespace ngraph
+{
+    template <>
+    EnumNames<op::v5::Round::RoundMode>& EnumNames<op::v5::Round::RoundMode>::get()
+    {
+        static auto enum_names = EnumNames<op::v5::Round::RoundMode>(
+            "op::v5::Round::RoundMode",
+            {{"half_to_even", op::v5::Round::RoundMode::HALF_TO_EVEN},
+             {"half_away_from_zero", op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO}});
+        return enum_names;
+    }
+
+    constexpr DiscreteTypeInfo AttributeAdapter<op::v5::Round::RoundMode>::type_info;
+
+    std::ostream& operator<<(std::ostream& s, const op::v5::Round::RoundMode& type)
+    {
+        return s << as_string(type);
+    }
+} // namespace ngraph
index 253c1e9..5f3702a 100644 (file)
@@ -161,6 +161,7 @@ set(SRC
     type_prop/reverse.cpp
     type_prop/reverse_sequence.cpp
     type_prop/roi_align.cpp
+    type_prop/round.cpp
     type_prop/rnn_cell.cpp
     type_prop/rnn_sequence.cpp
     type_prop/scatter_elements_update.cpp
diff --git a/ngraph/test/type_prop/round.cpp b/ngraph/test/type_prop/round.cpp
new file mode 100644 (file)
index 0000000..c7a16ec
--- /dev/null
@@ -0,0 +1,91 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "gtest/gtest.h"
+#include "ngraph/ngraph.hpp"
+#include "util/type_prop.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+TEST(type_prop, rounding_to_even)
+{
+    auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+    auto round_func = make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_TO_EVEN);
+    EXPECT_EQ(round_func->get_element_type(), element::f32);
+    EXPECT_EQ(round_func->get_shape(), (Shape{1, 3, 6}));
+}
+
+TEST(type_prop, rounding_away)
+{
+    auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+    auto round_func =
+        make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
+    EXPECT_EQ(round_func->get_element_type(), element::f32);
+    EXPECT_EQ(round_func->get_shape(), (Shape{1, 3, 6}));
+}
+
+TEST(type_prop, rounding_to_even_partial)
+{
+    auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
+    auto softplus_func = make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_TO_EVEN);
+    EXPECT_EQ(softplus_func->get_element_type(), element::f32);
+    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).same_scheme(
+        (PartialShape{1, Dimension::dynamic(), 6})));
+
+    // rank unknown
+    auto softplus_partial = make_shared<op::v5::Round>(
+        make_shared<op::Parameter>(element::f32, PartialShape::dynamic()),
+        op::v5::Round::RoundMode::HALF_TO_EVEN);
+    ASSERT_TRUE(softplus_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
+}
+
+TEST(type_prop, rounding_away_partial)
+{
+    auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
+    auto softplus_func =
+        make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
+    EXPECT_EQ(softplus_func->get_element_type(), element::f32);
+    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).same_scheme(
+        (PartialShape{1, Dimension::dynamic(), 6})));
+
+    // rank unknown
+    auto softplus_partial = make_shared<op::v5::Round>(
+        make_shared<op::Parameter>(element::f32, PartialShape::dynamic()),
+        op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
+    ASSERT_TRUE(softplus_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
+}
+
+TEST(type_prop, rounding_to_even_partial_static_rank)
+{
+    auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
+    auto softplus_func = make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_TO_EVEN);
+    EXPECT_EQ(softplus_func->get_element_type(), element::f32);
+    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).same_scheme(
+        (PartialShape{1, Dimension::dynamic(), 6})));
+    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).rank().is_static());
+}
+
+TEST(type_prop, rounding_away_partial_static_rank)
+{
+    auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
+    auto softplus_func =
+        make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
+    EXPECT_EQ(softplus_func->get_element_type(), element::f32);
+    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).same_scheme(
+        (PartialShape{1, Dimension::dynamic(), 6})));
+    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).rank().is_static());
+}