Add Round-5 operation (#2328)
authorAnton Chetverikov <Anton.Chetverikov@intel.com>
Tue, 20 Oct 2020 15:36:19 +0000 (18:36 +0300)
committerGitHub <noreply@github.com>
Tue, 20 Oct 2020 15:36:19 +0000 (18:36 +0300)
* Add Round-5 operation

* Add ONNX Round to supported operation list

* Add ngraph implementation for Round operation

* Update MO part

* Create UnaryElementwise class, update Round Operation

* Fix mode attr in mxnet extractor

* Add tests for Round shape infer

* Update 'enable' attr

* Update MO IR Reader to support UnaryElementwise operations

* Minor test refactor

* Update ngraph Round operation

* Add reference implementation

* Add test for reference implementation

* Add test for shape infer

* Add test for IE IR Reader

* AddRound operation to python api

* Fix missed mode attr

* Update Round operation version

* Fix codestyle

* Add MxNet Round to supported layers list

* Fix error in reference

* Fix comments style

* Update CMake file

* Update Ngraph reference test

* Update IE IR Reader tests

* Return v0::Round operation

* Update shape infer tests

* Fix v0::Round reference

* Fix codestyle

* Enum instead of string

* Fix codestyle

* Add Mode attribute adapter

* Update Mode attr

* Fix reference for v0::Round

* Fix codestyle

* Fix mode attr

* Fix get() method

* Fix codestyle in python api

* Update test info

* Fix ngraph api part

* Ad round v5 to interpreter tests

* Fix codestyle is ie reader test

* Update ngraph python api __init__.py file

* Adde opser5 to dafault opsets in ie_ir reader

* Add parser for Round layer

* Remove redundant spaces

* Add round creator to appropriate list

* Remove redundant import

* Commit to bump infrastructure version

I'm sorry for this, but this commit will be squashed on merge to master anyway and it is needed for your PR to correctly pass the pipeline

* Fix import

* fix codestyle

* Fix ngraph api part

* Add shape infer tests in python api

* Add .upper() for mode attr

* Refactor MO shape infer test for Round op

* Update tests and add comments

* Revert "Commit to bump infrastructure version"

This reverts commit 56e6ae1e4c31439ba0d4636fa76782c03bf30aca.

* remove parser for Round layer

* Update Ronund-5 evaluate test

* Resolve review comments

Co-authored-by: User <user@nnlvdp-achetver.inn.intel.com>
Co-authored-by: Andrey Babushkin <andrey.babushkin@intel.com>
Co-authored-by: Anton Chetverikov <anton.chetverikov@.intel.com>
20 files changed:
docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md
inference-engine/tests/functional/inference_engine/ngraph_reader/round_test.cpp [new file with mode: 0644]
model-optimizer/extensions/front/mxnet/elementwise_ext.py
model-optimizer/extensions/front/onnx/elementwise_ext.py
model-optimizer/extensions/front/tf/elementwise_ext.py
model-optimizer/extensions/ops/elementwise.py
model-optimizer/extensions/ops/elementwise_test.py [new file with mode: 0644]
model-optimizer/mo/utils/ir_reader/layer_to_class.py
ngraph/core/include/ngraph/op/round.hpp
ngraph/core/reference/include/ngraph/runtime/reference/round.hpp
ngraph/core/src/op/round.cpp
ngraph/python/src/ngraph/__init__.py
ngraph/python/src/ngraph/opset5/__init__.py
ngraph/python/src/ngraph/opset5/ops.py
ngraph/python/tests/test_ngraph/test_ops_unary.py
ngraph/test/CMakeLists.txt
ngraph/test/op_eval/round.cpp [new file with mode: 0644]
ngraph/test/runtime/interpreter/int_executable.hpp
ngraph/test/runtime/interpreter/opset_int_tbl.hpp
ngraph/test/type_prop/round.cpp

index 78b47d2..7989866 100644 (file)
@@ -70,6 +70,7 @@ Standard MXNet\* symbols:
 | repeat | No |
 | rnn | No |
 | rnn_param_concat | No |
+| round | No |
 | sigmoid | No |
 | slice | No |
 | slice_axis | No |
@@ -385,6 +386,7 @@ Standard ONNX\* operators:
 | Reshape | No |
 | Resize | Coordinate transformation mode `tf_crop_and_resize` is not supported, `nearest` mode is not supported for 5D+ inputs. |
 | ReverseSequence | No |
+| Round | No |
 | Scatter | Supported if fuse-able to ScatterUpdate. MYRIAD only |
 | ScatterND | No |
 | ScatterElements | Supported if fuse-able to ScatterUpdate. MYRIAD only |
diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reader/round_test.cpp b/inference-engine/tests/functional/inference_engine/ngraph_reader/round_test.cpp
new file mode 100644 (file)
index 0000000..418501e
--- /dev/null
@@ -0,0 +1,191 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <string>
+#include "ngraph_reader_tests.hpp"
+TEST_F(NGraphReaderTests, ReadRoundEvenNetwork) {
+    std::string model = R"V0G0N(
+<net name="Network" version="10">
+    <layers>
+        <layer name="in1" type="Parameter" id="0" version="opset1">
+            <data element_type="f32" shape="1,3,22,22"/>
+            <output>
+                <port id="0" precision="FP32">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </output>
+        </layer>
+               <layer name="Round" id="1" type="Round" version="opset5">
+                       <data mode="half_to_even"/>
+            <input>
+                <port id="1" precision="FP32">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </input>
+            <output>
+                <port id="2" precision="FP32">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </output>
+        </layer>
+        <layer name="output" type="Result" id="2" version="opset1">
+            <input>
+                <port id="0" precision="FP32">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </input>
+        </layer>
+    </layers>
+    <edges>
+        <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
+        <edge from-layer="1" from-port="2" to-layer="2" to-port="0"/>
+    </edges>
+</net>
+)V0G0N";
+    std::string modelV5 = R"V0G0N(
+<net name="Network" version="5" precision="FP32" batch="1">
+    <layers>
+        <layer name="in1" type="Input" precision="FP32" id="0">
+            <output>
+                <port id="0">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </output>
+        </layer>
+               <layer name="Round" id="1" type="Round" version="opset5">
+                       <data mode="half_to_even"/>
+            <input>
+                <port id="1">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </input>
+            <output>
+                <port id="2">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </output>
+        </layer>
+    </layers>
+    <edges>
+        <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
+    </edges>
+</net>
+)V0G0N";
+
+    compareIRs(model, modelV5, 0);
+}
+
+TEST_F(NGraphReaderTests, ReadRoundAwayNetwork) {
+    std::string model = R"V0G0N(
+<net name="Network" version="10">
+    <layers>
+        <layer name="in1" type="Parameter" id="0" version="opset1">
+            <data element_type="f32" shape="1,3,22,22"/>
+            <output>
+                <port id="0" precision="FP32">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </output>
+        </layer>
+               <layer name="Round" id="1" type="Round" version="opset5">
+                       <data mode="half_away_from_zero"/>
+            <input>
+                <port id="1" precision="FP32">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </input>
+            <output>
+                <port id="2" precision="FP32">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </output>
+        </layer>
+        <layer name="output" type="Result" id="2" version="opset1">
+            <input>
+                <port id="0" precision="FP32">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </input>
+        </layer>
+    </layers>
+    <edges>
+        <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
+        <edge from-layer="1" from-port="2" to-layer="2" to-port="0"/>
+    </edges>
+</net>
+)V0G0N";
+    std::string modelV5 = R"V0G0N(
+<net name="Network" version="5" precision="FP32" batch="1">
+    <layers>
+        <layer name="in1" type="Input" precision="FP32" id="0">
+            <output>
+                <port id="0">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </output>
+        </layer>
+               <layer name="Round" id="1" type="Round" version="opset5">
+                       <data mode="half_away_from_zero"/>
+            <input>
+                <port id="1">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </input>
+            <output>
+                <port id="2">
+                    <dim>1</dim>
+                    <dim>3</dim>
+                    <dim>22</dim>
+                    <dim>22</dim>
+                </port>
+            </output>
+        </layer>
+    </layers>
+    <edges>
+        <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
+    </edges>
+</net>
+)V0G0N";
+
+compareIRs(model, modelV5, 0);
+}
index ee43443..56909de 100644 (file)
@@ -16,7 +16,7 @@
 import numpy as np
 
 from extensions.ops.elementwise import Mul, Sub, Add, Maximum, Minimum, Div, Greater, GreaterEqual, Equal, Less, \
