[Fix] Add missing expr visitor for any (#6082)
authorHaichen Shen <shenhaichen@gmail.com>
Fri, 17 Jul 2020 14:48:38 +0000 (07:48 -0700)
committerGitHub <noreply@github.com>
Fri, 17 Jul 2020 14:48:38 +0000 (07:48 -0700)
include/tvm/tir/expr_functor.h
src/tir/ir/expr_functor.cc

index a6c90b3..3f73d21 100644 (file)
@@ -150,6 +150,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
   virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExprDefault_(const Object* op, Args...) {
     LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
     return R();
@@ -194,6 +195,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
     IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
     IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
     IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
+    IR_EXPR_FUNCTOR_DISPATCH(AnyNode);
     return vtable;
   }
 };
@@ -245,6 +247,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
   void VisitExpr_(const IntImmNode* op) override;
   void VisitExpr_(const FloatImmNode* op) override;
   void VisitExpr_(const StringImmNode* op) override;
+  void VisitExpr_(const AnyNode* op) override;
 };
 
 /*!
@@ -291,6 +294,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
   PrimExpr VisitExpr_(const IntImmNode* op) override;
   PrimExpr VisitExpr_(const FloatImmNode* op) override;
   PrimExpr VisitExpr_(const StringImmNode* op) override;
+  PrimExpr VisitExpr_(const AnyNode* op) override;
 };
 
 }  // namespace tir
index 0118228..166f950 100644 (file)
@@ -32,6 +32,8 @@ void ExprVisitor::VisitExpr_(const SizeVarNode* op) {
   this->VisitExpr_(static_cast<const VarNode*>(op));
 }
 
+void ExprVisitor::VisitExpr_(const AnyNode* op) {}
+
 void ExprVisitor::VisitExpr_(const LoadNode* op) {
   this->VisitExpr(op->index);
   this->VisitExpr(op->predicate);
@@ -119,6 +121,8 @@ PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) {
   return this->VisitExpr_(static_cast<const VarNode*>(op));
 }
 
+PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return GetRef<PrimExpr>(op); }
+
 PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
   PrimExpr index = this->VisitExpr(op->index);
   PrimExpr predicate = this->VisitExpr(op->predicate);