[Py] Ngraph Py API TensorIterator (#718)
authorAdam Osewski <adam.osewski@intel.com>
Mon, 22 Jun 2020 09:40:58 +0000 (11:40 +0200)
committerGitHub <noreply@github.com>
Mon, 22 Jun 2020 09:40:58 +0000 (11:40 +0200)
19 files changed:
ngraph/python/setup.py
ngraph/python/src/ngraph/__init__.py
ngraph/python/src/ngraph/ops.py
ngraph/python/src/ngraph/utils/tensor_iterator_types.py [new file with mode: 0644]
ngraph/python/src/pyngraph/node_factory.cpp
ngraph/python/src/pyngraph/tensor_iterator_builder.cpp [new file with mode: 0644]
ngraph/python/src/pyngraph/tensor_iterator_builder.hpp [new file with mode: 0644]
ngraph/python/test/ngraph/test_create_op.py
ngraph/python/test/ngraph/test_sequence_processing.py
ngraph/src/ngraph/op/detection_output.cpp
ngraph/src/ngraph/op/detection_output.hpp
ngraph/src/ngraph/op/interpolate.cpp
ngraph/src/ngraph/op/interpolate.hpp
ngraph/src/ngraph/op/prior_box.cpp
ngraph/src/ngraph/op/prior_box.hpp
ngraph/src/ngraph/op/prior_box_clustered.cpp
ngraph/src/ngraph/op/prior_box_clustered.hpp
ngraph/src/ngraph/op/proposal.cpp
ngraph/src/ngraph/op/proposal.hpp

index 07f4e7b..f340abb 100644 (file)
@@ -211,6 +211,7 @@ sources = [
     "pyngraph/serializer.cpp",
     "pyngraph/shape.cpp",
     "pyngraph/strides.cpp",
+    "pyngraph/tensor_iterator_builder.cpp",
     "pyngraph/types/element_type.cpp",
     "pyngraph/types/regmodule_pyngraph_types.cpp",
     "pyngraph/util.cpp",
index d644fda..3f0a797 100644 (file)
@@ -149,6 +149,7 @@ from ngraph.ops import strided_slice
 from ngraph.ops import subtract
 from ngraph.ops import tan
 from ngraph.ops import tanh
+from ngraph.ops import tensor_iterator
 from ngraph.ops import tile
 from ngraph.ops import topk
 from ngraph.ops import transpose
index 8d5fe41..9067ebc 100644 (file)
@@ -29,6 +29,14 @@ from ngraph.utils.input_validation import (
     is_positive_value,
 )
 from ngraph.utils.node_factory import NodeFactory
+from ngraph.utils.tensor_iterator_types import (
+    GraphBody,
+    TensorIteratorSliceInputDesc,
+    TensorIteratorMergedInputDesc,
+    TensorIteratorInvariantInputDesc,
+    TensorIteratorBodyOutputDesc,
+    TensorIteratorConcatOutputDesc,
+)
 from ngraph.utils.types import (
     NodeInput,
     NumericData,
@@ -3441,6 +3449,52 @@ def proposal(
 
 
 @nameable_op
+def tensor_iterator(
+    inputs: List[Node],
+    graph_body: GraphBody,
+    slice_input_desc: List[TensorIteratorSliceInputDesc],
+    merged_input_desc: List[TensorIteratorMergedInputDesc],
+    invariant_input_desc: List[TensorIteratorInvariantInputDesc],
+    body_output_desc: List[TensorIteratorBodyOutputDesc],
+    concat_output_desc: List[TensorIteratorConcatOutputDesc],
+    name: Optional[str] = None,
+) -> Node:
+    """
+    Perform recurrent execution of the network described in the body, iterating through the data.
+
+    :param      inputs:                The provided to TensorIterator operator.
+    :param      graph_body:            The graph representing the body we execute.
+    :param      slice_input_desc:      The descriptors describing sliced inputs, that is nodes
+                                       representing tensors we iterate through, processing single
+                                       data slice in one iteration.
+    :param      merged_input_desc:     The descriptors describing merged inputs, that is nodes
+                                       representing variables with initial value at first iteration,
+                                       which may be changing through iterations.
+    :param      invariant_input_desc:  The descriptors describing invariant inputs, that is nodes
+                                       representing variable with persistent value through all
+                                       iterations.
+    :param      body_output_desc:      The descriptors describing body outputs from specified
+                                       iteration.
+    :param      concat_output_desc:    The descriptors describing specified output values through
+                                       all the iterations concatenated into one node.
+    :param      name:                  The optional name for output node.
+
+    :returns:   Node representing TensorIterator operation.
+    """
+
+    attributes = {
+        "body": graph_body.serialize(),
+        "slice_input_desc": [desc.serialize() for desc in slice_input_desc],
+        "merged_input_desc": [desc.serialize() for desc in merged_input_desc],
+        "invariant_input_desc": [desc.serialize() for desc in invariant_input_desc],
+        "body_output_desc": [desc.serialize() for desc in body_output_desc],
+        "concat_output_desc": [desc.serialize() for desc in concat_output_desc],
+    }
+
+    return _get_node_factory().create('TensorIterator', as_nodes(*inputs), attributes)
+
+
+@nameable_op
 def assign(new_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
     """Return a node which produces the Assign operation.
 
diff --git a/ngraph/python/src/ngraph/utils/tensor_iterator_types.py b/ngraph/python/src/ngraph/utils/tensor_iterator_types.py
new file mode 100644 (file)
index 0000000..cae8721
--- /dev/null
@@ -0,0 +1,159 @@
+# ******************************************************************************
+# Copyright 2017-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+"""Helper classes for aggregating TensorIterator input/output desciptor attributes."""
+
+from typing import List
+
+from ngraph.impl import Node
+from ngraph.impl.op import Parameter
+
+
+class GraphBody(object):
+    """Class containing graph parameters and results."""
+
+    def __init__(self, parameters: List[Parameter], results: List[Node],) -> None:
+        self.parameters = parameters
+        self.results = results
+
+    def serialize(self) -> dict:
+        return {
+            "parameters": self.parameters,
+            "results": self.results,
+        }
+
+
+class TensorIteratorInputDesc(object):
+    """Represents a generic input descriptor for TensorIterator operator."""
+
+    def __init__(self, input_idx: int, body_parameter_idx: int,) -> None:
+        self.input_idx = input_idx
+        self.body_parameter_idx = body_parameter_idx
+
+    def serialize(self) -> dict:
+        return {
+            "input_idx": self.input_idx,
+            "body_parameter_idx": self.body_parameter_idx,
+        }
+
+
+class TensorIteratorSliceInputDesc(TensorIteratorInputDesc):
+    """Represents a TensorIterator graph body input formed from slices of TensorIterator input."""
+
+    def __init__(
+        self,
+        input_idx: int,
+        body_parameter_idx: int,
+        start: int,
+        stride: int,
+        part_size: int,
+        end: int,
+        axis: int,
+    ) -> None:
+        super().__init__(input_idx, body_parameter_idx)
+        self.start = start
+        self.stride = stride
+        self.part_size = part_size
+        self.end = end
+        self.axis = axis
+
+    def serialize(self) -> dict:
+        output = super().serialize()
+        output["start"] = self.start
+        output["stride"] = self.stride
+        output["part_size"] = self.part_size
+        output["end"] = self.end
+        output["axis"] = self.axis
+        return output
+
+
+class TensorIteratorMergedInputDesc(TensorIteratorInputDesc):
+    """Represents a TensorIterator graph body input with initial value in the first iteration.
+
+    Later on, this input value is computed inside graph body.
+    """
+
+    def __init__(self, input_idx: int, body_parameter_idx: int, body_value_idx: int,) -> None:
+        super().__init__(input_idx, body_parameter_idx)
+        self.body_value_idx = body_value_idx
+
+    def serialize(self) -> dict:
+        output = super().serialize()
+        output["body_value_idx"] = self.body_value_idx
+        return output
+
+
+class TensorIteratorInvariantInputDesc(TensorIteratorInputDesc):
+    """Represents a TensorIterator graph body input that has invariant value during iteration."""
+
+    def __init__(self, input_idx: int, body_parameter_idx: int,) -> None:
+        super().__init__(input_idx, body_parameter_idx)
+
+
+class TensorIteratorOutputDesc(object):
+    """Represents a generic output descriptor for TensorIterator operator."""
+
+    def __init__(self, body_value_idx: int, output_idx: int,) -> None:
+        self.body_value_idx = body_value_idx
+        self.output_idx = output_idx
+
+    def serialize(self) -> dict:
+        return {
+            "body_value_idx": self.body_value_idx,
+            "output_idx": self.output_idx,
+        }
+
+
+class TensorIteratorBodyOutputDesc(TensorIteratorOutputDesc):
+    """Represents an output from a specific iteration."""
+
+    def __init__(self, body_value_idx: int, output_idx: int, iteration: int,) -> None:
+        super().__init__(body_value_idx, output_idx)
+        self.iteration = iteration
+
+    def serialize(self) -> dict:
+        output = super().serialize()
+        output["iteration"] = self.iteration
+        return output
+
+
+class TensorIteratorConcatOutputDesc(TensorIteratorOutputDesc):
+    """Represents an output produced by concatenation of output from each iteration."""
+
+    def __init__(
+        self,
+        body_value_idx: int,
+        output_idx: int,
+        start: int,
+        stride: int,
+        part_size: int,
+        end: int,
+        axis: int,
+    ) -> None:
+        super().__init__(body_value_idx, output_idx)
+        self.start = start
+        self.stride = stride
+        self.part_size = part_size
+        self.end = end
+        self.axis = axis
+
+    def serialize(self) -> dict:
+        output = super().serialize()
+        output["start"] = self.start
+        output["stride"] = self.stride
+        output["part_size"] = self.part_size
+        output["end"] = self.end
+        output["axis"] = self.axis
+        return output
index 09597a7..ea54e00 100644 (file)
@@ -35,6 +35,9 @@
 #include "ngraph/opsets/opset.hpp"
 #include "ngraph/util.hpp"
 #include "node_factory.hpp"
+#include "tensor_iterator_builder.hpp"
+
+namespace py = pybind11;
 
 namespace
 {
@@ -265,6 +268,13 @@ namespace
                          "Currently NodeFactory doesn't support Constant node: ",
                          op_type_name);
 
+            if (op_type_name == "TensorIterator")
+            {
+                // TODO: how to differentiate opsets?
+                return util::TensorIteratorBuilder(arguments, attributes)
+                    .configure(std::static_pointer_cast<ngraph::op::TensorIterator>(op_node));
+            }
+
             DictAttributeDeserializer visitor(attributes);
 
             op_node->set_arguments(arguments);
@@ -303,8 +313,6 @@ namespace
     };
 }
 
-namespace py = pybind11;
-
 void regclass_pyngraph_NodeFactory(py::module m)
 {
     py::class_<NodeFactory> node_factory(m, "NodeFactory");
diff --git a/ngraph/python/src/pyngraph/tensor_iterator_builder.cpp b/ngraph/python/src/pyngraph/tensor_iterator_builder.cpp
new file mode 100644 (file)
index 0000000..66f5ee4
--- /dev/null
@@ -0,0 +1,227 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include <string>
+
+#include "ngraph/check.hpp"
+#include "ngraph/except.hpp"
+#include "tensor_iterator_builder.hpp"
+
+util::TensorIteratorBuilder::TensorIteratorBuilder(const ngraph::NodeVector& arguments,
+                                                   const py::dict& attributes)
+    : m_arguments(arguments)
+    , m_attributes(attributes)
+{
+    get_graph_body();
+    // Set-up TI inputs.
+    NGRAPH_CHECK(m_attributes.contains("slice_input_desc"),
+                 "The required \"slice_input_desc\" attribute is missing. Can't build "
+                 "TensorIterator operator.");
+    m_slice_input_desc = m_attributes["slice_input_desc"].cast<py::list>();
+
+    if (m_attributes.contains("merged_input_desc"))
+    {
+        m_merged_input_desc = m_attributes["merged_input_desc"].cast<py::list>();
+    }
+
+    if (m_attributes.contains("invariant_input_desc"))
+    {
+        m_invariant_input_desc = m_attributes["invariant_input_desc"].cast<py::list>();
+    }
+
+    if (m_attributes.contains("body_output_desc"))
+    {
+        py::list body_output_desc = m_attributes["body_output_desc"].cast<py::list>();
+        for (py::handle h : body_output_desc)
+        {
+            py::dict desc = h.cast<py::dict>();
+            desc["type"] = "BodyOutputDesc";
+            check_attribute(desc, "output_idx", "BodyOutputDesc");
+            m_outputs.emplace(desc["output_idx"].cast<int64_t>(), desc);
+        }
+    }
+    if (m_attributes.contains("concat_output_desc"))
+    {
+        py::list concat_output_desc = m_attributes["concat_output_desc"].cast<py::list>();
+        for (py::handle h : concat_output_desc)
+        {
+            py::dict desc = h.cast<py::dict>();
+            desc["type"] = "ConcatOutputDesc";
+            check_attribute(desc, "output_idx", "ConcatOutputDesc");
+            m_outputs.emplace(desc["output_idx"].cast<int64_t>(), desc);
+        }
+    }
+}
+
+std::shared_ptr<ngraph::op::TensorIterator>
+    util::TensorIteratorBuilder::configure(std::shared_ptr<ngraph::op::TensorIterator>&& ti_node)
+{
+    ti_node->set_body(m_body);
+    set_tensor_iterator_sliced_inputs(ti_node);
+    set_tensor_iterator_merged_inputs(ti_node);
+    set_tensor_iterator_invariant_inputs(ti_node);
+    set_tensor_iterator_outputs(ti_node);
+    ti_node->constructor_validate_and_infer_types();
+
+    return ti_node;
+}
+
+void util::TensorIteratorBuilder::check_attribute(const py::dict& attrs,
+                                                  std::string attr_name,
+                                                  std::string desc_name) const
+{
+    NGRAPH_CHECK(attrs.contains(attr_name),
+                 "The required \"",
+                 attr_name,
+                 "\" attribute is missing. Can't build TensorIterator's ",
+                 desc_name,
+                 ".");
+}
+
+void util::TensorIteratorBuilder::get_graph_body()
+{
+    NGRAPH_CHECK(m_attributes.contains("body"),
+                 "The required \"body\" attribute is missing. Can't build TensorIterator "
+                 "operator.");
+
+    const py::dict& body_attrs = m_attributes["body"].cast<py::dict>();
+
+    NGRAPH_CHECK(body_attrs.contains("parameters"),
+                 "The required body's \"parameters\" "
+                 "attribute is missing. Can't build TensorIterator's body.");
+    NGRAPH_CHECK(body_attrs.contains("results"),
+                 "The required body's \"results\" "
+                 "attribute is missing. Can't build TensorIterator's body.");
+
+    m_body_outputs = as_output_vector(body_attrs["results"].cast<ngraph::NodeVector>());
+    m_body_parameters = body_attrs["parameters"].cast<ngraph::ParameterVector>();
+    m_body =
+        std::make_shared<ngraph::op::TensorIterator::BodyLambda>(m_body_outputs, m_body_parameters);
+}
+
+void util::TensorIteratorBuilder::set_tensor_iterator_sliced_inputs(
+    std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
+{
+    for (py::handle h : m_slice_input_desc)
+    {
+        const py::dict& desc = h.cast<py::dict>();
+        check_attribute(desc, "input_idx", "SliceInputDesc");
+        check_attribute(desc, "body_parameter_idx", "SliceInputDesc");
+        check_attribute(desc, "start", "SliceInputDesc");
+        check_attribute(desc, "stride", "SliceInputDesc");
+        check_attribute(desc, "part_size", "SliceInputDesc");
+        check_attribute(desc, "end", "SliceInputDesc");
+        check_attribute(desc, "axis", "SliceInputDesc");
+
+        ti_node->set_sliced_input(m_body_parameters.at(desc["body_parameter_idx"].cast<int64_t>()),
+                                  m_arguments.at(desc["input_idx"].cast<int64_t>()),
+                                  desc["start"].cast<int64_t>(),
+                                  desc["stride"].cast<int64_t>(),
+                                  desc["part_size"].cast<int64_t>(),
+                                  desc["end"].cast<int64_t>(),
+                                  desc["axis"].cast<int64_t>());
+    }
+}
+
+void util::TensorIteratorBuilder::set_tensor_iterator_merged_inputs(
+    std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
+{
+    for (py::handle h : m_merged_input_desc)
+    {
+        const py::dict& desc = h.cast<py::dict>();
+        check_attribute(desc, "input_idx", "MergedInputDesc");
+        check_attribute(desc, "body_parameter_idx", "MergedInputDesc");
+        check_attribute(desc, "body_value_idx", "MergedInputDesc");
+
+        ti_node->set_merged_input(m_body_parameters.at(desc["body_parameter_idx"].cast<int64_t>()),
+                                  m_arguments.at(desc["input_idx"].cast<int64_t>()),
+                                  m_body_outputs.at(desc["body_value_idx"].cast<int64_t>()));
+    }
+}
+
+void util::TensorIteratorBuilder::set_tensor_iterator_invariant_inputs(
+    std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
+{
+    for (py::handle h : m_invariant_input_desc)
+    {
+        const py::dict& desc = h.cast<py::dict>();
+        check_attribute(desc, "input_idx", "InvariantInputDesc");
+        check_attribute(desc, "body_parameter_idx", "InvariantInputDesc");
+
+        ti_node->set_invariant_input(
+            m_body_parameters.at(desc["body_parameter_idx"].cast<int64_t>()),
+            m_arguments.at(desc["input_idx"].cast<int64_t>()));
+    }
+}
+
+void util::TensorIteratorBuilder::set_tensor_iterator_outputs(
+    std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
+{
+    for (const auto& elem : m_outputs)
+    {
+        const py::dict& desc = elem.second.cast<py::dict>();
+        if (desc["type"].cast<std::string>() == "BodyOutputDesc")
+        {
+            set_tensor_iterator_body_output(desc, ti_node);
+        }
+        else if (desc["type"].cast<std::string>() == "ConcatOutputDesc")
+        {
+            set_tensor_iterator_concatenated_body_output(desc, ti_node);
+        }
+        else
+        {
+            throw ngraph::ngraph_error("Unrecognized TensorIterator output type.");
+        }
+    }
+}
+
+void util::TensorIteratorBuilder::set_tensor_iterator_body_output(
+    const py::dict& desc, std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
+{
+    check_attribute(desc, "body_value_idx", "BodyOutputDesc");
+    check_attribute(desc, "iteration", "BodyOutputDesc");
+
+    NGRAPH_CHECK(desc["output_idx"].cast<size_t>() == ti_node->get_output_size(),
+                 "Descriptor output idx value is different from currently configured "
+                 "TensorIterator output.");
+
+    ti_node->get_iter_value(m_body_outputs.at(desc["body_value_idx"].cast<int64_t>()),
+                            desc["iteration"].cast<int64_t>());
+}
+
+void util::TensorIteratorBuilder::set_tensor_iterator_concatenated_body_output(
+    const py::dict& desc, std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
+{
+    check_attribute(desc, "body_value_idx", "ConcatOutputDesc");
+    check_attribute(desc, "start", "ConcatOutputDesc");
+    check_attribute(desc, "stride", "ConcatOutputDesc");
+    check_attribute(desc, "part_size", "ConcatOutputDesc");
+    check_attribute(desc, "end", "ConcatOutputDesc");
+    check_attribute(desc, "axis", "ConcatOutputDesc");
+
+    NGRAPH_CHECK(desc["output_idx"].cast<size_t>() == ti_node->get_output_size(),
+                 "Descriptor output idx value is different from currently configured "
+                 "TensorIterator output.");
+
+    ti_node->get_concatenated_slices(m_body_outputs.at(desc["body_value_idx"].cast<int64_t>()),
+                                     desc["start"].cast<int64_t>(),
+                                     desc["stride"].cast<int64_t>(),
+                                     desc["part_size"].cast<int64_t>(),
+                                     desc["end"].cast<int64_t>(),
+                                     desc["axis"].cast<int64_t>());
+}
diff --git a/ngraph/python/src/pyngraph/tensor_iterator_builder.hpp b/ngraph/python/src/pyngraph/tensor_iterator_builder.hpp
new file mode 100644 (file)
index 0000000..dc15a3e
--- /dev/null
@@ -0,0 +1,135 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include <cctype>
+#include <map>
+#include <memory>
+
+#include <pybind11/numpy.h>
+#include <pybind11/stl.h>
+
+#include "ngraph/node.hpp"
+#include "ngraph/op/parameter.hpp"
+#include "ngraph/op/tensor_iterator.hpp"
+
+namespace py = pybind11;
+
+namespace util
+{
+    class TensorIteratorBuilder
+    {
+    public:
+        ///
+        /// \brief      Initialize TensorIterator node builder.
+        ///
+        /// \param[in]  arguments   The arguments passed to TensorIterator node.
+        /// \param[in]  attributes  The TensorIterator's attributes. This
+        ///                         py::dict contains all descriptors for
+        ///                         plethora of TensorIterator available inputs
+        ///                         and outputs.
+        ///
+        TensorIteratorBuilder(const ngraph::NodeVector& arguments, const py::dict& attributes);
+
+        ///
+        /// \brief      Configure instance of TensorIterator node with set-up parameters.
+        ///
+        /// \param      ti_node  The TensorIterator node instance to configure.
+        ///
+        /// \return     TensorIterator node.
+        ///
+        std::shared_ptr<ngraph::op::TensorIterator>
+            configure(std::shared_ptr<ngraph::op::TensorIterator>&& ti_node);
+
+    private:
+        ///
+        /// \brief      Helper to conduct attribute presence.
+        ///
+        /// \param[in]  attrs      The attributes
+        /// \param[in]  attr_name  The attribute name
+        /// \param[in]  desc_name  The description name
+        ///
+        inline void check_attribute(const py::dict& attrs,
+                                    std::string attr_name,
+                                    std::string desc_name) const;
+
+        ///
+        /// \brief      Retrieve the TI graph body.
+        ///
+        void get_graph_body();
+
+        ///
+        /// \brief      Sets the tensor iterator sliced inputs.
+        ///
+        /// \param      ti_node  The TI node we will set input to.
+        ///
+        void set_tensor_iterator_sliced_inputs(
+            std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
+
+        ///
+        /// \brief      Sets the tensor iterator merged inputs.
+        ///
+        /// \param      ti_node  The TI node we will set inputs to.
+        ///
+        void set_tensor_iterator_merged_inputs(
+            std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
+
+        ///
+        /// \brief      Sets the tensor iterator invariant inputs.
+        ///
+        /// \param      ti_node  The TI node we will set inputs to.
+        ///
+        void set_tensor_iterator_invariant_inputs(
+            std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
+
+        ///
+        /// \brief      Sets the tensor iterator outputs.
+        ///
+        /// \param      ti_node  The TI node we will set outputs to.
+        ///
+        void
+            set_tensor_iterator_outputs(std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
+
+        ///
+        /// \brief      Sets the tensor iterator body output.
+        ///
+        /// \param[in]  desc     The descriptor of the TI body output.
+        /// \param      ti_node  The TI node we will set output to.
+        ///
+        void set_tensor_iterator_body_output(
+            const py::dict& desc, std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
+
+        ///
+        /// \brief      Sets the tensor iterator concatenated body output.
+        ///
+        /// \param[in]  desc     The descriptor of the TI body output.
+        /// \param      ti_node  The TI node we will set output to.
+        ///
+        void set_tensor_iterator_concatenated_body_output(
+            const py::dict& desc, std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
+
+        const ngraph::NodeVector& m_arguments;
+        const py::dict& m_attributes;
+        ngraph::OutputVector m_body_outputs;
+        ngraph::ParameterVector m_body_parameters;
+        std::shared_ptr<ngraph::op::TensorIterator::BodyLambda> m_body;
+        py::list m_slice_input_desc;
+        py::list m_merged_input_desc;
+        py::list m_invariant_input_desc;
+        std::map<int64_t, const py::dict> m_outputs;
+    };
+} // namespace util
index 9b041d8..c815eaf 100644 (file)
@@ -19,7 +19,6 @@ import pytest
 import ngraph as ng
 from ngraph.impl import Type
 from _pyngraph import PartialShape
-import test
 
 np_types = [np.float32, np.int32]
 integral_np_types = [
@@ -537,18 +536,6 @@ def test_convert_like():
     assert node.get_output_element_type(0) == Type.i8
 
 
-def test_one_hot():
-    data = np.array([0, 1, 2], dtype=np.int32)
-    depth = 2
-    on_value = 5
-    off_value = 10
-    axis = -1
-    excepted = [[5, 10], [10, 5], [10, 10]]
-
-    result = test.ngraph.util.run_op_node([data, depth, on_value, off_value], ng.ops.one_hot, axis)
-    assert np.allclose(result, excepted)
-
-
 def test_reverse():
     parameter_data = ng.parameter([3, 10, 100, 200], name="data", dtype=np.float32)
     parameter_axis = ng.parameter([1], name="axis", dtype=np.int64)
@@ -562,23 +549,6 @@ def test_reverse():
     assert node.get_output_element_type(0) == Type.f32
 
 
-def test_select():
-    cond = [[False, False], [True, False], [True, True]]
-    then_node = [[-1, 0], [1, 2], [3, 4]]
-    else_node = [[11, 10], [9, 8], [7, 6]]
-    excepted = [[11, 10], [1, 8], [3, 4]]
-
-    result = test.ngraph.util.run_op_node([cond, then_node, else_node], ng.ops.select)
-    assert np.allclose(result, excepted)
-
-
-def test_result():
-    node = [[11, 10], [1, 8], [3, 4]]
-
-    result = test.ngraph.util.run_op_node([node], ng.ops.result)
-    assert np.allclose(result, node)
-
-
 def test_bucketize():
     data = ng.parameter([4, 3, 2, 1], name="data", dtype=np.float32)
     buckets = ng.parameter([5], name="buckets", dtype=np.int64)
@@ -591,15 +561,6 @@ def test_bucketize():
     assert node.get_output_element_type(0) == Type.i32
 
 
-def test_range():
-    start = 5
-    stop = 35
-    step = 5
-
-    result = test.ngraph.util.run_op_node([start, stop, step], ng.ops.range)
-    assert np.allclose(result, [5, 10, 15, 20, 25, 30])
-
-
 def test_region_yolo():
     data = ng.parameter([1, 125, 13, 13], name="input", dtype=np.float32)
     num_coords = 4
@@ -847,6 +808,82 @@ def test_proposal(int_dtype, fp_dtype):
     assert list(node.get_output_shape(0)) == [batch_size * attributes["attrs.post_nms_topn"], 5]
 
 
+def test_tensor_iterator():
+    from ngraph.utils.tensor_iterator_types import (
+        GraphBody,
+        TensorIteratorSliceInputDesc,
+        TensorIteratorMergedInputDesc,
+        TensorIteratorInvariantInputDesc,
+        TensorIteratorBodyOutputDesc,
+        TensorIteratorConcatOutputDesc,
+    )
+
+    #  Body parameters
+    body_timestep = ng.parameter([], np.int32, "timestep")
+    body_data_in = ng.parameter([1, 2, 2], np.float32, "body_in")
+    body_prev_cma = ng.parameter([2, 2], np.float32, "body_prev_cma")
+    body_const_one = ng.parameter([], np.int32, "body_const_one")
+
+    # CMA = cumulative moving average
+    prev_cum_sum = ng.multiply(ng.convert(body_timestep, "f32"), body_prev_cma)
+    curr_cum_sum = ng.add(prev_cum_sum, ng.squeeze(body_data_in, [0]))
+    elem_cnt = ng.add(body_const_one, body_timestep)
+    curr_cma = ng.divide(curr_cum_sum, ng.convert(elem_cnt, "f32"))
+    cma_hist = ng.unsqueeze(curr_cma, [0])
+
+    # TI inputs
+    data = ng.parameter([16, 2, 2], np.float32, "data")
+    # Iterations count
+    zero = ng.constant(0, dtype=np.int32)
+    one = ng.constant(1, dtype=np.int32)
+    initial_cma = ng.constant(np.zeros([2, 2], dtype=np.float32), dtype=np.float32)
+    iter_cnt = ng.ops.range(zero, np.int32(16), np.int32(1))
+    ti_inputs = [iter_cnt, data, initial_cma, one]
+
+    graph_body = GraphBody([body_timestep, body_data_in, body_prev_cma, body_const_one],
+                           [curr_cma, cma_hist])
+    ti_slice_input_desc = [
+        # timestep
+        # input_idx, body_param_idx, start, stride, part_size, end, axis
+        TensorIteratorSliceInputDesc(0, 0, 0, 1, 1, -1, 0),
+        # data
+        TensorIteratorSliceInputDesc(1, 1, 0, 1, 1, -1, 0),
+    ]
+    ti_merged_input_desc = [
+        # body prev/curr_cma
+        TensorIteratorMergedInputDesc(2, 2, 0),
+    ]
+    ti_invariant_input_desc = [
+        # body const one
+        TensorIteratorInvariantInputDesc(3, 3),
+    ]
+
+    # TI outputs
+    ti_body_output_desc = [
+        # final average
+        TensorIteratorBodyOutputDesc(0, 0, -1),
+    ]
+    ti_concat_output_desc = [
+        # history of cma
+        TensorIteratorConcatOutputDesc(1, 1, 0, 1, 1, -1, 0),
+    ]
+
+    node = ng.tensor_iterator(ti_inputs,
+                              graph_body,
+                              ti_slice_input_desc,
+                              ti_merged_input_desc,
+                              ti_invariant_input_desc,
+                              ti_body_output_desc,
+                              ti_concat_output_desc)
+
+    assert node.get_type_name() == "TensorIterator"
+    assert node.get_output_size() == 2
+    # final average
+    assert list(node.get_output_shape(0)) == [2, 2]
+    # cma history
+    assert list(node.get_output_shape(1)) == [16, 2, 2]
+
+
 def test_read_value():
     init_value = ng.parameter([2, 2], name="init_value", dtype=np.int32)
 
index 8b3b015..3b45193 100644 (file)
@@ -41,3 +41,12 @@ def test_one_hot():
 
     result = run_op_node([data, depth, on_value, off_value], ng.one_hot, axis)
     assert np.allclose(result, excepted)
+
+
+def test_range():
+    start = 5
+    stop = 35
+    step = 5
+
+    result = run_op_node([start, stop, step], ng.ops.range)
+    assert np.allclose(result, [5, 10, 15, 20, 25, 30])
index d536e21..22404a2 100644 (file)
@@ -84,21 +84,34 @@ shared_ptr<Node> op::DetectionOutput::clone_with_new_inputs(const OutputVector&
 
 bool op::DetectionOutput::visit_attributes(AttributeVisitor& visitor)
 {
-    visitor.on_attribute("attrs.num_classes", m_attrs.num_classes);
-    visitor.on_attribute("attrs.background_label_id", m_attrs.background_label_id);
-    visitor.on_attribute("attrs.top_k", m_attrs.top_k);
-    visitor.on_attribute("attrs.variance_encoded_in_target", m_attrs.variance_encoded_in_target);
-    visitor.on_attribute("attrs.keep_top_k", m_attrs.keep_top_k);
-    visitor.on_attribute("attrs.code_type", m_attrs.code_type);
-    visitor.on_attribute("attrs.share_location", m_attrs.share_location);
-    visitor.on_attribute("attrs.nms_threshold", m_attrs.nms_threshold);
-    visitor.on_attribute("attrs.confidence_threshold", m_attrs.confidence_threshold);
-    visitor.on_attribute("attrs.clip_after_nms", m_attrs.clip_after_nms);
-    visitor.on_attribute("attrs.clip_before_nms", m_attrs.clip_before_nms);
-    visitor.on_attribute("attrs.decrease_label_id", m_attrs.decrease_label_id);
-    visitor.on_attribute("attrs.normalized", m_attrs.normalized);
-    visitor.on_attribute("attrs.input_height", m_attrs.input_height);
-    visitor.on_attribute("attrs.input_width", m_attrs.input_width);
-    visitor.on_attribute("attrs.objectness_score", m_attrs.objectness_score);
+    visitor.on_attribute("attrs", m_attrs);
+    return true;
+}
+
+constexpr DiscreteTypeInfo AttributeAdapter<op::DetectionOutputAttrs>::type_info;
+
+AttributeAdapter<op::DetectionOutputAttrs>::AttributeAdapter(op::DetectionOutputAttrs& ref)
+    : m_ref(ref)
+{
+}
+
+bool AttributeAdapter<op::DetectionOutputAttrs>::visit_attributes(AttributeVisitor& visitor)
+{
+    visitor.on_attribute("num_classes", m_ref.num_classes);
+    visitor.on_attribute("background_label_id", m_ref.background_label_id);
+    visitor.on_attribute("top_k", m_ref.top_k);
+    visitor.on_attribute("variance_encoded_in_target", m_ref.variance_encoded_in_target);
+    visitor.on_attribute("keep_top_k", m_ref.keep_top_k);
+    visitor.on_attribute("code_type", m_ref.code_type);
+    visitor.on_attribute("share_location", m_ref.share_location);
+    visitor.on_attribute("nms_threshold", m_ref.nms_threshold);
+    visitor.on_attribute("confidence_threshold", m_ref.confidence_threshold);
+    visitor.on_attribute("clip_after_nms", m_ref.clip_after_nms);
+    visitor.on_attribute("clip_before_nms", m_ref.clip_before_nms);
+    visitor.on_attribute("decrease_label_id", m_ref.decrease_label_id);
+    visitor.on_attribute("normalized", m_ref.normalized);
+    visitor.on_attribute("input_height", m_ref.input_height);
+    visitor.on_attribute("input_width", m_ref.input_width);
+    visitor.on_attribute("objectness_score", m_ref.objectness_score);
     return true;
 }
index 54e9c4d..5b062a5 100644 (file)
@@ -92,4 +92,18 @@ namespace ngraph
         }
         using v0::DetectionOutput;
     }
+
+    template <>
+    class NGRAPH_API AttributeAdapter<op::DetectionOutputAttrs> : public VisitorAdapter
+    {
+    public:
+        AttributeAdapter(op::DetectionOutputAttrs& ref);
+
+        virtual bool visit_attributes(AttributeVisitor& visitor) override;
+        static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::DetectionOutputAttrs>",
+                                                    0};
+        const DiscreteTypeInfo& get_type_info() const override { return type_info; }
+    protected:
+        op::DetectionOutputAttrs& m_ref;
+    };
 }
index df7866f..a8e30b2 100644 (file)
@@ -33,12 +33,7 @@ op::v0::Interpolate::Interpolate(const Output<Node>& image,
 
 bool op::v0::Interpolate::visit_attributes(AttributeVisitor& visitor)
 {
-    visitor.on_attribute("attrs.axes", m_attrs.axes);
-    visitor.on_attribute("attrs.mode", m_attrs.mode);
-    visitor.on_attribute("attrs.align_corners", m_attrs.align_corners);
-    visitor.on_attribute("attrs.antialias", m_attrs.antialias);
-    visitor.on_attribute("attrs.pads_begin", m_attrs.pads_begin);
-    visitor.on_attribute("attrs.pads_end", m_attrs.pads_end);
+    visitor.on_attribute("attrs", m_attrs);
     return true;
 }
 
@@ -98,6 +93,24 @@ namespace ngraph
     {
         return s << as_string(type);
     }
+
+    constexpr DiscreteTypeInfo AttributeAdapter<op::v0::InterpolateAttrs>::type_info;
+
+    AttributeAdapter<op::v0::InterpolateAttrs>::AttributeAdapter(op::v0::InterpolateAttrs& ref)
+        : m_ref(ref)
+    {
+    }
+
+    bool AttributeAdapter<op::v0::InterpolateAttrs>::visit_attributes(AttributeVisitor& visitor)
+    {
+        visitor.on_attribute("axes", m_ref.axes);
+        visitor.on_attribute("mode", m_ref.mode);
+        visitor.on_attribute("align_corners", m_ref.align_corners);
+        visitor.on_attribute("antialias", m_ref.antialias);
+        visitor.on_attribute("pads_begin", m_ref.pads_begin);
+        visitor.on_attribute("pads_end", m_ref.pads_end);
+        return true;
+    }
 }
 
 // Interpolate v3
@@ -115,15 +128,7 @@ op::v3::Interpolate::Interpolate(const Output<Node>& image,
 
 bool op::v3::Interpolate::visit_attributes(AttributeVisitor& visitor)
 {
-    visitor.on_attribute("attrs.axes", m_attrs.axes);
-    visitor.on_attribute("attrs.mode", m_attrs.mode);
-    visitor.on_attribute("attrs.coordinate_transformation_mode",
-                         m_attrs.coordinate_transformation_mode);
-    visitor.on_attribute("attrs.nearest_mode", m_attrs.nearest_mode);
-    visitor.on_attribute("attrs.antialias", m_attrs.antialias);
-    visitor.on_attribute("attrs.pads_begin", m_attrs.pads_begin);
-    visitor.on_attribute("attrs.pads_end", m_attrs.pads_end);
-    visitor.on_attribute("attrs.cube_coeff", m_attrs.cube_coeff);
+    visitor.on_attribute("attrs", m_attrs);
     return true;
 }
 
@@ -229,4 +234,27 @@ namespace ngraph
     {
         return s << as_string(type);
     }
+
+    constexpr DiscreteTypeInfo AttributeAdapter<op::v3::Interpolate::InterpolateAttrs>::type_info;
+
+    AttributeAdapter<op::v3::Interpolate::InterpolateAttrs>::AttributeAdapter(
+        op::v3::Interpolate::InterpolateAttrs& ref)
+        : m_ref(ref)
+    {
+    }
+
+    bool AttributeAdapter<op::v3::Interpolate::InterpolateAttrs>::visit_attributes(
+        AttributeVisitor& visitor)
+    {
+        visitor.on_attribute("axes", m_ref.axes);
+        visitor.on_attribute("mode", m_ref.mode);
+        visitor.on_attribute("coordinate_transformation_mode",
+                             m_ref.coordinate_transformation_mode);
+        visitor.on_attribute("nearest_mode", m_ref.nearest_mode);
+        visitor.on_attribute("antialias", m_ref.antialias);
+        visitor.on_attribute("pads_begin", m_ref.pads_begin);
+        visitor.on_attribute("pads_end", m_ref.pads_end);
+        visitor.on_attribute("cube_coeff", m_ref.cube_coeff);
+        return true;
+    }
 }
index 4433feb..f83740f 100644 (file)
@@ -16,6 +16,7 @@
 
 #pragma once
 
+#include "ngraph/attribute_adapter.hpp"
 #include "ngraph/op/op.hpp"
 #include "ngraph/op/util/attr_types.hpp"
 
@@ -176,6 +177,22 @@ namespace ngraph
         using v0::Interpolate;
     }
 
+    //---------------------------------------- v0 --------------------------------------------------
+
+    template <>
+    class NGRAPH_API AttributeAdapter<op::v0::InterpolateAttrs> : public VisitorAdapter
+    {
+    public:
+        AttributeAdapter(op::v0::InterpolateAttrs& ref);
+
+        virtual bool visit_attributes(AttributeVisitor& visitor) override;
+        static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v0::InterpolateAttrs>",
+                                                    0};
+        const DiscreteTypeInfo& get_type_info() const override { return type_info; }
+    protected:
+        op::v0::InterpolateAttrs& m_ref;
+    };
+
     NGRAPH_API
     std::ostream& operator<<(std::ostream& s, const op::v0::Interpolate::InterpolateMode& type);
 
@@ -194,6 +211,8 @@ namespace ngraph
         const DiscreteTypeInfo& get_type_info() const override { return type_info; }
     };
 
+    //---------------------------------------- v3 --------------------------------------------------
+
     NGRAPH_API
     std::ostream& operator<<(std::ostream& s, const op::v3::Interpolate::InterpolateMode& type);
 
@@ -248,4 +267,18 @@ namespace ngraph
             "AttributeAdapter<op::v3::Interpolate::NearestMode>", 3};
         const DiscreteTypeInfo& get_type_info() const override { return type_info; }
     };
+
+    template <>
+    class NGRAPH_API AttributeAdapter<op::v3::Interpolate::InterpolateAttrs> : public VisitorAdapter
+    {
+    public:
+        AttributeAdapter(op::v3::Interpolate::InterpolateAttrs& ref);
+
+        virtual bool visit_attributes(AttributeVisitor& visitor) override;
+        static constexpr DiscreteTypeInfo type_info{
+            "AttributeAdapter<op::v3::Interpolate::InterpolateAttrs>", 3};
+        const DiscreteTypeInfo& get_type_info() const override { return type_info; }
+    protected:
+        op::v3::Interpolate::InterpolateAttrs& m_ref;
+    };
 }
index 3e58a78..ff78d51 100644 (file)
@@ -132,18 +132,31 @@ std::vector<float> op::PriorBox::normalized_aspect_ratio(const std::vector<float
 
 bool op::PriorBox::visit_attributes(AttributeVisitor& visitor)
 {
-    visitor.on_attribute("attrs.min_size", m_attrs.min_size);
-    visitor.on_attribute("attrs.max_size", m_attrs.max_size);
-    visitor.on_attribute("attrs.aspect_ratio", m_attrs.aspect_ratio);
-    visitor.on_attribute("attrs.density", m_attrs.density);
-    visitor.on_attribute("attrs.fixed_ratio", m_attrs.fixed_ratio);
-    visitor.on_attribute("attrs.fixed_size", m_attrs.fixed_size);
-    visitor.on_attribute("attrs.clip", m_attrs.clip);
-    visitor.on_attribute("attrs.flip", m_attrs.flip);
-    visitor.on_attribute("attrs.step", m_attrs.step);
-    visitor.on_attribute("attrs.offset", m_attrs.offset);
-    visitor.on_attribute("attrs.variance", m_attrs.variance);
-    visitor.on_attribute("attrs.scale_all_sizes", m_attrs.scale_all_sizes);
+    visitor.on_attribute("attrs", m_attrs);
+    return true;
+}
+
+constexpr DiscreteTypeInfo AttributeAdapter<op::PriorBoxAttrs>::type_info;
+
+AttributeAdapter<op::PriorBoxAttrs>::AttributeAdapter(op::PriorBoxAttrs& ref)
+    : m_ref(ref)
+{
+}
+
+bool AttributeAdapter<op::PriorBoxAttrs>::visit_attributes(AttributeVisitor& visitor)
+{
+    visitor.on_attribute("min_size", m_ref.min_size);
+    visitor.on_attribute("max_size", m_ref.max_size);
+    visitor.on_attribute("aspect_ratio", m_ref.aspect_ratio);
+    visitor.on_attribute("density", m_ref.density);
+    visitor.on_attribute("fixed_ratio", m_ref.fixed_ratio);
+    visitor.on_attribute("fixed_size", m_ref.fixed_size);
+    visitor.on_attribute("clip", m_ref.clip);
+    visitor.on_attribute("flip", m_ref.flip);
+    visitor.on_attribute("step", m_ref.step);
+    visitor.on_attribute("offset", m_ref.offset);
+    visitor.on_attribute("variance", m_ref.variance);
+    visitor.on_attribute("scale_all_sizes", m_ref.scale_all_sizes);
     return true;
 }
 
index 1de5c24..fc9d325 100644 (file)
@@ -85,4 +85,17 @@ namespace ngraph
         }
         using v0::PriorBox;
     }
+
+    template <>
+    class NGRAPH_API AttributeAdapter<op::PriorBoxAttrs> : public VisitorAdapter
+    {
+    public:
+        AttributeAdapter(op::PriorBoxAttrs& ref);
+
+        virtual bool visit_attributes(AttributeVisitor& visitor) override;
+        static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::PriorBoxAttrs>", 0};
+        const DiscreteTypeInfo& get_type_info() const override { return type_info; }
+    protected:
+        op::PriorBoxAttrs& m_ref;
+    };
 }
index 21c02af..3fbb7f9 100644 (file)
@@ -94,13 +94,26 @@ shared_ptr<Node> op::PriorBoxClustered::clone_with_new_inputs(const OutputVector
 
 bool op::PriorBoxClustered::visit_attributes(AttributeVisitor& visitor)
 {
-    visitor.on_attribute("attrs.widths", m_attrs.widths);
-    visitor.on_attribute("attrs.heights", m_attrs.heights);
-    visitor.on_attribute("attrs.clip", m_attrs.clip);
-    visitor.on_attribute("attrs.step_widths", m_attrs.step_widths);
-    visitor.on_attribute("attrs.step_heights", m_attrs.step_heights);
-    visitor.on_attribute("attrs.offset", m_attrs.offset);
-    visitor.on_attribute("attrs.variances", m_attrs.variances);
+    visitor.on_attribute("attrs", m_attrs);
+    return true;
+}
+
+constexpr DiscreteTypeInfo AttributeAdapter<op::PriorBoxClusteredAttrs>::type_info;
+
+AttributeAdapter<op::PriorBoxClusteredAttrs>::AttributeAdapter(op::PriorBoxClusteredAttrs& ref)
+    : m_ref(ref)
+{
+}
+
+bool AttributeAdapter<op::PriorBoxClusteredAttrs>::visit_attributes(AttributeVisitor& visitor)
+{
+    visitor.on_attribute("widths", m_ref.widths);
+    visitor.on_attribute("heights", m_ref.heights);
+    visitor.on_attribute("clip", m_ref.clip);
+    visitor.on_attribute("step_widths", m_ref.step_widths);
+    visitor.on_attribute("step_heights", m_ref.step_heights);
+    visitor.on_attribute("offset", m_ref.offset);
+    visitor.on_attribute("variances", m_ref.variances);
     return true;
 }
 
index e9e9005..3b91a1e 100644 (file)
@@ -73,4 +73,18 @@ namespace ngraph
         }
         using v0::PriorBoxClustered;
     }
+
+    template <>
+    class NGRAPH_API AttributeAdapter<op::PriorBoxClusteredAttrs> : public VisitorAdapter
+    {
+    public:
+        AttributeAdapter(op::PriorBoxClusteredAttrs& ref);
+
+        virtual bool visit_attributes(AttributeVisitor& visitor) override;
+        static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::PriorBoxClusteredAttrs>",
+                                                    0};
+        const DiscreteTypeInfo& get_type_info() const override { return type_info; }
+    protected:
+        op::PriorBoxClusteredAttrs& m_ref;
+    };
 }
index a8d3085..ee45066 100644 (file)
@@ -92,19 +92,32 @@ shared_ptr<Node> op::Proposal::clone_with_new_inputs(const OutputVector& new_arg
 
 bool op::Proposal::visit_attributes(AttributeVisitor& visitor)
 {
-    visitor.on_attribute("attrs.base_size", m_attrs.base_size);
-    visitor.on_attribute("attrs.pre_nms_topn", m_attrs.pre_nms_topn);
-    visitor.on_attribute("attrs.post_nms_topn", m_attrs.post_nms_topn);
-    visitor.on_attribute("attrs.nms_thresh", m_attrs.nms_thresh);
-    visitor.on_attribute("attrs.feat_stride", m_attrs.feat_stride);
-    visitor.on_attribute("attrs.min_size", m_attrs.min_size);
-    visitor.on_attribute("attrs.ratio", m_attrs.ratio);
-    visitor.on_attribute("attrs.scale", m_attrs.scale);
-    visitor.on_attribute("attrs.clip_before_nms", m_attrs.clip_before_nms);
-    visitor.on_attribute("attrs.clip_after_nms", m_attrs.clip_after_nms);
-    visitor.on_attribute("attrs.normalize", m_attrs.normalize);
-    visitor.on_attribute("attrs.box_size_scale", m_attrs.box_size_scale);
-    visitor.on_attribute("attrs.box_coordinate_scale", m_attrs.box_coordinate_scale);
-    visitor.on_attribute("attrs.framework", m_attrs.framework);
+    visitor.on_attribute("attrs", m_attrs);
+    return true;
+}
+
+constexpr DiscreteTypeInfo AttributeAdapter<op::ProposalAttrs>::type_info;
+
+AttributeAdapter<op::ProposalAttrs>::AttributeAdapter(op::ProposalAttrs& ref)
+    : m_ref(ref)
+{
+}
+
+bool AttributeAdapter<op::ProposalAttrs>::visit_attributes(AttributeVisitor& visitor)
+{
+    visitor.on_attribute("base_size", m_ref.base_size);
+    visitor.on_attribute("pre_nms_topn", m_ref.pre_nms_topn);
+    visitor.on_attribute("post_nms_topn", m_ref.post_nms_topn);
+    visitor.on_attribute("nms_thresh", m_ref.nms_thresh);
+    visitor.on_attribute("feat_stride", m_ref.feat_stride);
+    visitor.on_attribute("min_size", m_ref.min_size);
+    visitor.on_attribute("ratio", m_ref.ratio);
+    visitor.on_attribute("scale", m_ref.scale);
+    visitor.on_attribute("clip_before_nms", m_ref.clip_before_nms);
+    visitor.on_attribute("clip_after_nms", m_ref.clip_after_nms);
+    visitor.on_attribute("normalize", m_ref.normalize);
+    visitor.on_attribute("box_size_scale", m_ref.box_size_scale);
+    visitor.on_attribute("box_coordinate_scale", m_ref.box_coordinate_scale);
+    visitor.on_attribute("framework", m_ref.framework);
     return true;
 }
index 32bca0b..51448ce 100644 (file)
@@ -85,4 +85,17 @@ namespace ngraph
         }
         using v0::Proposal;
     }
+
+    template <>
+    class NGRAPH_API AttributeAdapter<op::ProposalAttrs> : public VisitorAdapter
+    {
+    public:
+        AttributeAdapter(op::ProposalAttrs& ref);
+
+        virtual bool visit_attributes(AttributeVisitor& visitor) override;
+        static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::ProposalAttrs>", 0};
+        const DiscreteTypeInfo& get_type_info() const override { return type_info; }
+    protected:
+        op::ProposalAttrs& m_ref;
+    };
 }