-    LessEqual, Pow, NotEqual, LogicalAnd, LogicalOr
+    LessEqual, Pow, NotEqual, LogicalAnd, LogicalOr, Round
 from mo.front.extractor import FrontExtractorOp
 from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
 from mo.graph.graph import Node
@@ -414,3 +414,13 @@ class OnesFrontExtractor(FrontExtractorOp):
     def extract(cls, node):
         AttributedPower.update_node_stat(node, {'scale': 0, 'shift': 1})
         return cls.enabled
+
+
+class RoundExtractor(FrontExtractorOp):
+    op = 'round'
+    enabled = True
+
+    @classmethod
+    def extract(cls, node):
+        Round.update_node_stat(node, {'mode': 'half_away_from_zero'})
+        return cls.enabled
index 9c1442c..1ed5cb7 100644 (file)
@@ -16,7 +16,7 @@
 import numpy as np
 
 from extensions.ops.elementwise import Add, Sub, Mul, Div, Pow, Less, Equal, Greater, \
-    LogicalAnd, LogicalOr, LogicalXor
+    LogicalAnd, LogicalOr, LogicalXor, Round
 from mo.front.extractor import FrontExtractorOp
 from mo.front.onnx.extractors.utils import onnx_attr
 from mo.graph.graph import Node
