Dynamic attribute getters and setters. (#964)
authorAdam Osewski <adam.osewski@intel.com>
Fri, 26 Jun 2020 14:35:00 +0000 (16:35 +0200)
committerGitHub <noreply@github.com>
Fri, 26 Jun 2020 14:35:00 +0000 (16:35 +0200)
ngraph/python/setup.py
ngraph/python/src/ngraph/utils/node_factory.py
ngraph/python/src/pyngraph/dict_attribute_visitor.cpp [new file with mode: 0644]
ngraph/python/src/pyngraph/dict_attribute_visitor.hpp [new file with mode: 0644]
ngraph/python/src/pyngraph/node.cpp
ngraph/python/src/pyngraph/node_factory.cpp
ngraph/python/test/ngraph/test_create_op.py
ngraph/python/test/ngraph/test_dyn_attributes.py [new file with mode: 0644]

index 288fe4d..8c0d589 100644 (file)
@@ -182,6 +182,7 @@ sources = [
     "pyngraph/axis_vector.cpp",
     "pyngraph/coordinate.cpp",
     "pyngraph/coordinate_diff.cpp",
+    "pyngraph/dict_attribute_visitor.cpp",
     "pyngraph/dimension.cpp",
     "pyngraph/function.cpp",
     "pyngraph/node.cpp",
index 70750ce..d07ac3d 100644 (file)
@@ -1,3 +1,4 @@
+from functools import partial
 from typing import Any, Dict, List, Optional
 
 from _pyngraph import NodeFactory as _NodeFactory
@@ -21,6 +22,8 @@ class NodeFactory(object):
     ) -> Node:
         """Create node object from provided description.
 
+        The user does not have to provide all node's attributes, but only required ones.
+
         :param      op_type_name:  The operator type name.
         :param      arguments:     The operator arguments.
         :param      attributes:    The operator attributes.
@@ -30,4 +33,90 @@ class NodeFactory(object):
         if attributes is None:
             attributes = {}
         node = self.factory.create(op_type_name, arguments, attributes)
+
+        # Currently we don't support any attribute getters & setters for TensorIterator node.
+        if node.get_type_name() == "TensorIterator":
+            return node
+
+        # Set getters and setters for each node's attribute.
+        #   node.get_attribute_name()
+        #   node.set_attribute_name()
+        # For compound (with more than one level of nesting) attributes of form ie.:
+        # node.class_member_name.some_metric.attr_name:
+        #   node.get_some_metric_attr_name()
+        #   node.set_some_metric_attr_name()
+        # Please see test_dyn_attributes.py for more usage examples.
+        all_attributes = node._get_attributes()
+        for attr_name in all_attributes.keys():
+            setattr(node,
+                    self._normalize_attr_name_getter(attr_name),
+                    partial(NodeFactory._get_node_attr_value, node, attr_name))
+            setattr(node,
+                    self._normalize_attr_name_setter(attr_name),
+                    partial(NodeFactory._set_node_attr_value, node, attr_name))
+
+        # Setup helper members for caching attribute values.
+        # The cache would be lazily populated at first access attempt.
+        setattr(node, "_attr_cache", dict())
+        setattr(node, "_attr_cache_valid", bool(False))
+
         return node
+
+    @staticmethod
+    def _normalize_attr_name(attr_name: str, prefix: str) -> str:
+        """Normalizes attribute name.
+
+        :param      attr_name:  The attribute name.
+        :param      prefix:     The prefix to attach to attribute name.
+
+        :returns:   The modified attribute name.
+        """
+        # Trim first part of the name if there is only one level of attribute hierarchy.
+        if attr_name.count(".") == 1:
+            attr_name = attr_name[attr_name.find(".") + 1:]
+        return prefix + attr_name.replace(".", "_")
+
+    @classmethod
+    def _normalize_attr_name_getter(cls, attr_name: str) -> str:
+        """Normalizes atr name to be suitable for getter function name.
+
+        :param      attr_name:  The attribute name to normalize
+
+        :returns:   The appropriate getter function name.
+        """
+        return cls._normalize_attr_name(attr_name, "get_")
+
+    @classmethod
+    def _normalize_attr_name_setter(cls, attr_name: str) -> str:
+        """Normalizes atr name to be suitable for setter function name.
+
+        :param      attr_name:  The attribute name to normalize
+
+        :returns:   The appropriate setter function name.
+        """
+        return cls._normalize_attr_name(attr_name, "set_")
+
+    @staticmethod
+    def _get_node_attr_value(node: Node, attr_name: str) -> Any:
+        """Gets provided node attribute value.
+
+        :param      node:       The node we retrieve attribute value from.
+        :param      attr_name:  The attribute name.
+
+        :returns:   The node attribute value.
+        """
+        if not node._attr_cache_valid:
+            node._attr_cache = node._get_attributes()
+            node._attr_cache_valid = True
+        return node._attr_cache[attr_name]
+
+    @staticmethod
+    def _set_node_attr_value(node: Node, attr_name: str, value: Any) -> None:
+        """Sets the node attribute value.
+
+        :param      node:       The node we change attribute value for.
+        :param      attr_name:  The attribute name.
+        :param      value:      The new attribute value.
+        """
+        node._set_attribute(attr_name, value)
+        node._attr_cache[attr_name] = value
diff --git a/ngraph/python/src/pyngraph/dict_attribute_visitor.cpp b/ngraph/python/src/pyngraph/dict_attribute_visitor.cpp
new file mode 100644 (file)
index 0000000..246ca80
--- /dev/null
@@ -0,0 +1,351 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+// These are not used here, but needed in order to not violate ODR, since
+// these are included in other translation units, and specialize some types.
+// Related: https://github.com/pybind/pybind11/issues/1055
+#include <pybind11/numpy.h>
+#include <pybind11/stl.h>
+
+#include "dict_attribute_visitor.hpp"
+
+namespace py = pybind11;
+
+util::DictAttributeDeserializer::DictAttributeDeserializer(const py::dict& attributes)
+    : m_attributes(attributes)
+{
+}
+
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<void>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name);
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<bool>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<bool>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<std::string>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::string>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<int8_t>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<int8_t>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<int16_t>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<int16_t>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<int32_t>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<int32_t>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<int64_t>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<int64_t>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<uint8_t>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<uint8_t>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<uint16_t>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<uint16_t>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<uint32_t>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<uint32_t>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<uint64_t>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<uint64_t>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<float>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<float>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<double>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<double>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<std::string>>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<int8_t>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<int8_t>>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<int16_t>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<int16_t>>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<int32_t>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<int32_t>>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<int64_t>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<int64_t>>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<uint8_t>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<uint8_t>>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<uint16_t>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<uint16_t>>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<uint32_t>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<uint32_t>>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<uint64_t>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<uint64_t>>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+                                                 ngraph::ValueAccessor<std::vector<float>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<float>>());
+    }
+}
+void util::DictAttributeDeserializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<double>>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        adapter.set(m_attributes[name.c_str()].cast<std::vector<double>>());
+    }
+}
+
+util::DictAttributeSerializer::DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node)
+{
+    node->visit_attributes(*this);
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<void>& adapter)
+{
+    if (m_attributes.contains(name))
+    {
+        NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name);
+    }
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<bool>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<std::string>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<int8_t>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<int16_t>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<int32_t>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<int64_t>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<uint8_t>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<uint16_t>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<uint32_t>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<uint64_t>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<float>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<double>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<std::vector<int8_t>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<std::vector<int16_t>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<std::vector<int32_t>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<std::vector<int64_t>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<std::vector<uint8_t>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<uint16_t>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<uint32_t>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(
+    const std::string& name, ngraph::ValueAccessor<std::vector<uint64_t>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<std::vector<float>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+                                               ngraph::ValueAccessor<std::vector<double>>& adapter)
+{
+    m_attributes[name.c_str()] = adapter.get();
+}
diff --git a/ngraph/python/src/pyngraph/dict_attribute_visitor.hpp b/ngraph/python/src/pyngraph/dict_attribute_visitor.hpp
new file mode 100644 (file)
index 0000000..21978cc
--- /dev/null
@@ -0,0 +1,158 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "ngraph/attribute_visitor.hpp"
+#include "ngraph/node.hpp"
+
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+
+namespace util
+{
+    class DictAttributeDeserializer : public ngraph::AttributeVisitor
+    {
+    public:
+        DictAttributeDeserializer(const py::dict& attributes);
+
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<void>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<bool>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::string>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<int8_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<int16_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<int32_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<int64_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<uint8_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<uint16_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<uint32_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<uint64_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<float>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<double>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<float>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<double>>& adapter) override;
+
+    protected:
+        const py::dict& m_attributes;
+    };
+
+    class DictAttributeSerializer : public ngraph::AttributeVisitor
+    {
+    public:
+        DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node);
+
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<void>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<bool>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::string>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<int8_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<int16_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<int32_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<int64_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<uint8_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<uint16_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<uint32_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<uint64_t>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<float>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<double>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<float>>& adapter) override;
+        virtual void on_adapter(const std::string& name,
+                                ngraph::ValueAccessor<std::vector<double>>& adapter) override;
+
+        template <typename T>
+        T get_attribute(const std::string& name)
+        {
+            NGRAPH_CHECK(m_attributes.contains(name),
+                         "Couldn't find attribute \"",
+                         name,
+                         "\" in serialized node attribute dictionary.");
+            return m_attributes[name.c_str()].cast<T>();
+        }
+
+        py::dict get_attributes() const { return m_attributes; }
+    protected:
+        py::dict m_attributes;
+    };
+}
index f00205f..9db7b4d 100644 (file)
 // limitations under the License.
 //*****************************************************************************
 
-#include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
 
-#include "ngraph/node.hpp"        // ngraph::Node
-#include "ngraph/op/add.hpp"      // ngraph::op::Add
-#include "ngraph/op/divide.hpp"   // ngraph::op::Divide
-#include "ngraph/op/multiply.hpp" // ngraph::op::Multiply
-#include "ngraph/op/subtract.hpp" // ngraph::op::Subtract
+#include "dict_attribute_visitor.hpp"
+#include "ngraph/node.hpp"
+#include "ngraph/op/add.hpp"
+#include "ngraph/op/divide.hpp"
+#include "ngraph/op/multiply.hpp"
+#include "ngraph/op/subtract.hpp"
 #include "pyngraph/node.hpp"
 
 namespace py = pybind11;
 
 void regclass_pyngraph_Node(py::module m)
 {
-    py::class_<ngraph::Node, std::shared_ptr<ngraph::Node>> node(m, "Node");
+    py::class_<ngraph::Node, std::shared_ptr<ngraph::Node>> node(m, "Node", py::dynamic_attr());
     node.doc() = "ngraph.impl.Node wraps ngraph::Node";
     node.def("__add__",
              [](const std::shared_ptr<ngraph::Node>& a, const std::shared_ptr<ngraph::Node> b) {
@@ -79,4 +79,18 @@ void regclass_pyngraph_Node(py::module m)
     node.def("get_unique_name", &ngraph::Node::get_name);
 
     node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
+    node.def_property_readonly("shape", &ngraph::Node::get_shape);
+
+    node.def("_get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
+        util::DictAttributeSerializer dict_serializer(self);
+        return dict_serializer.get_attributes();
+    });
+    node.def(
+        "_set_attribute",
+        [](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
+            py::dict attr_dict;
+            attr_dict[atr_name.c_str()] = value;
+            util::DictAttributeDeserializer dict_deserializer(attr_dict);
+            self->visit_attributes(dict_deserializer);
+        });
 }
index ea54e00..d7cfca7 100644 (file)
 #include <pybind11/numpy.h>
 #include <pybind11/stl.h>
 
-#include "ngraph/attribute_visitor.hpp"
+#include "dict_attribute_visitor.hpp"
 #include "ngraph/check.hpp"
-#include "ngraph/enum_names.hpp"
 #include "ngraph/except.hpp"
 #include "ngraph/node.hpp"
-#include "ngraph/op/constant.hpp"
 #include "ngraph/opsets/opset.hpp"
-#include "ngraph/util.hpp"
 #include "node_factory.hpp"
 #include "tensor_iterator_builder.hpp"
 
@@ -41,212 +38,6 @@ namespace py = pybind11;
 
 namespace
 {
-    class DictAttributeDeserializer : public ngraph::AttributeVisitor
-    {
-    public:
-        DictAttributeDeserializer(const py::dict& attributes)
-            : m_attributes(attributes)
-        {
-        }
-
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<void>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                NGRAPH_CHECK(
-                    false, "No AttributeVisitor support for accessing attribute named: ", name);
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<bool>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<bool>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::string>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::string>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<int8_t>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<int8_t>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<int16_t>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<int16_t>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<int32_t>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<int32_t>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<int64_t>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<int64_t>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<uint8_t>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<uint8_t>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<uint16_t>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<uint16_t>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<uint32_t>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<uint32_t>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<uint64_t>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<uint64_t>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<float>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<float>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<double>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<double>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<std::string>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<std::string>>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<int8_t>>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<int16_t>>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<int32_t>>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<int64_t>>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<uint8_t>>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<uint16_t>>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<uint32_t>>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<uint64_t>>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<float>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<float>>());
-            }
-        }
-        virtual void on_adapter(const std::string& name,
-                                ngraph::ValueAccessor<std::vector<double>>& adapter) override
-        {
-            if (m_attributes.contains(name))
-            {
-                adapter.set(m_attributes[name.c_str()].cast<std::vector<double>>());
-            }
-        }
-
-    protected:
-        const py::dict& m_attributes;
-    };
-
     class NodeFactory
     {
     public:
@@ -270,12 +61,12 @@ namespace
 
             if (op_type_name == "TensorIterator")
             {
-                // TODO: how to differentiate opsets?
+                // XXX: How to differentiate opsets?
                 return util::TensorIteratorBuilder(arguments, attributes)
                     .configure(std::static_pointer_cast<ngraph::op::TensorIterator>(op_node));
             }
 
-            DictAttributeDeserializer visitor(attributes);
+            util::DictAttributeDeserializer visitor(attributes);
 
             op_node->set_arguments(arguments);
             op_node->visit_attributes(visitor);
index c815eaf..24b7cca 100644 (file)
@@ -508,7 +508,7 @@ def test_roi_pooling():
     node = ng.roi_pooling(inputs, coords, [6, 6], 0.0625, "Max")
 
     assert node.get_type_name() == "ROIPooling"
-    assert node.get_output_size() == 1
+    assert node.get_output_size() == [6, 6]
     assert list(node.get_output_shape(0)) == [150, 3, 6, 6]
     assert node.get_output_element_type(0) == Type.f32
 
diff --git a/ngraph/python/test/ngraph/test_dyn_attributes.py b/ngraph/python/test/ngraph/test_dyn_attributes.py
new file mode 100644 (file)
index 0000000..8b6fb8a
--- /dev/null
@@ -0,0 +1,241 @@
+# ******************************************************************************
+# Copyright 2017-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import pytest
+
+import ngraph as ng
+
+
+@pytest.fixture()
+def _proposal_node():
+    attributes = {
+        "attrs.base_size": np.uint16(1),
+        "attrs.pre_nms_topn": np.uint16(20),
+        "attrs.post_nms_topn": np.uint16(64),
+        "attrs.nms_thresh": np.float64(0.34),
+        "attrs.feat_stride": np.uint16(16),
+        "attrs.min_size": np.uint16(32),
+        "attrs.ratio": np.array([0.1, 1.5, 2.0, 2.5], dtype=np.float64),
+        "attrs.scale": np.array([2, 3, 3, 4], dtype=np.float64),
+    }
+    batch_size = 7
+
+    class_probs = ng.parameter([batch_size, 12, 34, 62], np.float64, "class_probs")
+    class_logits = ng.parameter([batch_size, 24, 34, 62], np.float64, "class_logits")
+    image_shape = ng.parameter([3], np.float64, "image_shape")
+    return ng.proposal(class_probs, class_logits, image_shape, attributes)
+
+
+def test_dynamic_attributes_softmax():
+    axis = 2
+    data = ng.parameter([1, 2, 3, 4], np.float32, "data_in")
+    node = ng.softmax(data, axis)
+
+    assert node.get_axis() == axis
+    node.set_axis(3)
+    assert node.get_axis() == 3
+
+
+@pytest.mark.parametrize(
+    "int_dtype, fp_dtype",
+    [
+        (np.int8, np.float32),
+        (np.int16, np.float32),
+        (np.int32, np.float32),
+        (np.int64, np.float32),
+        (np.uint8, np.float32),
+        (np.uint16, np.float32),
+        (np.uint32, np.float32),
+        (np.uint64, np.float32),
+        (np.int32, np.float16),
+        (np.int32, np.float64),
+    ],
+)
+def test_dynamic_get_attribute_value(int_dtype, fp_dtype):
+    attributes = {
+        "attrs.num_classes": int_dtype(85),
+        "attrs.background_label_id": int_dtype(13),
+        "attrs.top_k": int_dtype(16),
+        "attrs.variance_encoded_in_target": True,
+        "attrs.keep_top_k": np.array([64, 32, 16, 8], dtype=int_dtype),
+        "attrs.code_type": "pytorch.some_parameter_name",
+        "attrs.share_location": False,
+        "attrs.nms_threshold": fp_dtype(0.645),
+        "attrs.confidence_threshold": fp_dtype(0.111),
+        "attrs.clip_after_nms": True,
+        "attrs.clip_before_nms": False,
+        "attrs.decrease_label_id": True,
+        "attrs.normalized": True,
+        "attrs.input_height": int_dtype(86),
+        "attrs.input_width": int_dtype(79),
+        "attrs.objectness_score": fp_dtype(0.77),
+    }
+
+    box_logits = ng.parameter([4, 1, 5, 5], fp_dtype, "box_logits")
+    class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "class_preds")
+    proposals = ng.parameter([2, 1, 4, 5], fp_dtype, "proposals")
+    aux_class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_class_preds")
+    aux_box_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_box_preds")
+
+    node = ng.detection_output(
+        box_logits, class_preds, proposals, attributes, aux_class_preds, aux_box_preds
+    )
+
+    assert node.get_num_classes() == int_dtype(85)
+    assert node.get_background_label_id() == int_dtype(13)
+    assert node.get_top_k() == int_dtype(16)
+    assert node.get_variance_encoded_in_target() == True
+    assert np.all(np.equal(node.get_keep_top_k(), np.array([64, 32, 16, 8], dtype=int_dtype)))
+    assert node.get_code_type() == "pytorch.some_parameter_name"
+    assert node.get_share_location() == False
+    assert np.isclose(node.get_nms_threshold(), fp_dtype(0.645))
+    assert np.isclose(node.get_confidence_threshold(), fp_dtype(0.111))
+    assert node.get_clip_after_nms() == True
+    assert node.get_clip_before_nms() == False
+    assert node.get_decrease_label_id() == True
+    assert node.get_normalized() == True
+    assert node.get_input_height() == int_dtype(86)
+    assert node.get_input_width() == int_dtype(79)
+    assert np.isclose(node.get_objectness_score(), fp_dtype(0.77))
+    assert node.get_num_classes() == int_dtype(85)
+
+
+@pytest.mark.parametrize(
+    "int_dtype, fp_dtype",
+    [
+        (np.uint8, np.float32),
+        (np.uint16, np.float32),
+        (np.uint32, np.float32),
+        (np.uint64, np.float32),
+        (np.uint32, np.float16),
+        (np.uint32, np.float64),
+    ],
+)
+def test_dynamic_set_attribute_value(int_dtype, fp_dtype):
+    attributes = {
+        "attrs.base_size": int_dtype(1),
+        "attrs.pre_nms_topn": int_dtype(20),
+        "attrs.post_nms_topn": int_dtype(64),
+        "attrs.nms_thresh": fp_dtype(0.34),
+        "attrs.feat_stride": int_dtype(16),
+        "attrs.min_size": int_dtype(32),
+        "attrs.ratio": np.array([0.1, 1.5, 2.0, 2.5], dtype=fp_dtype),
+        "attrs.scale": np.array([2, 3, 3, 4], dtype=fp_dtype),
+    }
+    batch_size = 7
+
+    class_probs = ng.parameter([batch_size, 12, 34, 62], fp_dtype, "class_probs")
+    class_logits = ng.parameter([batch_size, 24, 34, 62], fp_dtype, "class_logits")
+    image_shape = ng.parameter([3], fp_dtype, "image_shape")
+    node = ng.proposal(class_probs, class_logits, image_shape, attributes)
+
+    node.set_base_size(int_dtype(15))
+    node.set_pre_nms_topn(int_dtype(7))
+    node.set_post_nms_topn(int_dtype(33))
+    node.set_nms_thresh(fp_dtype(1.55))
+    node.set_feat_stride(int_dtype(8))
+    node.set_min_size(int_dtype(123))
+    node.set_ratio(np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype))
+    node.set_scale(np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype))
+    node.set_clip_before_nms(True)
+    node.set_clip_after_nms(True)
+    node.set_normalize(True)
+    node.set_box_size_scale(fp_dtype(1.34))
+    node.set_box_coordinate_scale(fp_dtype(0.88))
+    node.set_framework("OpenVINO")
+
+    assert node.get_base_size() == int_dtype(15)
+    assert node.get_pre_nms_topn() == int_dtype(7)
+    assert node.get_post_nms_topn() == int_dtype(33)
+    assert np.isclose(node.get_nms_thresh(), fp_dtype(1.55))
+    assert node.get_feat_stride() == int_dtype(8)
+    assert node.get_min_size() == int_dtype(123)
+    assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype))
+    assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype))
+    assert node.get_clip_before_nms() == True
+    assert node.get_clip_after_nms() == True
+    assert node.get_normalize() == True
+    assert np.isclose(node.get_box_size_scale(), fp_dtype(1.34))
+    assert np.isclose(node.get_box_coordinate_scale(), fp_dtype(0.88))
+    assert node.get_framework() == "OpenVINO"
+
+
+def test_dynamic_attr_cache(_proposal_node):
+    node = _proposal_node
+
+    assert not node._attr_cache_valid
+    node.set_nms_thresh(1.3453678102)
+    assert not node._attr_cache_valid
+    assert np.isclose(node.get_nms_thresh(), np.float64(1.3453678102))
+    assert node._attr_cache_valid
+
+
+def test_dynamic_attr_transitivity(_proposal_node):
+    node = _proposal_node
+    node2 = node
+
+    node.set_ratio(np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
+    assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
+    assert np.allclose(node2.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
+
+    node2.set_scale(np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
+    assert np.allclose(node2.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
+    assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
+
+
+def test_dynamic_attributes_simple():
+    batch_size = 1
+    input_size = 16
+    hidden_size = 128
+
+    X_shape = [batch_size, input_size]
+    H_t_shape = [batch_size, hidden_size]
+    W_shape = [3 * hidden_size, input_size]
+    R_shape = [3 * hidden_size, hidden_size]
+    B_shape = [4 * hidden_size]
+
+    parameter_X = ng.parameter(X_shape, name="X", dtype=np.float32)
+    parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=np.float32)
+    parameter_W = ng.parameter(W_shape, name="W", dtype=np.float32)
+    parameter_R = ng.parameter(R_shape, name="R", dtype=np.float32)
+    parameter_B = ng.parameter(B_shape, name="B", dtype=np.float32)
+
+    activations = ["tanh", "relu"]
+    activations_alpha = [1.0, 2.0]
+    activations_beta = [1.0, 2.0]
+    clip = 0.5
+    linear_before_reset = True
+
+    node = ng.gru_cell(
+        parameter_X,
+        parameter_H_t,
+        parameter_W,
+        parameter_R,
+        parameter_B,
+        hidden_size,
+        activations,
+        activations_alpha,
+        activations_beta,
+        clip,
+        linear_before_reset,
+    )
+
+    assert node.get_hidden_size() == hidden_size
+    assert all(map(lambda x, y: x == y, node.get_activations(), activations))
+    assert all(np.equal(node.get_activations_alpha(), activations_alpha))
+    assert all(np.equal(node.get_activations_beta(), activations_beta))
+    assert node.get_linear_before_reset() == linear_before_reset
+    assert np.isclose(node.get_clip(), clip)