[nnc] Impement node replacement facility in Graph class (#1951)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Thu, 25 Oct 2018 17:51:17 +0000 (20:51 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Thu, 25 Oct 2018 17:51:17 +0000 (20:51 +0300)
Provides replaceNode which is used to replace a node in graph preserving all edges
Provides replaceNodeWithInput and replaceInputNodes methods, used to replace any node with input node
Provides replaceOutputNodes,  used to replace graph output nodes with new list

Adds unit tests for all implementations

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/core/modelIR/graph.cpp
contrib/nnc/core/modelIR/ir_node.cpp
contrib/nnc/include/core/modelIR/graph.h
contrib/nnc/include/core/modelIR/ir_node.h
contrib/nnc/unittests/core/CMakeLists.txt
contrib/nnc/unittests/core/Graph.cpp [new file with mode: 0644]
contrib/nnc/unittests/core/NodeReplacer.cpp [new file with mode: 0644]

index 6d9ee69..e785ea5 100644 (file)
 
 #include <deque>
 #include <set>
+#include <algorithm>
 
 #include "core/modelIR/graph.h"
 
 namespace nnc {
 namespace mir {
 
+/**
+ * @brief replace all usages of node `node` with node `with`
+ * (i.e. all references in previous/next nodes )
+ * @param inode a node to replace
+ * @param with a node to use as a replacement
+ */
+static void replaceUsages(const INode* node, INode* with) {
+  auto with_node = dynamic_cast<AbstractNode*>(with);
+  assert(with_node);
+
+  //For each output replace prev references to `node` by `with`
+  for (auto out : node->getNextNodes()) {
+    auto anode = dynamic_cast<AbstractNode*>(out);
+    assert(anode && "Unexpected node type");
+
+    for (auto& prev : anode->getMutablePrevNodes()) {
+      if (prev.node == node)
+        prev.node = with;
+    }
+  }
+
+  with_node->getMutableNextNodes() = node->getNextNodes();
+
+  //For each input replace next references to `node` by `with`
+  for (auto& in : node->getPrevNodes()) {
+    auto anode = dynamic_cast<AbstractNode*>(in.node);
+    assert(anode && "Unexpected node type");
+
+    for (auto& next : anode->getMutableNextNodes()) {
+      if (next == node)
+        next = with;
+    }
+  }
+
+  with_node->getMutablePrevNodes() = node->getPrevNodes();
+}
+
 INode::Ref Graph::getInput(const std::string& name) {
   auto it = _inputs.find(name);
   if (it == _inputs.end())
@@ -101,5 +139,73 @@ std::vector<INode::Ref> Graph::collectOutputs() {
   return res;
 }
 
+void Graph::replaceNode(const INode* node, INode* with) {
+  auto in = _inputs.find(node->getName());
+  if (in != _inputs.end()) {
+    (*in).second = with;
+  }
+
+  auto out_it = _outputs.find(node->getName());
+  if (out_it != _outputs.end()) {
+    (*out_it).second = with;
+  }
+
+  replaceUsages(node, with);
+
+  _nodes.erase(std::remove_if(_nodes.begin(), _nodes.end(), [node] (INode::Ref n) {
+    return n == node;
+  }), _nodes.end());
+}
+
+Node<ops::VariableOp>* Graph::replaceWithInputNode(const INode* node) {
+  auto in = create<ops::VariableOp>(node->getName());
+  assert(node->getOperation()->getNumOutputs() <= 1
+         && "Only operations with single output value can be replaced with input node");
+  assert(node->getNextNodes().size() <= 1
+         && "Node with multiple outputs cannot be changed into input");
+
+  replaceNode(node, in);
+
+  //replaceNode adds all connections of original node,
+  //but for input node we don't need input connections
+  //
+  //cast is safe since we know graph creates only AbstractNode(s)
+  static_cast<AbstractNode*>(in)->getMutablePrevNodes().clear();
+
+  delete node;
+
+  return static_cast<Node<ops::VariableOp>*>(in);
+}
+
+void Graph::replaceInputNodes(const std::vector<std::string>& new_inputs) {
+  std::vector<INode::Ref> nodes_to_replace;
+
+  std::set<std::string> new_input_set(new_inputs.begin(), new_inputs.end());
+
+  for (auto& n : _nodes) {
+    if (new_input_set.count(n->getName()) != 0) {
+      nodes_to_replace.push_back(n);
+    }
+  }
+
+  _inputs.clear();
+
+  for (auto& n : nodes_to_replace) {
+    replaceWithInputNode(n);
+  }
+}
+
+void Graph::replaceOutputNodes(const std::vector<std::string>& new_outputs) {
+  _outputs.clear();
+
+  std::set<std::string> new_outputs_set(new_outputs.begin(), new_outputs.end());
+
+  for (auto& n : _nodes) {
+    if (new_outputs_set.count(n->getName()) != 0) {
+      markOutput(n);
+    }
+  }
+}
+
 } // namespace mir
 } // namespace nnc
index 13fa93d..98fa2a5 100644 (file)
@@ -14,6 +14,8 @@
  * limitations under the License.
  */
 
+#include <core/modelIR/ir_node.h>
+
 #include "core/modelIR/ir_node.h"
 
 namespace nnc
@@ -47,5 +49,13 @@ AbstractNode::AbstractNode(size_t num_inputs) {
   _inputs.resize(num_inputs);
 }
 
+std::vector<INode::IODescriptor>& AbstractNode::getMutablePrevNodes() {
+  return _inputs;
+}
+
+std::vector<INode::Ref>& AbstractNode::getMutableNextNodes() {
+  return _outputs;
+}
+
 } // namespace mir
 } // namespace nnc