@@ -188,3 +188,13 @@ class XorExtractor(FrontExtractorOp):
     def extract(cls, node):
         LogicalXor.update_node_stat(node)
         return cls.enabled
+
+
+class RoundFrontExtractor(FrontExtractorOp):
+    op = 'Round'
+    enabled = True
+
+    @classmethod
+    def extract(cls, node: Node):
+        Round.update_node_stat(node, {'mode': 'half_to_even'})
+        return cls.enabled
index c367c20..fe453b7 100644 (file)
@@ -14,7 +14,7 @@
  limitations under the License.
 """
 from extensions.ops.elementwise import Add, Mul, Sub, Div, Maximum, Minimum, Pow, LogicalAnd, LogicalOr, Equal, \
-    GreaterEqual, Greater, Less, LessEqual, NotEqual, FloorMod, BiasAdd, SquaredDifference
+    GreaterEqual, Greater, Less, LessEqual, NotEqual, FloorMod, BiasAdd, SquaredDifference, Round
 from mo.front.extractor import FrontExtractorOp
 from mo.front.tf.extractors.utils import tf_dtype_extractor
 from mo.ops.eltwise_n import EltwiseNAdd
@@ -271,3 +271,13 @@ class FloorModFrontExtractor(FrontExtractorOp):
     def extract(cls, node):
         FloorMod.update_node_stat(node)
         return cls.enabled
+
+
+class RoundExtractor(FrontExtractorOp):
+    op = 'Round'
+    enabled = True
+
+    @classmethod
+    def extract(cls, node):
+        Round.update_node_stat(node, {'mode': 'half_to_even'})
+        return cls.enabled
index 2038ed7..65dc693 100644 (file)
@@ -49,12 +49,13 @@ class Elementwise(Op):
     operation = None
     op = None
     op_type = None
+    version = 'opset1'
 
     def __init__(self, graph: Graph, attrs: dict):
         super().__init__(graph, {
             'op': self.op,
             'type': self.op_type,
-            'version': 'opset1',
+            'version': self.version,
             'infer': lambda node: eltwise_infer(node, self.operation),
             'type_infer': self.type_infer,
             'can_be_bias': True,
@@ -71,8 +72,18 @@ class Elementwise(Op):
         node.out_port(0).set_data_type(node.in_port(0).get_data_type())
 
 
+class UnaryElementwise(Elementwise):
+    def __init__(self, graph: Graph, attrs: dict):
+        super().__init__(graph, {**{
+            'in_ports_count': 1,
+        }, **attrs})
+
+    @staticmethod
+    def type_infer(node):
+        copy_type_infer(node)
+
+
 class Add(Elementwise):
-    enabled = False
     op = 'Add'
     op_type = 'Add'
     operation = staticmethod(lambda a, b: a + b)
@@ -87,14 +98,12 @@ class BiasAdd(Add):
 
 
 class Sub(Elementwise):
-    enabled = False
     op = 'Sub'
     op_type = 'Subtract'
     operation = staticmethod(lambda a, b: a - b)
 
 
 class Mul(Elementwise):
-    enabled = False
     op = 'Mul'
     op_type = 'Multiply'
     operation = staticmethod(lambda a, b: a * b)
@@ -105,21 +114,18 @@ def both_types_are_integer(a, b):
 
 
 class Div(Elementwise):
-    enabled = False
     op = 'Div'
     op_type = 'Divide'
     operation = staticmethod(lambda a, b: a // b if both_types_are_integer(a, b) else a / b)
 
 
 class SquaredDifference(Elementwise):
-    enabled = False
     op = 'SquaredDifference'
     op_type = 'SquaredDifference'
     operation = staticmethod(lambda a, b: (a - b) * (a - b))
 
 
 class Pow(Elementwise):
-    enabled = False
     op = 'Pow'
     op_type = 'Power'
 
@@ -147,103 +153,112 @@ class LogicalElementwise(Elementwise):
 
 
 class Greater(LogicalElementwise):
-    enabled = False
     op = 'Greater'
     op_type = 'Greater'
     operation = staticmethod(lambda a, b: a > b)
 
 
 class GreaterEqual(LogicalElementwise):
-    enabled = False
     op = 'GreaterEqual'
     op_type = 'GreaterEqual'
     operation = staticmethod(lambda a, b: a >= b)
 
 
 class Less(LogicalElementwise):
-    enabled = False
     op = 'Less'
     op_type = 'Less'
     operation = staticmethod(lambda a, b: a < b)
 
 
 class LessEqual(LogicalElementwise):
-    enabled = False
     op = 'LessEqual'
     op_type = 'LessEqual'
     operation = staticmethod(lambda a, b: a <= b)
 
 
 class Equal(LogicalElementwise):
-    enabled = False
     op = 'Equal'
     op_type = 'Equal'
     operation = staticmethod(lambda a, b: a == b)
 
 
 class NotEqual(LogicalElementwise):
-    enabled = False
     op = 'NotEqual'
     op_type = 'NotEqual'
     operation = staticmethod(lambda a, b: a != b)
 
 
 class Maximum(Elementwise):
-    enabled = False
     op = 'Maximum'
     op_type = 'Maximum'
     operation = staticmethod(lambda a, b: np.maximum(a, b))
 
 
 class Minimum(Elementwise):
-    enabled = False
     op = 'Minimum'
     op_type = 'Minimum'
     operation = staticmethod(lambda a, b: np.minimum(a, b))
 
 
-class Round(Elementwise):
-    enabled = False
+class Round(UnaryElementwise):
     op = 'Round'
-    op_type = None
-    version = 'extension'
-    operation = staticmethod(lambda a: np.round(a))
+    op_type = 'Round'
+    version = 'opset5'
+
+    def __init__(self, graph: Graph, attrs):
+        round_attrs = {'mode': 'half_to_even',
+                       'infer': self.infer
+                       }
+        round_attrs.update(attrs)
+        super().__init__(graph, round_attrs)
+
+    def backend_attrs(self):
+        return ['mode']
+
+    @classmethod
+    def infer(cls, node: Node):
+        node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
+
+        a = node.in_port(0).data.get_value()
+        if a is not None:
+            assert node.soft_get('mode') in ['half_to_even', 'half_away_from_zero'], \
+                'Round node {} has unsupported "mode" attribute value: {}'.format(node.soft_get('name', node.id),
+                                                                                  node.soft_get('mode'))
+            if node.mode == 'half_away_from_zero':
+                mask = (a >= 0)
+                out = np.empty_like(a)
+                out[mask] = np.floor(a[mask] + 0.5)
+                out[~mask] = np.ceil(a[~mask] - 0.5)
+            else:
+                out = np.round(a)
+            node.out_port(0).data.set_value(out)
 
 
 class LogicalOr(LogicalElementwise):
-    enabled = False
     op = 'LogicalOr'
     op_type = 'LogicalOr'
     operation = staticmethod(lambda a, b: np.logical_or(a, b))
 
 
 class LogicalXor(Elementwise):
-    enabled = False
     op = 'LogicalXor'
     op_type = 'LogicalXor'
     operation = staticmethod(lambda a, b: np.logical_xor(a, b))
 
 
 class LogicalAnd(LogicalElementwise):
-    enabled = False
     op = 'LogicalAnd'
     op_type = 'LogicalAnd'
     operation = staticmethod(lambda a, b: np.logical_and(a, b))
 
 
 class FloorMod(Elementwise):
-    enabled = False
     op = 'FloorMod'
     op_type = 'FloorMod'
     operation = staticmethod(lambda a, b: a % b)
 
 
-class Negative(Elementwise):
-    enabled = False
+class Negative(UnaryElementwise):
     op = 'Negative'
     op_type = 'Negative'
     operation = staticmethod(lambda a: -a)
-
-    @staticmethod
-    def type_infer(node):
-        copy_type_infer(node)
diff --git a/model-optimizer/extensions/ops/elementwise_test.py b/model-optimizer/extensions/ops/elementwise_test.py
new file mode 100644 (file)
index 0000000..990a89c
--- /dev/null
@@ -0,0 +1,92 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+import numpy as np
+
+from extensions.ops.elementwise import Round
+from mo.graph.graph import Node
+from mo.utils.unittest.graph import build_graph
+
+def round_test_graph(nodes_attributes, value, mode: str):
+    graph = build_graph(nodes_attributes,
+                        [
+                            ('node_1', 'elementwise_node'),
+                            ('elementwise_node', 'node_3')
+                        ],
+                        {
+                            'node_1': {
+                                'value': value
+                            },
+                            'elementwise_node': {
+                                'op': 'Round',
+                                'mode': mode,
+                            },
+                            'node_3': {
+                                'value': None
+                            }
+                        })
+    return graph
+
+
+class TestElementwiseOp(unittest.TestCase):
+    nodes_attributes = {
+        'node_1': {
+            'shape': np.array([13]),
+            'value': None
+        },
+        'elementwise_node': {
+            'op': None,
+            'kind': 'op',
+            'operation': None
+        },
+        'node_3': {
+            'shape': None
+        }
+    }
+
+    value = np.array([-23.5, -22.5, -2.5, -1.5, -0.5, 0.5, 0.9, 1.5, 2.3, 2.5, 3.5, 22.5, 23.5])
+
+    def test_elementwise_round_even_infer(self):
+        graph = round_test_graph(self.nodes_attributes, self.value, 'half_to_even')
+
+        graph.graph['layout'] = 'NCHW'
+        elementwise_node = Node(graph, 'elementwise_node')
+        Round.infer(elementwise_node)
+        exp_shape = np.array([13])
+        res_shape = graph.node['node_3']['shape']
+        res_value = graph.node['node_3']['value']
+        exp_value = np.array([-24., -22., -2., -2., -0., 0., 1., 2., 2., 2., 4., 22., 24.,])
+        for i, value in enumerate(exp_shape):
+            self.assertEqual(res_shape[i], value)
+        for i, value in enumerate(exp_value):
+            self.assertAlmostEqual(res_value[i], value)
+
+    def test_elementwise_round_away_infer(self):
+        graph = round_test_graph(self.nodes_attributes, self.value, 'half_away_from_zero')
+
+        graph.graph['layout'] = 'NCHW'
+        elementwise_node = Node(graph, 'elementwise_node')
+        Round.infer(elementwise_node)
+        exp_shape = np.array([13])
+        res_shape = graph.node['node_3']['shape']
+        res_value = graph.node['node_3']['value']
+        exp_value = np.array([-24., -23., -3., -2., -1., 1., 1., 2., 2., 3., 4., 23., 24.])
+        for i, value in enumerate(exp_shape):
+            self.assertEqual(res_shape[i], value)
+        for i, value in enumerate(exp_value):
+            self.assertAlmostEqual(res_value[i], value)
index 4694b28..5e237a8 100644 (file)
@@ -23,7 +23,7 @@ from extensions.back.TopKNormalizer import TopKNormalizer
 from extensions.ops.Cast import Cast
 from extensions.ops.ReduceOps import ReduceOp
 from extensions.ops.activation_ops import Activation
-from extensions.ops.elementwise import Elementwise, LogicalElementwise, BiasAdd, Div, Mul, Pow, Sub
+from extensions.ops.elementwise import Elementwise, UnaryElementwise, LogicalElementwise, BiasAdd, Div, Mul, Pow, Sub
 from extensions.ops.embedding_bag import EmbeddingBagBase
 from extensions.ops.psroipooling import DeformablePSROIPoolingOp
 from extensions.ops.scatter import Scatter
@@ -69,8 +69,8 @@ def collect_ops(path: str):
     """
     import_by_path(os.path.join(path, 'mo', 'ops'), ['mo', 'ops'])
     import_by_path(os.path.join(path, 'extensions', 'ops'), ['extensions', 'ops'])
