[loco] Introduce mutable canonical node visitor (#3979)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 26 Jun 2019 06:05:01 +0000 (15:05 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 26 Jun 2019 06:05:01 +0000 (15:05 +0900)
This commit allows users to define a visitor that may update each node,
and use it.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/loco/include/loco/IR/CanonicalNode.h
contrib/loco/include/loco/IR/CanonicalNodeImpl.h
contrib/loco/include/loco/IR/CanonicalNodeVisitor.forward.h
contrib/loco/include/loco/IR/CanonicalNodeVisitor.h
contrib/loco/src/IR/CanonicalNode.test.cpp

index afd5f15..58e89c1 100644 (file)
@@ -33,6 +33,7 @@ struct CanonicalNode : public Node
   virtual CanonicalOpcode opcode(void) const = 0;
 
   template <typename T> T accept(CanonicalNodeVisitorBase<T> *) const;
+  template <typename T> T accept(CanonicalNodeMutableVisitorBase<T> *);
 };
 
 template <CanonicalOpcode Code> struct CanonicalNodeImpl : public CanonicalNode
index 02408b4..bb8c8b7 100644 (file)
@@ -42,6 +42,23 @@ template <typename T> T CanonicalNode::accept(CanonicalNodeVisitorBase<T> *v) co
   throw std::runtime_error{"NYI"};
 }
 
+template <typename T> T CanonicalNode::accept(CanonicalNodeMutableVisitorBase<T> *v)
+{
+  switch (this->opcode())
+  {
+#define CANONICAL_NODE(OPCODE, CLASS) \
+  case CanonicalOpcode::OPCODE:       \
+    return v->visit(dynamic_cast<CLASS *>(this));
+
+#include "CanonicalNodes.lst"
+#undef CANONICAL_NODE
+  default:
+    break;
+  }
+
+  throw std::runtime_error{"NYI"};
+}
+
 } // namespace loco
 
 #endif // __LOCO_IR_CANONICAL_NODE_IMPL_H__
index 50bb11f..425d779 100644 (file)
@@ -20,8 +20,9 @@
 namespace loco
 {
 
-// NOTE This forward declaration SHOULD BE aligned with Node delcaration in "CanonicalNodeVisitor.h"
+// NOTE These forward declarations SHOULD BE aligned with "CanonicalNodeVisitor.h"
 template <typename T> struct CanonicalNodeVisitorBase;
+template <typename T> struct CanonicalNodeMutableVisitorBase;
 
 } // namespace loco
 
index 5d1e35d..b9ffd54 100644 (file)
@@ -49,6 +49,31 @@ template <typename T> struct CanonicalNodeVisitor : public CanonicalNodeVisitorB
   virtual T visit(const Node *) { throw std::runtime_error{"Not implemented, yet"}; }
 };
 
+/**
+ * DO NOT use this class. Use CanonicalNodeMutableVisitor instead.
+ */
+template <typename T> struct CanonicalNodeMutableVisitorBase
+{
+  virtual ~CanonicalNodeMutableVisitorBase() = default;
+
+#define CANONICAL_NODE(OPCODE, CLASS) virtual T visit(CLASS *) = 0;
+#include "CanonicalNodes.lst"
+#undef CANONICAL_NODE
+};
+
+template <typename T> struct CanonicalNodeMutableVisitor : public CanonicalNodeMutableVisitorBase<T>
+{
+  virtual ~CanonicalNodeMutableVisitor() = default;
+
+#define CANONICAL_NODE(OPCODE, CLASS) \
+  virtual T visit(CLASS *node) { return visit(static_cast<Node *>(node)); }
+#include "CanonicalNodes.lst"
+#undef CANONICAL_NODE
+
+  /// @brief Default fallback
+  virtual T visit(Node *) { throw std::runtime_error{"Not implemented, yet"}; }
+};
+
 } // namespace loco
 
 #endif // __LOCO_IR_CANONICAL_NODE_VISITOR_H__
index e2ef063..6d0229a 100644 (file)
@@ -53,3 +53,21 @@ TEST(CanonicalNodeTest, visitor)
 
   ASSERT_EQ(node.accept(&v), 1);
 }
+
+TEST(CanonicalNodeTest, mutable_visitor)
+{
+  struct ResetForward final : public loco::CanonicalNodeMutableVisitor<void>
+  {
+    void visit(loco::Forward *node) final { node->input(nullptr); }
+  };
+
+  loco::Pull pull_node;
+  loco::Forward forward_node;
+
+  forward_node.input(&pull_node);
+
+  ResetForward v;
+  forward_node.accept(&v);
+
+  ASSERT_EQ(forward_node.input(), nullptr);
+}