"pyngraph/axis_vector.cpp",
"pyngraph/coordinate.cpp",
"pyngraph/coordinate_diff.cpp",
+ "pyngraph/dict_attribute_visitor.cpp",
"pyngraph/dimension.cpp",
"pyngraph/function.cpp",
"pyngraph/node.cpp",
+from functools import partial
from typing import Any, Dict, List, Optional
from _pyngraph import NodeFactory as _NodeFactory
) -> Node:
"""Create node object from provided description.
+ The user does not have to provide all node's attributes, but only required ones.
+
:param op_type_name: The operator type name.
:param arguments: The operator arguments.
:param attributes: The operator attributes.
if attributes is None:
attributes = {}
node = self.factory.create(op_type_name, arguments, attributes)
+
+ # Currently we don't support any attribute getters & setters for TensorIterator node.
+ if node.get_type_name() == "TensorIterator":
+ return node
+
+ # Set getters and setters for each node's attribute.
+ # node.get_attribute_name()
+ # node.set_attribute_name()
+ # For compound (with more than one level of nesting) attributes of form ie.:
+ # node.class_member_name.some_metric.attr_name:
+ # node.get_some_metric_attr_name()
+ # node.set_some_metric_attr_name()
+ # Please see test_dyn_attributes.py for more usage examples.
+ all_attributes = node._get_attributes()
+ for attr_name in all_attributes.keys():
+ setattr(node,
+ self._normalize_attr_name_getter(attr_name),
+ partial(NodeFactory._get_node_attr_value, node, attr_name))
+ setattr(node,
+ self._normalize_attr_name_setter(attr_name),
+ partial(NodeFactory._set_node_attr_value, node, attr_name))
+
+ # Setup helper members for caching attribute values.
+ # The cache would be lazily populated at first access attempt.
+ setattr(node, "_attr_cache", dict())
+ setattr(node, "_attr_cache_valid", bool(False))
+
return node
+
+ @staticmethod
+ def _normalize_attr_name(attr_name: str, prefix: str) -> str:
+ """Normalizes attribute name.
+
+ :param attr_name: The attribute name.
+ :param prefix: The prefix to attach to attribute name.
+
+ :returns: The modified attribute name.
+ """
+ # Trim first part of the name if there is only one level of attribute hierarchy.
+ if attr_name.count(".") == 1:
+ attr_name = attr_name[attr_name.find(".") + 1:]
+ return prefix + attr_name.replace(".", "_")
+
+ @classmethod
+ def _normalize_attr_name_getter(cls, attr_name: str) -> str:
+ """Normalizes atr name to be suitable for getter function name.
+
+ :param attr_name: The attribute name to normalize
+
+ :returns: The appropriate getter function name.
+ """
+ return cls._normalize_attr_name(attr_name, "get_")
+
+ @classmethod
+ def _normalize_attr_name_setter(cls, attr_name: str) -> str:
+ """Normalizes atr name to be suitable for setter function name.
+
+ :param attr_name: The attribute name to normalize
+
+ :returns: The appropriate setter function name.
+ """
+ return cls._normalize_attr_name(attr_name, "set_")
+
+ @staticmethod
+ def _get_node_attr_value(node: Node, attr_name: str) -> Any:
+ """Gets provided node attribute value.
+
+ :param node: The node we retrieve attribute value from.
+ :param attr_name: The attribute name.
+
+ :returns: The node attribute value.
+ """
+ if not node._attr_cache_valid:
+ node._attr_cache = node._get_attributes()
+ node._attr_cache_valid = True
+ return node._attr_cache[attr_name]
+
+ @staticmethod
+ def _set_node_attr_value(node: Node, attr_name: str, value: Any) -> None:
+ """Sets the node attribute value.
+
+ :param node: The node we change attribute value for.
+ :param attr_name: The attribute name.
+ :param value: The new attribute value.
+ """
+ node._set_attribute(attr_name, value)
+ node._attr_cache[attr_name] = value
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+// These are not used here, but needed in order to not violate ODR, since
+// these are included in other translation units, and specialize some types.
+// Related: https://github.com/pybind/pybind11/issues/1055
+#include <pybind11/numpy.h>
+#include <pybind11/stl.h>
+
+#include "dict_attribute_visitor.hpp"
+
+namespace py = pybind11;
+
+util::DictAttributeDeserializer::DictAttributeDeserializer(const py::dict& attributes)
+ : m_attributes(attributes)
+{
+}
+
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<void>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name);
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<bool>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<bool>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::string>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::string>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int8_t>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<int8_t>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int16_t>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<int16_t>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int32_t>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<int32_t>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int64_t>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<int64_t>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint8_t>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<uint8_t>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint16_t>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<uint16_t>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint32_t>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<uint32_t>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint64_t>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<uint64_t>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<float>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<float>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<double>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<double>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<std::string>>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<int8_t>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<int8_t>>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<int16_t>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<int16_t>>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<int32_t>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<int32_t>>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<int64_t>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<int64_t>>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<uint8_t>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<uint8_t>>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<uint16_t>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<uint16_t>>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<uint32_t>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<uint32_t>>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<uint64_t>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<uint64_t>>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<float>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<float>>());
+ }
+}
+void util::DictAttributeDeserializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<double>>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ adapter.set(m_attributes[name.c_str()].cast<std::vector<double>>());
+ }
+}
+
+util::DictAttributeSerializer::DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node)
+{
+ node->visit_attributes(*this);
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<void>& adapter)
+{
+ if (m_attributes.contains(name))
+ {
+ NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name);
+ }
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<bool>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::string>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int8_t>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int16_t>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int32_t>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int64_t>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint8_t>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint16_t>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint32_t>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint64_t>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<float>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<double>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int8_t>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int16_t>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int32_t>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int64_t>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<uint8_t>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<uint16_t>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<uint32_t>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(
+ const std::string& name, ngraph::ValueAccessor<std::vector<uint64_t>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<float>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
+void util::DictAttributeSerializer::on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<double>>& adapter)
+{
+ m_attributes[name.c_str()] = adapter.get();
+}
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "ngraph/attribute_visitor.hpp"
+#include "ngraph/node.hpp"
+
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+
+namespace util
+{
+ class DictAttributeDeserializer : public ngraph::AttributeVisitor
+ {
+ public:
+ DictAttributeDeserializer(const py::dict& attributes);
+
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<void>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<bool>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::string>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int8_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int16_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int32_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int64_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint8_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint16_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint32_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint64_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<float>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<double>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<float>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<double>>& adapter) override;
+
+ protected:
+ const py::dict& m_attributes;
+ };
+
+ class DictAttributeSerializer : public ngraph::AttributeVisitor
+ {
+ public:
+ DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node);
+
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<void>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<bool>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::string>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int8_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int16_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int32_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<int64_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint8_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint16_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint32_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<uint64_t>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<float>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<double>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<float>>& adapter) override;
+ virtual void on_adapter(const std::string& name,
+ ngraph::ValueAccessor<std::vector<double>>& adapter) override;
+
+ template <typename T>
+ T get_attribute(const std::string& name)
+ {
+ NGRAPH_CHECK(m_attributes.contains(name),
+ "Couldn't find attribute \"",
+ name,
+ "\" in serialized node attribute dictionary.");
+ return m_attributes[name.c_str()].cast<T>();
+ }
+
+ py::dict get_attributes() const { return m_attributes; }
+ protected:
+ py::dict m_attributes;
+ };
+}
// limitations under the License.
//*****************************************************************************
-#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
-#include "ngraph/node.hpp" // ngraph::Node
-#include "ngraph/op/add.hpp" // ngraph::op::Add
-#include "ngraph/op/divide.hpp" // ngraph::op::Divide
-#include "ngraph/op/multiply.hpp" // ngraph::op::Multiply
-#include "ngraph/op/subtract.hpp" // ngraph::op::Subtract
+#include "dict_attribute_visitor.hpp"
+#include "ngraph/node.hpp"
+#include "ngraph/op/add.hpp"
+#include "ngraph/op/divide.hpp"
+#include "ngraph/op/multiply.hpp"
+#include "ngraph/op/subtract.hpp"
#include "pyngraph/node.hpp"
namespace py = pybind11;
void regclass_pyngraph_Node(py::module m)
{
- py::class_<ngraph::Node, std::shared_ptr<ngraph::Node>> node(m, "Node");
+ py::class_<ngraph::Node, std::shared_ptr<ngraph::Node>> node(m, "Node", py::dynamic_attr());
node.doc() = "ngraph.impl.Node wraps ngraph::Node";
node.def("__add__",
[](const std::shared_ptr<ngraph::Node>& a, const std::shared_ptr<ngraph::Node> b) {
node.def("get_unique_name", &ngraph::Node::get_name);
node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
+ node.def_property_readonly("shape", &ngraph::Node::get_shape);
+
+ node.def("_get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
+ util::DictAttributeSerializer dict_serializer(self);
+ return dict_serializer.get_attributes();
+ });
+ node.def(
+ "_set_attribute",
+ [](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
+ py::dict attr_dict;
+ attr_dict[atr_name.c_str()] = value;
+ util::DictAttributeDeserializer dict_deserializer(attr_dict);
+ self->visit_attributes(dict_deserializer);
+ });
}
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
-#include "ngraph/attribute_visitor.hpp"
+#include "dict_attribute_visitor.hpp"
#include "ngraph/check.hpp"
-#include "ngraph/enum_names.hpp"
#include "ngraph/except.hpp"
#include "ngraph/node.hpp"
-#include "ngraph/op/constant.hpp"
#include "ngraph/opsets/opset.hpp"
-#include "ngraph/util.hpp"
#include "node_factory.hpp"
#include "tensor_iterator_builder.hpp"
namespace
{
- class DictAttributeDeserializer : public ngraph::AttributeVisitor
- {
- public:
- DictAttributeDeserializer(const py::dict& attributes)
- : m_attributes(attributes)
- {
- }
-
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<void>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- NGRAPH_CHECK(
- false, "No AttributeVisitor support for accessing attribute named: ", name);
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<bool>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<bool>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::string>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::string>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<int8_t>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<int8_t>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<int16_t>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<int16_t>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<int32_t>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<int32_t>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<int64_t>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<int64_t>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<uint8_t>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<uint8_t>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<uint16_t>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<uint16_t>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<uint32_t>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<uint32_t>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<uint64_t>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<uint64_t>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<float>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<float>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<double>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<double>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<std::string>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<std::string>>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<int8_t>>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<int16_t>>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<int32_t>>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<int64_t>>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<uint8_t>>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<uint16_t>>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<uint32_t>>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<uint64_t>>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<float>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<float>>());
- }
- }
- virtual void on_adapter(const std::string& name,
- ngraph::ValueAccessor<std::vector<double>>& adapter) override
- {
- if (m_attributes.contains(name))
- {
- adapter.set(m_attributes[name.c_str()].cast<std::vector<double>>());
- }
- }
-
- protected:
- const py::dict& m_attributes;
- };
-
class NodeFactory
{
public:
if (op_type_name == "TensorIterator")
{
- // TODO: how to differentiate opsets?
+ // XXX: How to differentiate opsets?
return util::TensorIteratorBuilder(arguments, attributes)
.configure(std::static_pointer_cast<ngraph::op::TensorIterator>(op_node));
}
- DictAttributeDeserializer visitor(attributes);
+ util::DictAttributeDeserializer visitor(attributes);
op_node->set_arguments(arguments);
op_node->visit_attributes(visitor);
node = ng.roi_pooling(inputs, coords, [6, 6], 0.0625, "Max")
assert node.get_type_name() == "ROIPooling"
- assert node.get_output_size() == 1
+ assert node.get_output_size() == [6, 6]
assert list(node.get_output_shape(0)) == [150, 3, 6, 6]
assert node.get_output_element_type(0) == Type.f32
--- /dev/null
+# ******************************************************************************
+# Copyright 2017-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import pytest
+
+import ngraph as ng
+
+
+@pytest.fixture()
+def _proposal_node():
+ attributes = {
+ "attrs.base_size": np.uint16(1),
+ "attrs.pre_nms_topn": np.uint16(20),
+ "attrs.post_nms_topn": np.uint16(64),
+ "attrs.nms_thresh": np.float64(0.34),
+ "attrs.feat_stride": np.uint16(16),
+ "attrs.min_size": np.uint16(32),
+ "attrs.ratio": np.array([0.1, 1.5, 2.0, 2.5], dtype=np.float64),
+ "attrs.scale": np.array([2, 3, 3, 4], dtype=np.float64),
+ }
+ batch_size = 7
+
+ class_probs = ng.parameter([batch_size, 12, 34, 62], np.float64, "class_probs")
+ class_logits = ng.parameter([batch_size, 24, 34, 62], np.float64, "class_logits")
+ image_shape = ng.parameter([3], np.float64, "image_shape")
+ return ng.proposal(class_probs, class_logits, image_shape, attributes)
+
+
+def test_dynamic_attributes_softmax():
+ axis = 2
+ data = ng.parameter([1, 2, 3, 4], np.float32, "data_in")
+ node = ng.softmax(data, axis)
+
+ assert node.get_axis() == axis
+ node.set_axis(3)
+ assert node.get_axis() == 3
+
+
+@pytest.mark.parametrize(
+ "int_dtype, fp_dtype",
+ [
+ (np.int8, np.float32),
+ (np.int16, np.float32),
+ (np.int32, np.float32),
+ (np.int64, np.float32),
+ (np.uint8, np.float32),
+ (np.uint16, np.float32),
+ (np.uint32, np.float32),
+ (np.uint64, np.float32),
+ (np.int32, np.float16),
+ (np.int32, np.float64),
+ ],
+)
+def test_dynamic_get_attribute_value(int_dtype, fp_dtype):
+ attributes = {
+ "attrs.num_classes": int_dtype(85),
+ "attrs.background_label_id": int_dtype(13),
+ "attrs.top_k": int_dtype(16),
+ "attrs.variance_encoded_in_target": True,
+ "attrs.keep_top_k": np.array([64, 32, 16, 8], dtype=int_dtype),
+ "attrs.code_type": "pytorch.some_parameter_name",
+ "attrs.share_location": False,
+ "attrs.nms_threshold": fp_dtype(0.645),
+ "attrs.confidence_threshold": fp_dtype(0.111),
+ "attrs.clip_after_nms": True,
+ "attrs.clip_before_nms": False,
+ "attrs.decrease_label_id": True,
+ "attrs.normalized": True,
+ "attrs.input_height": int_dtype(86),
+ "attrs.input_width": int_dtype(79),
+ "attrs.objectness_score": fp_dtype(0.77),
+ }
+
+ box_logits = ng.parameter([4, 1, 5, 5], fp_dtype, "box_logits")
+ class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "class_preds")
+ proposals = ng.parameter([2, 1, 4, 5], fp_dtype, "proposals")
+ aux_class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_class_preds")
+ aux_box_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_box_preds")
+
+ node = ng.detection_output(
+ box_logits, class_preds, proposals, attributes, aux_class_preds, aux_box_preds
+ )
+
+ assert node.get_num_classes() == int_dtype(85)
+ assert node.get_background_label_id() == int_dtype(13)
+ assert node.get_top_k() == int_dtype(16)
+ assert node.get_variance_encoded_in_target() == True
+ assert np.all(np.equal(node.get_keep_top_k(), np.array([64, 32, 16, 8], dtype=int_dtype)))
+ assert node.get_code_type() == "pytorch.some_parameter_name"
+ assert node.get_share_location() == False
+ assert np.isclose(node.get_nms_threshold(), fp_dtype(0.645))
+ assert np.isclose(node.get_confidence_threshold(), fp_dtype(0.111))
+ assert node.get_clip_after_nms() == True
+ assert node.get_clip_before_nms() == False
+ assert node.get_decrease_label_id() == True
+ assert node.get_normalized() == True
+ assert node.get_input_height() == int_dtype(86)
+ assert node.get_input_width() == int_dtype(79)
+ assert np.isclose(node.get_objectness_score(), fp_dtype(0.77))
+ assert node.get_num_classes() == int_dtype(85)
+
+
+@pytest.mark.parametrize(
+ "int_dtype, fp_dtype",
+ [
+ (np.uint8, np.float32),
+ (np.uint16, np.float32),
+ (np.uint32, np.float32),
+ (np.uint64, np.float32),
+ (np.uint32, np.float16),
+ (np.uint32, np.float64),
+ ],
+)
+def test_dynamic_set_attribute_value(int_dtype, fp_dtype):
+ attributes = {
+ "attrs.base_size": int_dtype(1),
+ "attrs.pre_nms_topn": int_dtype(20),
+ "attrs.post_nms_topn": int_dtype(64),
+ "attrs.nms_thresh": fp_dtype(0.34),
+ "attrs.feat_stride": int_dtype(16),
+ "attrs.min_size": int_dtype(32),
+ "attrs.ratio": np.array([0.1, 1.5, 2.0, 2.5], dtype=fp_dtype),
+ "attrs.scale": np.array([2, 3, 3, 4], dtype=fp_dtype),
+ }
+ batch_size = 7
+
+ class_probs = ng.parameter([batch_size, 12, 34, 62], fp_dtype, "class_probs")
+ class_logits = ng.parameter([batch_size, 24, 34, 62], fp_dtype, "class_logits")
+ image_shape = ng.parameter([3], fp_dtype, "image_shape")
+ node = ng.proposal(class_probs, class_logits, image_shape, attributes)
+
+ node.set_base_size(int_dtype(15))
+ node.set_pre_nms_topn(int_dtype(7))
+ node.set_post_nms_topn(int_dtype(33))
+ node.set_nms_thresh(fp_dtype(1.55))
+ node.set_feat_stride(int_dtype(8))
+ node.set_min_size(int_dtype(123))
+ node.set_ratio(np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype))
+ node.set_scale(np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype))
+ node.set_clip_before_nms(True)
+ node.set_clip_after_nms(True)
+ node.set_normalize(True)
+ node.set_box_size_scale(fp_dtype(1.34))
+ node.set_box_coordinate_scale(fp_dtype(0.88))
+ node.set_framework("OpenVINO")
+
+ assert node.get_base_size() == int_dtype(15)
+ assert node.get_pre_nms_topn() == int_dtype(7)
+ assert node.get_post_nms_topn() == int_dtype(33)
+ assert np.isclose(node.get_nms_thresh(), fp_dtype(1.55))
+ assert node.get_feat_stride() == int_dtype(8)
+ assert node.get_min_size() == int_dtype(123)
+ assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype))
+ assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype))
+ assert node.get_clip_before_nms() == True
+ assert node.get_clip_after_nms() == True
+ assert node.get_normalize() == True
+ assert np.isclose(node.get_box_size_scale(), fp_dtype(1.34))
+ assert np.isclose(node.get_box_coordinate_scale(), fp_dtype(0.88))
+ assert node.get_framework() == "OpenVINO"
+
+
+def test_dynamic_attr_cache(_proposal_node):
+ node = _proposal_node
+
+ assert not node._attr_cache_valid
+ node.set_nms_thresh(1.3453678102)
+ assert not node._attr_cache_valid
+ assert np.isclose(node.get_nms_thresh(), np.float64(1.3453678102))
+ assert node._attr_cache_valid
+
+
+def test_dynamic_attr_transitivity(_proposal_node):
+ node = _proposal_node
+ node2 = node
+
+ node.set_ratio(np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
+ assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
+ assert np.allclose(node2.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
+
+ node2.set_scale(np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
+ assert np.allclose(node2.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
+ assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
+
+
+def test_dynamic_attributes_simple():
+ batch_size = 1
+ input_size = 16
+ hidden_size = 128
+
+ X_shape = [batch_size, input_size]
+ H_t_shape = [batch_size, hidden_size]
+ W_shape = [3 * hidden_size, input_size]
+ R_shape = [3 * hidden_size, hidden_size]
+ B_shape = [4 * hidden_size]
+
+ parameter_X = ng.parameter(X_shape, name="X", dtype=np.float32)
+ parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=np.float32)
+ parameter_W = ng.parameter(W_shape, name="W", dtype=np.float32)
+ parameter_R = ng.parameter(R_shape, name="R", dtype=np.float32)
+ parameter_B = ng.parameter(B_shape, name="B", dtype=np.float32)
+
+ activations = ["tanh", "relu"]
+ activations_alpha = [1.0, 2.0]
+ activations_beta = [1.0, 2.0]
+ clip = 0.5
+ linear_before_reset = True
+
+ node = ng.gru_cell(
+ parameter_X,
+ parameter_H_t,
+ parameter_W,
+ parameter_R,
+ parameter_B,
+ hidden_size,
+ activations,
+ activations_alpha,
+ activations_beta,
+ clip,
+ linear_before_reset,
+ )
+
+ assert node.get_hidden_size() == hidden_size
+ assert all(map(lambda x, y: x == y, node.get_activations(), activations))
+ assert all(np.equal(node.get_activations_alpha(), activations_alpha))
+ assert all(np.equal(node.get_activations_beta(), activations_beta))
+ assert node.get_linear_before_reset() == linear_before_reset
+ assert np.isclose(node.get_clip(), clip)