-    update_registration(classes=[Op, Activation, Elementwise, EmbeddingBagBase,
-                                 LogicalElementwise, ReduceOp, Scatter, ScatterNDBase],
+    update_registration(classes=[Op, Activation, Elementwise, UnaryElementwise, LogicalElementwise,
+                                 EmbeddingBagBase, ReduceOp, Scatter, ScatterNDBase],
                         enabled_transforms=[], disabled_transforms=[])
 
 
index 69026d8..ebd4e26 100644 (file)
@@ -90,6 +90,9 @@ namespace ngraph
                 virtual std::shared_ptr<Node>
                     clone_with_new_inputs(const OutputVector& new_args) const override;
 
+                bool evaluate(const HostTensorVector& outputs,
+                              const HostTensorVector& inputs) const override;
+
                 RoundMode get_mode() const { return m_mode; }
             private:
                 RoundMode m_mode;
index 50949ab..9c4913d 100644 (file)
@@ -40,11 +40,18 @@ namespace ngraph
             }
 
             template <typename T>
-            void round(const T* arg, T* out, size_t count)
+            void round(const T* arg, T* out, size_t count, const op::v5::Round::RoundMode mode)
             {
                 for (size_t i = 0; i < count; ++i)
                 {
-                    out[i] = round_to_nearest_even(arg[i]);
+                    if (mode == op::v5::Round::RoundMode::HALF_TO_EVEN)
+                    {
+                        out[i] = round_to_nearest_even(arg[i]);
+                    }
+                    else
+                    {
+                        out[i] = std::round(arg[i]);
+                    }
                 }
             }
         }
