From 3d85187dd62b30c64ab4e8ac3a03ecb7d90d43b6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Vladimir=20Plazun/AI=20Tools=20Lab/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 29 May 2018 17:24:26 +0400 Subject: [PATCH] Add IRNode class (#256) Add IRNode class This class used to represent single operation in computation graph along with its connections Signed-off-by: Vladimir Plazun --- .../core/include/nnc/core/IR/model/graph/ir_node.h | 134 +++++++++++++++++++++ .../libs/core/src/core/IR/model/graph/ir_node.cpp | 38 ++++++ .../core/src/core/IR/model/graph/ir_node.test.cpp | 23 ++++ 3 files changed, 195 insertions(+) create mode 100644 contrib/nnc/libs/core/include/nnc/core/IR/model/graph/ir_node.h create mode 100644 contrib/nnc/libs/core/src/core/IR/model/graph/ir_node.cpp create mode 100644 contrib/nnc/libs/core/src/core/IR/model/graph/ir_node.test.cpp diff --git a/contrib/nnc/libs/core/include/nnc/core/IR/model/graph/ir_node.h b/contrib/nnc/libs/core/include/nnc/core/IR/model/graph/ir_node.h new file mode 100644 index 0000000..34822ef --- /dev/null +++ b/contrib/nnc/libs/core/include/nnc/core/IR/model/graph/ir_node.h @@ -0,0 +1,134 @@ +#ifndef _NNC_CORE_IR_MODEL_NODE_H_ +#define _NNC_CORE_IR_MODEL_NODE_H_ + +#include +#include +#include + +#include "nnc/core/IR/model/operations/operation.h" +#include "nnc/core/IR/model/visitor/visitor.h" + +namespace nncc +{ +namespace contrib +{ +namespace core +{ +namespace IR +{ +namespace model +{ + +namespace ADT +{ + +class INode +{ +public: + using Ref = INode *; + + struct IODescriptor + { + ADT::INode* const node; // Data source + const size_t index; // Output id + }; + + virtual const std::vector &getPrevNodes() const = 0; + virtual const std::vector &getNextNodes() const = 0; + + virtual size_t getId() const = 0; + + virtual OpDescription *getOperation() = 0; + + virtual const std::string &getName() const = 0; + virtual void setName(const std::string &name) = 0; + + virtual void accept(Visitor *v) = 0; + + virtual const IODescriptor getOutput(const size_t index) = 0; + virtual void connectInputTo(const int inputIndex, const IODescriptor &descriptor) = 0; + + virtual ~INode() = default; + +protected: + virtual void addNextNode(const INode::Ref) = 0; +}; + +class AbstractNode : public INode +{ +public: + const std::vector &getPrevNodes() const override; + const std::vector &getNextNodes() const override; + void connectInputTo(const int inputIndex, const IODescriptor &descriptor) override; + const IODescriptor getOutput(const size_t index) override; + + protected: + virtual void addNextNode(ADT::INode::Ref const node) override; + +private: + std::vector _inputs; + std::vector _outputs; +}; + +} // namespace ADT + +struct NodeProperties +{ + explicit NodeProperties(std::string name, const size_t id, OpDescription *op = nullptr) + : name(std::move(name)), id(id), op(op) + { + } + + std::string name; + OpDescription *op; + const size_t id; + + NodeProperties(NodeProperties &&nodeProps) noexcept : op(nodeProps.op), id(nodeProps.id), name(std::move(nodeProps.name)) + { + nodeProps.op = nullptr; + } +}; + +template +class Node : public ADT::AbstractNode +{ +public: + OpType *getOperation() override { return static_cast(_props.op); } + + template + static Node *createNode(const std::string &nodeName, size_t id, Args &&... args) + { + auto node = + new Node(NodeProperties(nodeName, id, new OpType(std::forward(args)...))); + return node; + }; + + size_t getId() const override { return _props.id; }; + + const std::string &getName() const override { return _props.name; }; + + void setName(const std::string &name) override { _props.name = name; } + + void accept(Visitor *v) override + { + //TODO: enable this when at least one Visitor operation declared + // v->visit(this, *static_cast(_props.op)); + } + + ~Node() override { + delete _props.op; + } + +private: + explicit Node(NodeProperties &&properties) : _props(std::move(properties)) {}; + + NodeProperties _props; +}; + +} // namespace model +} // namespace IR +} // namespace core +} // namespace contrib +} // namespace nncc + +#endif //_NNC_CORE_IR_MODEL_NODE_H_ diff --git a/contrib/nnc/libs/core/src/core/IR/model/graph/ir_node.cpp b/contrib/nnc/libs/core/src/core/IR/model/graph/ir_node.cpp new file mode 100644 index 0000000..f2dbd42 --- /dev/null +++ b/contrib/nnc/libs/core/src/core/IR/model/graph/ir_node.cpp @@ -0,0 +1,38 @@ +#include "nnc/core/IR/model/graph/ir_node.h" + +namespace nncc +{ +namespace contrib +{ +namespace core +{ +namespace IR +{ +namespace model +{ + +const std::vector &ADT::AbstractNode::getNextNodes() const { return _outputs; } + +const std::vector &ADT::AbstractNode::getPrevNodes() const +{ + return _inputs; +} + +void ADT::AbstractNode::connectInputTo(const int inputIndex, const IODescriptor &descriptor) +{ + dynamic_cast(descriptor.node)->addNextNode(this); + _inputs.push_back(descriptor); +} + +void ADT::AbstractNode::addNextNode(ADT::INode::Ref const node) { _outputs.emplace_back(node); } + +const ADT::INode::IODescriptor ADT::AbstractNode::getOutput(size_t index) +{ + return IODescriptor{.node = this, .index = index}; +} + +} // namespace model +} // namespace IR +} // namespace core +} // namespace contrib +} // namespace nncc diff --git a/contrib/nnc/libs/core/src/core/IR/model/graph/ir_node.test.cpp b/contrib/nnc/libs/core/src/core/IR/model/graph/ir_node.test.cpp new file mode 100644 index 0000000..cd0f857 --- /dev/null +++ b/contrib/nnc/libs/core/src/core/IR/model/graph/ir_node.test.cpp @@ -0,0 +1,23 @@ +#include "nnc/core/IR/model/operations/operation.h" +#include "nnc/core/IR/model/graph/ir_node.h" +#include "nncc/core/ADT/feature/Shape.h" + + +#include + + +class DummyOperation : public nncc::contrib::core::IR::model::OpDescription { + public: + DummyOperation() : OpDescription(1, 1) {} +}; + +TEST(IRNode, ConnectionTest) { + using namespace nncc::contrib::core::IR::model; + + auto node1 = Node::createNode("node1", 0); + auto node2 = Node::createNode("node2", 1); + + node2->connectInputTo(0, node1->getOutput(0)); + + ASSERT_EQ(node1->getId(), node2->getPrevNodes()[0].node->getId()); +} -- 2.7.4