Provides replaceNode which is used to replace a node in graph preserving all edges
Provides replaceNodeWithInput and replaceInputNodes methods, used to replace any node with input node
Provides replaceOutputNodes, used to replace graph output nodes with new list
Adds unit tests for all implementations
Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
#include <deque>
#include <set>
+#include <algorithm>
#include "core/modelIR/graph.h"
namespace nnc {
namespace mir {
+/**
+ * @brief replace all usages of node `node` with node `with`
+ * (i.e. all references in previous/next nodes )
+ * @param inode a node to replace
+ * @param with a node to use as a replacement
+ */
+static void replaceUsages(const INode* node, INode* with) {
+ auto with_node = dynamic_cast<AbstractNode*>(with);
+ assert(with_node);
+
+ //For each output replace prev references to `node` by `with`
+ for (auto out : node->getNextNodes()) {
+ auto anode = dynamic_cast<AbstractNode*>(out);
+ assert(anode && "Unexpected node type");
+
+ for (auto& prev : anode->getMutablePrevNodes()) {
+ if (prev.node == node)
+ prev.node = with;
+ }
+ }
+
+ with_node->getMutableNextNodes() = node->getNextNodes();
+
+ //For each input replace next references to `node` by `with`
+ for (auto& in : node->getPrevNodes()) {
+ auto anode = dynamic_cast<AbstractNode*>(in.node);
+ assert(anode && "Unexpected node type");
+
+ for (auto& next : anode->getMutableNextNodes()) {
+ if (next == node)
+ next = with;
+ }
+ }
+
+ with_node->getMutablePrevNodes() = node->getPrevNodes();
+}
+
INode::Ref Graph::getInput(const std::string& name) {
auto it = _inputs.find(name);
if (it == _inputs.end())
return res;
}
+void Graph::replaceNode(const INode* node, INode* with) {
+ auto in = _inputs.find(node->getName());
+ if (in != _inputs.end()) {
+ (*in).second = with;
+ }
+
+ auto out_it = _outputs.find(node->getName());
+ if (out_it != _outputs.end()) {
+ (*out_it).second = with;
+ }
+
+ replaceUsages(node, with);
+
+ _nodes.erase(std::remove_if(_nodes.begin(), _nodes.end(), [node] (INode::Ref n) {
+ return n == node;
+ }), _nodes.end());
+}
+
+Node<ops::VariableOp>* Graph::replaceWithInputNode(const INode* node) {
+ auto in = create<ops::VariableOp>(node->getName());
+ assert(node->getOperation()->getNumOutputs() <= 1
+ && "Only operations with single output value can be replaced with input node");
+ assert(node->getNextNodes().size() <= 1
+ && "Node with multiple outputs cannot be changed into input");
+
+ replaceNode(node, in);
+
+ //replaceNode adds all connections of original node,
+ //but for input node we don't need input connections
+ //
+ //cast is safe since we know graph creates only AbstractNode(s)
+ static_cast<AbstractNode*>(in)->getMutablePrevNodes().clear();
+
+ delete node;
+
+ return static_cast<Node<ops::VariableOp>*>(in);
+}
+
+void Graph::replaceInputNodes(const std::vector<std::string>& new_inputs) {
+ std::vector<INode::Ref> nodes_to_replace;
+
+ std::set<std::string> new_input_set(new_inputs.begin(), new_inputs.end());
+
+ for (auto& n : _nodes) {
+ if (new_input_set.count(n->getName()) != 0) {
+ nodes_to_replace.push_back(n);
+ }
+ }
+
+ _inputs.clear();
+
+ for (auto& n : nodes_to_replace) {
+ replaceWithInputNode(n);
+ }
+}
+
+void Graph::replaceOutputNodes(const std::vector<std::string>& new_outputs) {
+ _outputs.clear();
+
+ std::set<std::string> new_outputs_set(new_outputs.begin(), new_outputs.end());
+
+ for (auto& n : _nodes) {
+ if (new_outputs_set.count(n->getName()) != 0) {
+ markOutput(n);
+ }
+ }
+}
+
} // namespace mir
} // namespace nnc
* limitations under the License.
*/
+#include <core/modelIR/ir_node.h>
+
#include "core/modelIR/ir_node.h"
namespace nnc
_inputs.resize(num_inputs);
}
+std::vector<INode::IODescriptor>& AbstractNode::getMutablePrevNodes() {
+ return _inputs;
+}
+
+std::vector<INode::Ref>& AbstractNode::getMutableNextNodes() {
+ return _outputs;
+}
+
} // namespace mir
} // namespace nnc
* @returns vector containing all graph outputs nodes
*/
std::vector<INode::Ref> collectOutputs();
+
+
+ /**
+ * @brief Subsitude node in graph with another keeping all edges
+ * @param node Node to subsitude
+ * @param with Node to place instead
+ */
+ void replaceNode(const INode* node, INode* with);
+
+ /**
+ * @brief Replaces referenced node with input(VariableOp) node
+ * @param node Node to replace
+ * @return Input node which is placed in graph instead of passed node
+ * @warning deletes passed node
+ */
+ Node<ops::VariableOp>* replaceWithInputNode(const INode* node);
+
+ /**
+ * @brief Change graph inputs to nodes with names in newInputs
+ * @param new_inputs names of nodes to be made into input nodes
+ * @warning Input node order is not preserved and may differ from newInputs vector
+ */
+ void replaceInputNodes(const std::vector<std::string>& new_inputs);
+
+ /**
+ * @brief Change graph outputs to nodes with names in newOutputs
+ * @param new_outputs names of nodes to be marked as output nodes
+ * @warning Output node order is not preserved and may differ from newOutputs vector
+ * @note Does essentially the same as markOutput() does, but takes node names
+ */
+ void replaceOutputNodes(const std::vector<std::string>& new_outputs);
+
private:
void registerNode(INode::Ref node) {
_nodes.push_back(node);
virtual size_t getId() const = 0;
- virtual OpDescription *getOperation() = 0;
+ virtual OpDescription *getOperation() const = 0;
virtual const std::string &getName() const = 0;
virtual void setName(const std::string &name) = 0;
void connectInputTo(const int inputIndex, const IODescriptor &descriptor) override;
const IODescriptor getOutput(const size_t index) override;
+ std::vector<IODescriptor>& getMutablePrevNodes();
+ std::vector<INode::Ref>& getMutableNextNodes();
+
protected:
virtual void addNextNode(INode::Ref const node) override;
class Node : public AbstractNode
{
public:
- OpType *getOperation() override { return static_cast<OpType*>(_props.op); }
+ OpType *getOperation() const override { return static_cast<OpType*>(_props.op); }
template <typename... Args>
static Node<OpType> *createNode(const std::string &nodeName, size_t id, Args &&... args)
"ShapeIndex.cpp"
"ShapeInference.cpp"
"ShapeRange.cpp"
- "TensorVariant.cpp")
+ "TensorVariant.cpp"
+ "NodeReplacer.cpp"
+ "Graph.cpp")
nncc_find_package(Protobuf QUIET)
if(Protobuf_FOUND)
--- /dev/null
+#include <gtest/gtest.h>
+
+#include "core/modelIR/graph.h"
+#include "core/modelIR/operations/relu_op.h"
+
+#include "core/modelIR/operations/concat_op.h"
+
+namespace {
+
+using namespace nnc;
+using namespace nnc::mir;
+
+class DumpVisitor : public Visitor {
+public:
+ DumpVisitor(std::ostream& s) : _s(s) {}
+
+ void visit(INode* node, ops::VariableOp& op) override {
+ _s << "i" << node->getName();
+ };
+
+ void visit(INode* node, ops::ReluOp& op) override {
+ _s << "r" << node->getName();
+ }
+
+ void visit(INode* node, ops::ConcatOp& op) override {
+ _s << "c" << node->getName();
+ }
+
+ std::ostream& _s;
+};
+
+TEST(Graph, ReplaceInputs) {
+ auto g = new Graph;
+
+ auto n1 = g->create<ops::VariableOp>("op1");
+ auto n2 = g->create<ops::ReluOp>("op2");
+ auto n3 = g->create<ops::ReluOp>("op3");
+ auto n4 = g->create<ops::ReluOp>("op4");
+ auto n5 = g->create<ops::ConcatOp>("op5", 2, 0);
+
+ n2->connectInputTo(0, n1->getOutput(0));
+ n3->connectInputTo(0, n2->getOutput(0));
+ n4->connectInputTo(0, n2->getOutput(0));
+
+ n5->connectInputTo(0, n3->getOutput(0));
+ n5->connectInputTo(1, n4->getOutput(0));
+
+ g->replaceInputNodes({"op1", "op4"});
+
+ std::stringstream ss;
+ DumpVisitor d(ss);
+ g->accept(&d);
+
+ auto str = ss.str();
+ ASSERT_EQ(str, "iop4iop1rop2rop3cop5");
+ delete g;
+};
+
+TEST(Graph, ReplaceOutputs) {
+ //There is not much to test here as Graph::replaceOutputNodes simply calls Graph::markOutput
+ // multiple times ( Graph::markOutput just places passed node into Graph::_outputs map )
+
+ auto g = new Graph;
+
+ auto n1 = g->create<ops::VariableOp>("op1");
+ auto n2 = g->create<ops::ReluOp>("op2");
+ auto n3 = g->create<ops::ReluOp>("op3");
+ auto n4 = g->create<ops::ReluOp>("op4");
+ auto n5 = g->create<ops::ConcatOp>("op5", 2, 0);
+
+ n2->connectInputTo(0, n1->getOutput(0));
+ n3->connectInputTo(0, n2->getOutput(0));
+ n4->connectInputTo(0, n2->getOutput(0));
+
+ n5->connectInputTo(0, n3->getOutput(0));
+ n5->connectInputTo(1, n4->getOutput(0));
+
+ g->replaceOutputNodes({"op3"});
+
+ std::vector<INode::Ref> expectedOutputs{n3};
+ ASSERT_EQ(g->collectOutputs(), expectedOutputs);
+ delete g;
+};
+
+TEST(Graph, ReplaceOutputNodeWithInput) {
+ auto g = new Graph;
+
+ auto n1 = g->create<ops::VariableOp>("op1");
+ auto n2 = g->create<ops::ReluOp>("op2");
+
+ n2->connectInputTo(0, n1->getOutput(0));
+
+ g->markOutput(n2);
+
+ auto in2 = g->replaceWithInputNode(n2);
+
+ std::vector<INode::Ref> expectedInputs{in2, n1};
+ ASSERT_EQ(g->collectInputs(), expectedInputs);
+ delete g;
+}
+
+}
--- /dev/null
+#include <gtest/gtest.h>
+
+#include "core/modelIR/graph.h"
+#include "core/modelIR/operations/relu_op.h"
+
+namespace {
+
+using namespace nnc;
+using namespace nnc::mir;
+
+class DumpVisitor : public Visitor {
+public:
+ DumpVisitor(std::ostream& s) : _s(s) {}
+
+ void visit(INode* node, ops::VariableOp& op) override {
+ _s << "i" << node->getName();
+ };
+
+ void visit(INode* node, ops::ReluOp& op) override {
+ _s << "r" << node->getName();
+ }
+
+ void visit(INode* node, ops::ConcatOp& op) override {
+ _s << "c" << node->getName();
+ }
+
+ std::ostream& _s;
+};
+
+TEST(NodeMutatorTest, SimpleChainTest) {
+ auto g = new Graph;
+ auto n1 = g->create<ops::VariableOp>("op1");
+ auto n2 = g->create<ops::ReluOp>("op2");
+ auto n3 = g->create<ops::ReluOp>("op3");
+ auto n4 = g->create<ops::ReluOp>("op4");
+ auto n5 = g->create<ops::ReluOp>("op5");
+
+ n2->connectInputTo(0, n1->getOutput(0));
+ n3->connectInputTo(0, n2->getOutput(0));
+ n4->connectInputTo(0, n2->getOutput(0));
+
+ g->replaceNode(n2, n5);
+ delete n2;
+
+ std::stringstream ss;
+ DumpVisitor d(ss);
+ g->accept(&d);
+
+ auto str = ss.str();
+ ASSERT_EQ(str, "iop1rop5rop3rop4");
+ delete g;
+}
+
+}