index e296ba2..f78f932 100644 (file)
@@ -45,10 +45,14 @@ namespace roundop
 {
     // function used by TYPE_CASE
     template <element::Type_t ET>
-    inline bool evaluate(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count)
+    inline bool evaluate(const HostTensorPtr& arg0,
+                         const HostTensorPtr& out,
+                         const size_t count,
+                         const op::v5::Round::RoundMode mode)
     {
         using T = typename element_type_traits<ET>::value_type;
-        runtime::reference::round<T>(arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
+        runtime::reference::round<T>(
+            arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count, mode);
         return true;
     }
 
@@ -60,7 +64,10 @@ namespace roundop
         return true;
     }
 
-    bool evaluate_round(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count)
+    bool evaluate_round(const HostTensorPtr& arg0,
+                        const HostTensorPtr& out,
+                        const size_t count,
+                        const op::v5::Round::RoundMode mode)
     {
         bool rc = true;
         out->set_unary(arg0);
@@ -85,9 +92,11 @@ namespace roundop
             break;
             COPY_TENSOR(u64)(arg0, out, count);
             break;
-            TYPE_CASE(f16)(arg0, out, count);
+            TYPE_CASE(f16)(arg0, out, count, mode);
             break;
-            TYPE_CASE(f32)(arg0, out, count);
+            TYPE_CASE(f32)(arg0, out, count, mode);
+            break;
+            TYPE_CASE(bf16)(arg0, out, count, mode);
             break;
         default: rc = false; break;
         }
@@ -98,7 +107,10 @@ namespace roundop
 bool op::v0::Round::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
 {
     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Round::evaluate");
-    return roundop::evaluate_round(inputs[0], outputs[0], shape_size(get_output_shape(0)));
+    return roundop::evaluate_round(inputs[0],
+                                   outputs[0],
+                                   shape_size(get_output_shape(0)),
+                                   op::v5::Round::RoundMode::HALF_TO_EVEN);
 }
 NGRAPH_SUPPRESS_DEPRECATED_END
 
@@ -129,6 +141,13 @@ shared_ptr<Node> op::v5::Round::clone_with_new_inputs(const OutputVector& new_ar
     return make_shared<v5::Round>(new_args.at(0), m_mode);
 }
 
+bool op::v5::Round::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
+{
+    OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v5::Round::evaluate");
+    return roundop::evaluate_round(
+        inputs[0], outputs[0], shape_size(get_output_shape(0)), get_mode());
+}
+
 namespace ngraph
 {
     template <>
index 40572c3..03f5043 100644 (file)
@@ -141,6 +141,7 @@ from ngraph.opset5 import reverse_sequence
 from ngraph.opset5 import rnn_cell
 from ngraph.opset5 import roi_align
 from ngraph.opset5 import roi_pooling
+from ngraph.opset5 import round
 from ngraph.opset5 import scatter_elements_update
 from ngraph.opset5 import scatter_update
 from ngraph.opset5 import select
index e2b4a83..8c115b0 100644 (file)
@@ -128,6 +128,7 @@ from ngraph.opset1.ops import reverse_sequence
 from ngraph.opset3.ops import rnn_cell
 from ngraph.opset3.ops import roi_align
 from ngraph.opset2.ops import roi_pooling
+from ngraph.opset5.ops import round
 from ngraph.opset3.ops import scatter_elements_update
 from ngraph.opset3.ops import scatter_update
 from ngraph.opset1.ops import select
index 8c1950e..0c84162 100644 (file)
@@ -90,3 +90,17 @@ def log_softmax(data: NodeInput, axis: int, name: Optional[str] = None) -> Node:
     :return: The new node with LogSoftmax operation applied on each element.
     """
     return _get_node_factory_opset5().create("LogSoftmax", [as_node(data)], {"axis": axis})
+
+
+@nameable_op
+def round(data: NodeInput, mode: str = "half_to_even", name: Optional[str] = None) -> Node:
+    """Apply Round operation on each element of input tensor.
+
+    :param data: The tensor providing input data.
+    :param mode: Rule to round halfway cases. If set to 'half_to_even' then halfs round to the nearest even
+        integer or rounding in such a way that the result heads away from zero if `mode` attribute is
+        'half_away_from_zero`.
+    :param name: An optional name of the output node.
+    :return: The new node with Round operation applied on each element.
+    """
+    return _get_node_factory_opset5().create("Round", as_nodes(data), {"mode": mode.upper()})
index 951bb7d..7594425 100644 (file)
@@ -141,3 +141,41 @@ def test_hswish():
     assert node.get_output_size() == 1
     assert list(node.get_output_shape(0)) == [3, 10]
     assert node.get_output_element_type(0) == Type.f32
+
+
+def test_round_even():
+    float_dtype = np.float32
+    data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
+
+    node = ng.round(data, "HALF_TO_EVEN")
+    assert node.get_type_name() == "Round"
+    assert node.get_output_size() == 1
+    assert list(node.get_output_shape(0)) == [3, 10]
+    assert node.get_output_element_type(0) == Type.f32
+
+    # Excluded because this part needs mklddn implementation of Round operation
+    # Need to uncomment and check when 37651 will be done.
+    # input_tensor = np.array([-2.5, -1.5, -0.5, 0.5, 0.9, 1.5, 2.3, 2.5, 3.5], dtype=np.float32)
+    # expected = [-2.0, -2.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0, 4.0]
+
+    # result = run_op_node([input_tensor], ng.round, "HALF_TO_EVEN")
+    # assert np.allclose(result, expected)
+
+
+def test_round_away():
+    float_dtype = np.float32
+    data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
+
+    node = ng.round(data, "HALF_AWAY_FROM_ZERO")
+    assert node.get_type_name() == "Round"
+    assert node.get_output_size() == 1
+    assert list(node.get_output_shape(0)) == [3, 10]
+    assert node.get_output_element_type(0) == Type.f32
+
+    # Excluded because this part needs mklddn implementation of Round operation
+    # Need to uncomment and check when 37651 will be done.
+    # input_tensor = np.array([-2.5, -1.5, -0.5, 0.5, 0.9, 1.5, 2.3, 2.5, 3.5], dtype=np.float32)
+    # expected = [-3.0, -2.0, -1.0, 1.0, 1.0, 2.0, 2.0, 3.0, 4.0]
+
+    # result = run_op_node([input_tensor], ng.round, "HALF_AWAY_FROM_ZERO")
+    # assert np.allclose(result, expected)
index 706690a..67cdbe3 100644 (file)
@@ -79,6 +79,7 @@ set(SRC
     op_eval/reduce_l1.cpp
     op_eval/reduce_l2.cpp
     op_eval/roi_align.cpp
+    op_eval/round.cpp
     op_eval/softplus.cpp
     op_eval/split.cpp
     op_eval/swish.cpp
@@ -166,6 +167,7 @@ set(SRC
     type_prop/round.cpp
     type_prop/rnn_cell.cpp
     type_prop/rnn_sequence.cpp
+    type_prop/round.cpp
     type_prop/scatter_elements_update.cpp
     type_prop/scatter_nd_update.cpp
     type_prop/scatter_update.cpp
diff --git a/ngraph/test/op_eval/round.cpp b/ngraph/test/op_eval/round.cpp
new file mode 100644 (file)
index 0000000..e9807aa
--- /dev/null
@@ -0,0 +1,67 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "ngraph/op/round.hpp"
+#include "ngraph/runtime/host_tensor.hpp"
+#include "ngraph/validation_util.hpp"
+#include "runtime/backend.hpp"
+#include "util/test_tools.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+TEST(op_eval, rounding_to_even)
+{
+    auto p = make_shared<op::Parameter>(element::f32, Shape{9});
+    auto round = make_shared<op::v5::Round>(p, op::v5::Round::RoundMode::HALF_TO_EVEN);
+    auto fun = make_shared<Function>(OutputVector{round}, ParameterVector{p});
+
+    std::vector<float> inputs{-2.5f, -1.5f, -0.5f, 0.5f, 0.9f, 1.5f, 2.3f, 2.5f, 3.5f};
+    std::vector<float> expected_result{-2.f, -2.f, -0.f, 0.f, 1.f, 2.f, 2.f, 2.f, 4.f};
+
+    auto result = make_shared<HostTensor>();
+    ASSERT_TRUE(
+        fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(Shape{9}, inputs)}));
+    EXPECT_EQ(result->get_element_type(), element::f32);
+    EXPECT_EQ(result->get_shape(), Shape{9});
+    auto result_data = read_vector<float>(result);
+    for (auto i = 0; i < inputs.size(); i++)
+        EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
+}
+
+TEST(op_eval, rounding_away)
+{
+    auto p = make_shared<op::Parameter>(element::f32, Shape{9});
+    auto round = make_shared<op::v5::Round>(p, op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
+    auto fun = make_shared<Function>(OutputVector{round}, ParameterVector{p});
+
+    std::vector<float> inputs{-2.5f, -1.5f, -0.5f, 0.5f, 0.9f, 1.5f, 2.3f, 2.5f, 3.5f};
+    std::vector<float> expected_result{-3.f, -2.f, -1.f, 1.f, 1.f, 2.f, 2.f, 3.f, 4.f};
+
+    auto result = make_shared<HostTensor>();
+    ASSERT_TRUE(
+        fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(Shape{9}, inputs)}));
+    EXPECT_EQ(result->get_element_type(), element::f32);
+    EXPECT_EQ(result->get_shape(), Shape{9});
+    auto result_data = read_vector<float>(result);
+    for (auto i = 0; i < inputs.size(); i++)
+        EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
+}
index bd5db5e..be993f6 100644 (file)
@@ -1258,8 +1258,10 @@ protected:
         case OP_TYPEID::Round:
         {
             size_t element_count = shape_size(node.get_output_shape(0));
-            reference::round<T>(
-                args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
+            reference::round<T>(args[0]->get_data_ptr<const T>(),
+                                out[0]->get_data_ptr<T>(),
+                                element_count,
+                                op::v5::Round::RoundMode::HALF_TO_EVEN);
             break;
         }
         case OP_TYPEID::Select:
@@ -1495,6 +1497,7 @@ protected:
         case OP_TYPEID::Range:
         case OP_TYPEID::Reshape:
         case OP_TYPEID::Result:
+        case OP_TYPEID::Round_v5:
         case OP_TYPEID::ShapeOf_v3:
         case OP_TYPEID::ShapeOf:
         case OP_TYPEID::Softmax:
index 61fa35d..f9e1ee4 100644 (file)
@@ -57,5 +57,6 @@ NGRAPH_OP(GatherND, op::v5)
 NGRAPH_OP(LSTMSequence, op::v5)
 NGRAPH_OP(GRUSequence, op::v5)
 NGRAPH_OP(RNNSequence, op::v5)
+NGRAPH_OP(Round, op::v5)
 NGRAPH_OP(LogSoftmax, op::v5)
 #undef ID_SUFFIX
index c7a16ec..dde3c7a 100644 (file)
@@ -41,51 +41,51 @@ TEST(type_prop, rounding_away)
 TEST(type_prop, rounding_to_even_partial)
 {
     auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
-    auto softplus_func = make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_TO_EVEN);
-    EXPECT_EQ(softplus_func->get_element_type(), element::f32);
-    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).same_scheme(
+    auto round_func = make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_TO_EVEN);
+    EXPECT_EQ(round_func->get_element_type(), element::f32);
+    ASSERT_TRUE(round_func->get_output_partial_shape(0).same_scheme(
         (PartialShape{1, Dimension::dynamic(), 6})));
 
     // rank unknown
-    auto softplus_partial = make_shared<op::v5::Round>(
+    auto round_partial = make_shared<op::v5::Round>(
         make_shared<op::Parameter>(element::f32, PartialShape::dynamic()),
         op::v5::Round::RoundMode::HALF_TO_EVEN);
-    ASSERT_TRUE(softplus_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
+    ASSERT_TRUE(round_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
 }
 
 TEST(type_prop, rounding_away_partial)
 {
     auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
-    auto softplus_func =
+    auto round_func =
         make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
-    EXPECT_EQ(softplus_func->get_element_type(), element::f32);
-    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).same_scheme(
+    EXPECT_EQ(round_func->get_element_type(), element::f32);
+    ASSERT_TRUE(round_func->get_output_partial_shape(0).same_scheme(
         (PartialShape{1, Dimension::dynamic(), 6})));
 
     // rank unknown
-    auto softplus_partial = make_shared<op::v5::Round>(
+    auto round_partial = make_shared<op::v5::Round>(
         make_shared<op::Parameter>(element::f32, PartialShape::dynamic()),
         op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
-    ASSERT_TRUE(softplus_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
+    ASSERT_TRUE(round_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
 }
 
 TEST(type_prop, rounding_to_even_partial_static_rank)
 {
     auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
-    auto softplus_func = make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_TO_EVEN);
-    EXPECT_EQ(softplus_func->get_element_type(), element::f32);
-    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).same_scheme(
+    auto round_func = make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_TO_EVEN);
+    EXPECT_EQ(round_func->get_element_type(), element::f32);
+    ASSERT_TRUE(round_func->get_output_partial_shape(0).same_scheme(
         (PartialShape{1, Dimension::dynamic(), 6})));
-    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).rank().is_static());
+    ASSERT_TRUE(round_func->get_output_partial_shape(0).rank().is_static());
 }
 
 TEST(type_prop, rounding_away_partial_static_rank)
 {
     auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
-    auto softplus_func =
+    auto round_func =
         make_shared<op::v5::Round>(data, op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
-    EXPECT_EQ(softplus_func->get_element_type(), element::f32);
-    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).same_scheme(
+    EXPECT_EQ(round_func->get_element_type(), element::f32);
+    ASSERT_TRUE(round_func->get_output_partial_shape(0).same_scheme(
         (PartialShape{1, Dimension::dynamic(), 6})));
-    ASSERT_TRUE(softplus_func->get_output_partial_shape(0).rank().is_static());
+    ASSERT_TRUE(round_func->get_output_partial_shape(0).rank().is_static());
 }