[coco] Introduce template-based Op visitor (#845)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 31 Jul 2018 23:56:03 +0000 (08:56 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 31 Jul 2018 23:56:03 +0000 (08:56 +0900)
This commit introduces Op::Vistor class which allows users to analyze
the actual content of each operator.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/coco/core/include/coco/IR/Op.h
contrib/coco/core/src/IR/Op.test.cpp

index 2ab6ae7..630317c 100644 (file)
@@ -1,6 +1,8 @@
 #ifndef __COCO_IR_OP_H__
 #define __COCO_IR_OP_H__
 
+#include <stdexcept>
+
 namespace coco
 {
 
@@ -15,6 +17,33 @@ struct Op
 
   virtual Conv2D *asConv2D(void) { return nullptr; }
   virtual const Conv2D *asConv2D(void) const { return nullptr; }
+
+  template<typename T> struct Visitor
+  {
+    virtual ~Visitor() = default;
+
+    virtual T visit(const Conv2D *) = 0;
+  };
+
+  template<typename T> struct DefaultVisitor : public Visitor<T>
+  {
+    virtual ~DefaultVisitor() = default;
+
+    virtual T visit(const Conv2D *) override { throw std::runtime_error{"NYI"}; }
+  };
+
+  template<typename T> T accept(Visitor<T> *v)
+  {
+    if (auto op = asConv2D())
+    {
+      return v->visit(op);
+    }
+
+    throw std::runtime_error{"unreachable"};
+  }
+
+  template<typename T> T accept(Visitor<T> &v) { return accept(&v); }
+  template<typename T> T accept(Visitor<T> &&v) { return accept(&v); }
 };
 
 } // namespace coco
index ab68ef8..eee3398 100644 (file)
@@ -23,6 +23,17 @@ TEST(IR_OP_CONV2D, asConv2D)
   ASSERT_EQ(mutable_base->asConv2D(), immutable_base->asConv2D());
 }
 
+//
+// Conv2D
+//
+namespace
+{
+struct IsConv2D : public coco::Op::DefaultVisitor<bool>
+{
+  bool visit(const coco::Conv2D *) override { return true; }
+};
+} // namespace
+
 TEST(IR_OP_CONV2D, ker_update)
 {
   // Prepare a kernel object for testing
@@ -38,3 +49,12 @@ TEST(IR_OP_CONV2D, ker_update)
   op.ker(obj);
   ASSERT_EQ(op.ker(), obj);
 }
+
+TEST(IR_OP_CONV2D, accept)
+{
+  // Test 'Conv2D' class
+  const coco::Conv2D::Param param{};
+  coco::Conv2D op{param};
+
+  ASSERT_TRUE(op.accept(IsConv2D{}));
+}