#ifndef __COCO_IR_OP_H__
#define __COCO_IR_OP_H__
+#include <stdexcept>
+
namespace coco
{
virtual Conv2D *asConv2D(void) { return nullptr; }
virtual const Conv2D *asConv2D(void) const { return nullptr; }
+
+ template<typename T> struct Visitor
+ {
+ virtual ~Visitor() = default;
+
+ virtual T visit(const Conv2D *) = 0;
+ };
+
+ template<typename T> struct DefaultVisitor : public Visitor<T>
+ {
+ virtual ~DefaultVisitor() = default;
+
+ virtual T visit(const Conv2D *) override { throw std::runtime_error{"NYI"}; }
+ };
+
+ template<typename T> T accept(Visitor<T> *v)
+ {
+ if (auto op = asConv2D())
+ {
+ return v->visit(op);
+ }
+
+ throw std::runtime_error{"unreachable"};
+ }
+
+ template<typename T> T accept(Visitor<T> &v) { return accept(&v); }
+ template<typename T> T accept(Visitor<T> &&v) { return accept(&v); }
};
} // namespace coco
ASSERT_EQ(mutable_base->asConv2D(), immutable_base->asConv2D());
}
+//
+// Conv2D
+//
+namespace
+{
+struct IsConv2D : public coco::Op::DefaultVisitor<bool>
+{
+ bool visit(const coco::Conv2D *) override { return true; }
+};
+} // namespace
+
TEST(IR_OP_CONV2D, ker_update)
{
// Prepare a kernel object for testing
op.ker(obj);
ASSERT_EQ(op.ker(), obj);
}
+
+TEST(IR_OP_CONV2D, accept)
+{
+ // Test 'Conv2D' class
+ const coco::Conv2D::Param param{};
+ coco::Conv2D op{param};
+
+ ASSERT_TRUE(op.accept(IsConv2D{}));
+}