index 1dc143f..15d71ff 100644 (file)
@@ -63,6 +63,38 @@ class Graph {
    * @returns vector containing all graph outputs nodes
    */
   std::vector<INode::Ref> collectOutputs();
+
+
+  /**
+   * @brief Subsitude node in graph with another keeping all edges
+   * @param node Node to subsitude
+   * @param with Node to place instead
+   */
+  void replaceNode(const INode* node, INode* with);
+
+  /**
+   * @brief Replaces referenced node with input(VariableOp) node
+   * @param node Node to replace
+   * @return Input node which is placed in graph instead of passed node
+   * @warning deletes passed node
+   */
+  Node<ops::VariableOp>* replaceWithInputNode(const INode* node);
+
+  /**
+   * @brief Change graph inputs to nodes with names in newInputs
+   * @param new_inputs names of nodes to be made into input nodes
+   * @warning Input node order is not preserved and may differ from newInputs vector
+   */
+  void replaceInputNodes(const std::vector<std::string>& new_inputs);
+
+  /**
+   * @brief Change graph outputs to nodes with names in newOutputs
+   * @param new_outputs names of nodes to be marked as output nodes
+   * @warning Output node order is not preserved and may differ from newOutputs vector
+   * @note Does essentially the same as markOutput() does, but takes node names
+   */
+  void replaceOutputNodes(const std::vector<std::string>& new_outputs);
+
  private:
   void registerNode(INode::Ref node) {
     _nodes.push_back(node);
index 137fbfd..12a4ddc 100644 (file)
@@ -45,7 +45,7 @@ public:
 
   virtual size_t getId() const = 0;
 
-  virtual OpDescription *getOperation() = 0;
+  virtual OpDescription *getOperation() const = 0;
 
   virtual const std::string &getName() const = 0;
   virtual void setName(const std::string &name) = 0;
@@ -70,6 +70,9 @@ public:
   void connectInputTo(const int inputIndex, const IODescriptor &descriptor) override;
   const IODescriptor getOutput(const size_t index) override;
 
+  std::vector<IODescriptor>& getMutablePrevNodes();
+  std::vector<INode::Ref>& getMutableNextNodes();
+
  protected:
   virtual void addNextNode(INode::Ref const node) override;
 
@@ -100,7 +103,7 @@ template <typename OpType>
 class Node : public AbstractNode
 {
 public:
-  OpType *getOperation() override { return static_cast<OpType*>(_props.op); }
+  OpType *getOperation() const override { return static_cast<OpType*>(_props.op); }
 
   template <typename... Args>
   static Node<OpType> *createNode(const std::string &nodeName, size_t id, Args &&... args)
index 87806b4..fa3af8a 100644 (file)
@@ -3,7 +3,9 @@ set(TESTS "ir_node.cpp"
           "ShapeIndex.cpp"
           "ShapeInference.cpp"
           "ShapeRange.cpp"
-          "TensorVariant.cpp")
+          "TensorVariant.cpp"
+          "NodeReplacer.cpp"
+          "Graph.cpp")
 
 nncc_find_package(Protobuf QUIET)
 if(Protobuf_FOUND)
diff --git a/contrib/nnc/unittests/core/Graph.cpp b/contrib/nnc/unittests/core/Graph.cpp
new file mode 100644 (file)
index 0000000..7d9dc23
--- /dev/null
@@ -0,0 +1,102 @@
+#include <gtest/gtest.h>
+
+#include "core/modelIR/graph.h"
+#include "core/modelIR/operations/relu_op.h"
+
+#include "core/modelIR/operations/concat_op.h"
+
+namespace {
+
+using namespace nnc;
+using namespace nnc::mir;
+
+class DumpVisitor : public Visitor {
+public:
+  DumpVisitor(std::ostream& s) : _s(s) {}
+
+  void visit(INode* node, ops::VariableOp& op) override {
+    _s << "i" << node->getName();
+  };
+
+  void visit(INode* node, ops::ReluOp& op) override {
+    _s << "r" << node->getName();
+  }
+
+  void visit(INode* node, ops::ConcatOp& op) override {
+    _s << "c" << node->getName();
+  }
+
+  std::ostream& _s;
+};
+
+TEST(Graph, ReplaceInputs) {
+  auto g = new Graph;
+
+  auto n1 = g->create<ops::VariableOp>("op1");
+  auto n2 = g->create<ops::ReluOp>("op2");
+  auto n3 = g->create<ops::ReluOp>("op3");
+  auto n4 = g->create<ops::ReluOp>("op4");
+  auto n5 = g->create<ops::ConcatOp>("op5", 2, 0);
+
+  n2->connectInputTo(0, n1->getOutput(0));
+  n3->connectInputTo(0, n2->getOutput(0));
+  n4->connectInputTo(0, n2->getOutput(0));
+
+  n5->connectInputTo(0, n3->getOutput(0));
+  n5->connectInputTo(1, n4->getOutput(0));
+
+  g->replaceInputNodes({"op1", "op4"});
+
+  std::stringstream ss;
+  DumpVisitor d(ss);
+  g->accept(&d);
+
+  auto str = ss.str();
+  ASSERT_EQ(str, "iop4iop1rop2rop3cop5");
+  delete g;
+};
+
+TEST(Graph, ReplaceOutputs) {
+  //There is not much to test here as Graph::replaceOutputNodes simply calls Graph::markOutput
+  // multiple times ( Graph::markOutput just places passed node into Graph::_outputs map )
+
+  auto g = new Graph;
+
+  auto n1 = g->create<ops::VariableOp>("op1");
+  auto n2 = g->create<ops::ReluOp>("op2");
+  auto n3 = g->create<ops::ReluOp>("op3");
+  auto n4 = g->create<ops::ReluOp>("op4");
+  auto n5 = g->create<ops::ConcatOp>("op5", 2, 0);
+
+  n2->connectInputTo(0, n1->getOutput(0));
+  n3->connectInputTo(0, n2->getOutput(0));
+  n4->connectInputTo(0, n2->getOutput(0));
+
+  n5->connectInputTo(0, n3->getOutput(0));
+  n5->connectInputTo(1, n4->getOutput(0));
+
+  g->replaceOutputNodes({"op3"});
+
+  std::vector<INode::Ref> expectedOutputs{n3};
+  ASSERT_EQ(g->collectOutputs(), expectedOutputs);
+  delete g;
+};
+
+TEST(Graph, ReplaceOutputNodeWithInput) {
+  auto g = new Graph;
+
+  auto n1 = g->create<ops::VariableOp>("op1");
+  auto n2 = g->create<ops::ReluOp>("op2");
+
+  n2->connectInputTo(0, n1->getOutput(0));
+
+  g->markOutput(n2);
+
+  auto in2 = g->replaceWithInputNode(n2);
+
+  std::vector<INode::Ref> expectedInputs{in2, n1};
+  ASSERT_EQ(g->collectInputs(), expectedInputs);
+  delete g;
+}
+
+}
diff --git a/contrib/nnc/unittests/core/NodeReplacer.cpp b/contrib/nnc/unittests/core/NodeReplacer.cpp
new file mode 100644 (file)
index 0000000..fbc2014
--- /dev/null
@@ -0,0 +1,54 @@
+#include <gtest/gtest.h>
+
+#include "core/modelIR/graph.h"
+#include "core/modelIR/operations/relu_op.h"
+
+namespace {
+
+using namespace nnc;
+using namespace nnc::mir;
+
+class DumpVisitor : public Visitor {
+public:
+  DumpVisitor(std::ostream& s) : _s(s) {}
+
+  void visit(INode* node, ops::VariableOp& op) override {
+    _s << "i" << node->getName();
+  };
+
+  void visit(INode* node, ops::ReluOp& op) override {
+    _s << "r" << node->getName();
+  }
+
+  void visit(INode* node, ops::ConcatOp& op) override {
+    _s << "c" << node->getName();
+  }
+
+  std::ostream& _s;
+};
+
+TEST(NodeMutatorTest, SimpleChainTest) {
+  auto g = new Graph;
+  auto n1 = g->create<ops::VariableOp>("op1");
+  auto n2 = g->create<ops::ReluOp>("op2");
+  auto n3 = g->create<ops::ReluOp>("op3");
+  auto n4 = g->create<ops::ReluOp>("op4");
+  auto n5 = g->create<ops::ReluOp>("op5");
+
+  n2->connectInputTo(0, n1->getOutput(0));
+  n3->connectInputTo(0, n2->getOutput(0));
+  n4->connectInputTo(0, n2->getOutput(0));
+
+  g->replaceNode(n2, n5);
+  delete n2;
+
+  std::stringstream ss;
+  DumpVisitor d(ss);
+  g->accept(&d);
+
+  auto str = ss.str();
+  ASSERT_EQ(str, "iop1rop5rop3rop4");
+  delete g;
+}
+
+}