[Relay] Reference (#2489)
author雾雨魔理沙 <lolisa@marisa.moe>
Fri, 15 Feb 2019 18:07:10 +0000 (10:07 -0800)
committerziheng <ziheng@apache.org>
Fri, 15 Feb 2019 18:07:10 +0000 (10:07 -0800)
* move

fix test

fix lint

fix test

add more code

fix lint

better type infer ability

* fix build

* address comment

25 files changed:
include/tvm/relay/expr.h
include/tvm/relay/expr_functor.h
include/tvm/relay/interpreter.h
include/tvm/relay/type.h
python/tvm/relay/__init__.py
python/tvm/relay/backend/graph_runtime_codegen.py
python/tvm/relay/backend/interpreter.py
python/tvm/relay/expr.py
python/tvm/relay/expr_functor.py
python/tvm/relay/ty.py
src/relay/backend/interpreter.cc
src/relay/ir/alpha_equal.cc
src/relay/ir/expr.cc
src/relay/ir/expr_functor.cc
src/relay/ir/hash.cc
src/relay/ir/text_printer.cc
src/relay/ir/type.cc
src/relay/ir/type_functor.cc
src/relay/ir/type_functor.h
src/relay/pass/fuse_ops.cc
src/relay/pass/kind_check.cc
src/relay/pass/type_infer.cc
src/relay/pass/type_solver.cc
tests/python/relay/test_backend_interpreter.py
tests/python/relay/test_type_infer.py

index 14b3cd91701c6203afcf3828b273f1dceff8ccd1..b9a57c5d4618fcc8d7edf6d24a0891b141ab9a34 100644 (file)
@@ -428,12 +428,78 @@ class TupleGetItemNode : public ExprNode {
 
   TVM_DLL static TupleGetItem make(Expr tuple, int index);
 
-  static constexpr const char * _type_key = "relay.TupleGetItem";
+  static constexpr const char* _type_key = "relay.TupleGetItem";
   TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
 };
 
 RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
 
+/*! \brief Create a new Reference out of initial value. */
+class RefCreate;
+class RefCreateNode : public ExprNode {
+ public:
+  /*! \brief The initial value of the Reference. */
+  Expr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("value", &value);
+    v->Visit("span", &span);
+    v->Visit("_checked_type_", &checked_type_);
+  }
+
+  TVM_DLL static RefCreate make(Expr value);
+
+  static constexpr const char* _type_key = "relay.RefCreate";
+  TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);
+
+/*! \brief Get value out of Reference. */
+class RefRead;
+class RefReadNode : public ExprNode {
+ public:
+  /*! \brief The Reference Expression. */
+  Expr ref;
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("ref", &ref);
+    v->Visit("span", &span);
+    v->Visit("_checked_type_", &checked_type_);
+  }
+
+  TVM_DLL static RefRead make(Expr ref);
+
+  static constexpr const char* _type_key = "relay.RefRead";
+  TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);
+
+/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
+class RefWrite;
+class RefWriteNode : public ExprNode {
+ public:
+  /*! \brief The Reference Expression. */
+  Expr ref;
+  /*! \brief The value to write into. */
+  Expr value;
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("ref", &ref);
+    v->Visit("value", &value);
+    v->Visit("span", &span);
+    v->Visit("_checked_type_", &checked_type_);
+  }
+
+  TVM_DLL static RefWrite make(Expr ref, Expr value);
+
+  static constexpr const char* _type_key = "relay.RefWrite";
+  TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);
+
 /*!
  * \brief Base class of the temporary expression.
  *
index 60b18218a313109ec2707e7292d4c864df3f6e97..e7b66bc1bbde0fa08c4dcf20fe1f3e423ab02863 100644 (file)
@@ -89,6 +89,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
   virtual R VisitExpr_(const OpNode* op,
                        Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExprDefault_(const Node* op, Args...) {
     throw Error(std::string("Do not have a default for ") + op->type_key());
   }
@@ -108,6 +111,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
     RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
     RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
     RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
+    RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
+    RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
+    RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
     return vtable;
   }
 };
@@ -133,6 +139,9 @@ class ExprVisitor
   void VisitExpr_(const IfNode* op) override;
   void VisitExpr_(const OpNode* op) override;
   void VisitExpr_(const TupleGetItemNode* op) override;
+  void VisitExpr_(const RefCreateNode* op) override;
+  void VisitExpr_(const RefReadNode* op) override;
+  void VisitExpr_(const RefWriteNode* op) override;
   virtual void VisitType(const Type& t);
 
  protected:
@@ -168,6 +177,9 @@ class ExprMutator
   Expr VisitExpr_(const LetNode* op) override;
   Expr VisitExpr_(const IfNode* op) override;
   Expr VisitExpr_(const TupleGetItemNode* op) override;
+  Expr VisitExpr_(const RefCreateNode* op) override;
+  Expr VisitExpr_(const RefReadNode* op) override;
+  Expr VisitExpr_(const RefWriteNode* op) override;
   /*!
    * \brief Used to visit the types inside of expressions.
    *
index 1099ef0f3cfdd510bd0b55e0948db6a7eacc1a26..08aeef1827b6838bfcc850ca7899efb55ea0325c 100644 (file)
@@ -140,6 +140,25 @@ struct TensorValueNode : ValueNode {
 
 RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);
 
+/*! \brief A reference value. */
+class RefValue;
+
+struct RefValueNode : ValueNode {
+  mutable Value value;
+
+  RefValueNode() {}
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("value", &value);
+  }
+
+  TVM_DLL static RefValue make(Value val);
+
+  static constexpr const char* _type_key = "relay.RefValue";
+  TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode);
+};
+
+RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
 
 }  // namespace relay
 }  // namespace tvm
index f3bcf2c0a1d9f7e0a504e32eb0c04ccd7d214b18..0ee265e5f3b0e706a5c00c762a339a4dfdce2089 100644 (file)
@@ -262,6 +262,33 @@ class TupleTypeNode : public TypeNode {
 
 RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
 
+/*!
+ * \brief The type of reference values.
+ */
+class RefType;
+/*!
+ * \brief Reference Type in relay.
+ */
+class RefTypeNode : public TypeNode {
+ public:
+  /*! \brief The type of value in the Reference. */
+  Type value;
+
+  RefTypeNode() {}
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("value", &value);
+    v->Visit("span", &span);
+  }
+
+  TVM_DLL static RefType make(Type value);
+
+  static constexpr const char* _type_key = "relay.RefType";
+  TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode);
+};
+
+RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type);
+
 class TypeReporter;
 
 /*!
index b9d4695b70f8812673ff982b2a4f3b86c48477b8..0af164bc7a73e7943669311bc94b566943f22250 100644 (file)
@@ -44,6 +44,7 @@ FuncType = ty.FuncType
 TypeRelation = ty.TypeRelation
 IncompleteType = ty.IncompleteType
 scalar_type = ty.scalar_type
+RefType = ty.RefType
 
 # Expr
 Expr = expr.Expr
@@ -56,15 +57,18 @@ Call = expr.Call
 Let = expr.Let
 If = expr.If
 TupleGetItem = expr.TupleGetItem
-
-# ExprFunctor
-ExprFunctor = expr_functor.ExprFunctor
-ExprMutator = expr_functor.ExprMutator
+RefCreate = expr.RefCreate
+RefRead = expr.RefRead
+RefWrite = expr.RefWrite
 
 # helper functions
 var = expr.var
 const = expr.const
 bind = expr.bind
 
+# ExprFunctor
+ExprFunctor = expr_functor.ExprFunctor
+ExprMutator = expr_functor.ExprMutator
+
 # Parser
 fromtext = parser.fromtext
index 15e0a81226cb03eccdf1c2cc72912c676729d680..cc510b2290cfbc329a6e83f0caa5801734b7865c 100644 (file)
@@ -283,6 +283,15 @@ class GraphRuntimeCodegen(ExprFunctor):
     def visit_op(self, _):
         raise Exception("can not compile op in non-eta expanded form")
 
+    def visit_ref_create(self, _):
+        raise RuntimeError("reference not supported")
+
+    def visit_ref_read(self, _):
+        raise RuntimeError("reference not supported")
+
+    def visit_ref_write(self, _):
+        raise RuntimeError("reference not supported")
+
     def _get_json(self):
         """
         Convert the sequence of nodes stored by the compiler into the
index 4a5ddcd8270c5c535362665adba1e5774f7d66c2..b21eab185c28ceae709aee6399978468825ae6bd 100644 (file)
@@ -45,6 +45,7 @@ class TupleValue(Value):
     def __iter__(self):
         return iter(self.fields)
 
+
 @register_relay_node
 class Closure(Value):
     """A closure produced by the interpreter."""
@@ -79,6 +80,13 @@ class TensorValue(Value):
         return str(self.data)
 
 
+@register_relay_node
+class RefValue(Value):
+    def __init__(self, value):
+        self.__init_handle_by_constructor__(
+            _make.RefValue, value)
+
+
 def _arg_to_ast(arg):
     if isinstance(arg, TensorValue):
         return Constant(arg.data.copyto(_nd.cpu(0)))
index 38ab0064e671ac4a23c0e65c97a30b38510ff1b1..71b89d0b4777011d4e2a2e32083f5a3b1aa24def 100644 (file)
@@ -327,6 +327,46 @@ class TupleGetItem(Expr):
             _make.TupleGetItem, tuple_value, index)
 
 
+@register_relay_node
+class RefCreate(Expr):
+    """Create a new reference from initial value.
+    Parameters
+    ----------
+    value: tvm.relay.Expr
+       The initial value.
+    """
+    def __init__(self, value):
+        self.__init_handle_by_constructor__(_make.RefCreate, value)
+
+
+@register_relay_node
+class RefRead(Expr):
+    """Get the value inside the reference.
+    Parameters
+    ----------
+    ref: tvm.relay.Expr
+         The reference.
+    """
+    def __init__(self, ref):
+        self.__init_handle_by_constructor__(_make.RefRead, ref)
+
+
+@register_relay_node
+class RefWrite(Expr):
+    """
+    Update the value inside the reference.
+    The whole expression will evaluate to an empty tuple.
+    Parameters
+    ----------
+    ref: tvm.relay.Expr
+        The reference.
+    value: tvm.relay.Expr
+        The new value.
+    """
+    def __init__(self, ref, value):
+        self.__init_handle_by_constructor__(_make.RefWrite, ref, value)
+
+
 class TempExpr(Expr):
     """Baseclass of all TempExpr.
 
index eafe5f09309fffe2eee78c30362c71201e5cb5d4..b22a4e7562e2e607ff5a8a5c2c1b9b2b5b4d89f2 100644 (file)
@@ -41,6 +41,12 @@ class ExprFunctor:
             res = self.visit_constant(expr)
         elif isinstance(expr, Op):
             res = self.visit_op(expr)
+        elif isinstance(expr, RefCreate):
+            res = self.visit_ref_create(expr)
+        elif isinstance(expr, RefRead):
+            res = self.visit_ref_read(expr)
+        elif isinstance(expr, RefWrite):
+            res = self.visit_ref_write(expr)
         else:
             raise Exception("warning unhandled case: {0}".format(type(expr)))
 
@@ -81,6 +87,14 @@ class ExprFunctor:
     def visit_constant(self, _):
         raise NotImplementedError()
 
+    def visit_ref_create(self, _):
+        raise NotImplementedError()
+
+    def visit_ref_write(self, _):
+        raise NotImplementedError()
+
+    def visit_ref_read(self, _):
+        raise NotImplementedError()
 
 class ExprMutator(ExprFunctor):
     """
@@ -145,8 +159,8 @@ class ExprMutator(ExprFunctor):
     def visit_match(self, m):
         return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern])
 
-    def visit_ref_new(self, r):
-        return RefNew(self.visit(r.value))
+    def visit_ref_create(self, r):
+        return RefCreate(self.visit(r.value))
 
     def visit_ref_write(self, r):
         return RefWrite(self.visit(r.ref), self.visit(r.value))
index 96dde5acb4dfeb7e7943c604416831ff9d732dee..bed293d1e3caf903ff9ed98a528845f38b3fbeb9 100644 (file)
@@ -210,6 +210,19 @@ class TypeRelation(TypeConstraint):
                                             func, args, num_inputs, attrs)
 
 
+@register_relay_node
+class RefType(Type):
+    """Reference Type in relay.
+
+    Parameters
+    ----------
+    value: Type
+        The value type.
+    """
+    def __init__(self, value):
+        self.__init_handle_by_constructor__(_make.RefType, value)
+
+
 def scalar_type(dtype):
     """Creates a scalar type.
 
index 396ff907951d913e0a46afb3ef2d73f34c04d10b..893e66b41b426bb7907b2519e5d9690fd8bd9c97 100644 (file)
@@ -75,6 +75,23 @@ TVM_REGISTER_API("relay._make.TensorValue")
     *ret = TensorValueNode::make(data);
   });
 
+RefValue RefValueNode::make(Value value) {
+  NodePtr<RefValueNode> n = make_node<RefValueNode>();
+  n->value = value;
+  return RefValue(n);
+}
+
+TVM_REGISTER_API("relay._make.RefValue")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    *ret = RefValueNode::make(args[0]);
+  });
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RefValueNode>([](const RefValueNode* node,
+                               tvm::IRPrinter* p) {
+                              p->stream << "RefValueNode(" << node->value << ")";
+                            });
+
 /*!
  * \brief A stack frame in the Relay interpreter.
  *
@@ -432,6 +449,31 @@ class Interpreter :
     }
   }
 
+  Value VisitExpr_(const RefWriteNode* op) final {
+    Value r = Eval(op->ref);
+    if (const RefValueNode* rv = r.as<RefValueNode>()) {
+      rv->value = Eval(op->value);
+      return TupleValueNode::make({});
+    } else {
+      LOG(FATAL) << "type error, type system should have caught this";
+      return Value();
+    }
+  }
+
+  Value VisitExpr_(const RefCreateNode* op) final {
+    return RefValueNode::make(Eval(op->value));
+  }
+
+  Value VisitExpr_(const RefReadNode* op) final {
+    Value r = Eval(op->ref);
+    if (const RefValueNode* rv = r.as<RefValueNode>()) {
+      return rv->value;
+    } else {
+      LOG(FATAL) << "type error, type system should have caught this";
+      return Value();
+    }
+  }
+
   InterpreterState get_state(Expr e = Expr()) const {
     InterpreterStateNode::Stack stack;
     for (auto fr : this->stack_.frames) {
index 064343c834ea078cbde9b33d83d245288241621a..d0cc004994d4d439ec8861e854563ed71b6e39b2 100644 (file)
@@ -207,6 +207,14 @@ class AlphaEqualHandler:
       return false;
     }
   }
+
+  bool VisitType_(const RefTypeNode* lhs, const Type& other) final {
+    if (const RefTypeNode* rhs = other.as<RefTypeNode>()) {
+      return TypeEqual(lhs->value, rhs->value);
+    }
+    return false;
+  }
+
   // Expr equal checking.
   bool NDArrayEqual(const runtime::NDArray& lhs,
                     const runtime::NDArray& rhs) {
@@ -361,6 +369,29 @@ class AlphaEqualHandler:
     }
   }
 
+  bool VisitExpr_(const RefCreateNode* op, const Expr& e2) final {
+    if (const RefCreateNode* nr = e2.as<RefCreateNode>()) {
+      return ExprEqual(op->value, nr->value);
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitExpr_(const RefReadNode* op, const Expr& e2) final {
+    if (const RefReadNode* r = e2.as<RefReadNode>()) {
+      return ExprEqual(op->ref, r->ref);
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitExpr_(const RefWriteNode* op, const Expr& e2) final {
+    if (const RefWriteNode* r = e2.as<RefWriteNode>()) {
+      return ExprEqual(op->ref, r->ref) && ExprEqual(op->value, r->value);
+    } else {
+      return false;
+    }
+  }
  private:
   // whether to map open terms.
   bool map_free_var_{false};
index cdb2a32a0009b7cf1ec83c047f426d9f1e5e992a..bc6eee3ebc03c30705b9312df56df931cbb8cd8c 100644 (file)
@@ -271,13 +271,59 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
   p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
 });
 
+RefCreate RefCreateNode::make(Expr value) {
+  NodePtr<RefCreateNode> n = make_node<RefCreateNode>();
+  n->value = std::move(value);
+  return RefCreate(n);
+}
 
-TVM_REGISTER_API("relay._expr.TempExprRealize")
+TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
+  *ret = RefCreateNode::make(args[0]);
+});
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
+  p->stream << "RefCreateNode(" << node->value << ")";
+});
+
+RefRead RefReadNode::make(Expr ref) {
+  NodePtr<RefReadNode> n = make_node<RefReadNode>();
+  n->ref = std::move(ref);
+  return RefRead(n);
+}
+
+TVM_REGISTER_API("relay._make.RefRead")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
-    TempExpr temp = args[0];
-    *ret = temp->Realize();
+  *ret = RefReadNode::make(args[0]);
+});
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) {
+  p->stream << "RefReadNode(" << node->ref << ")";
 });
 
+RefWrite RefWriteNode::make(Expr ref, Expr value) {
+  NodePtr<RefWriteNode> n = make_node<RefWriteNode>();
+  n->ref = std::move(ref);
+  n->value = std::move(value);
+  return RefWrite(n);
+}
+
+TVM_REGISTER_API("relay._make.RefWrite")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  *ret = RefWriteNode::make(args[0], args[1]);
+});
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
+  p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
+});
+
+TVM_REGISTER_API("relay._expr.TempExprRealize")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  TempExpr temp = args[0];
+  *ret = temp->Realize();
+});
 
 }  // namespace relay
 }  // namespace tvm
index e7b4a918c9842cad6ee0e71a63c8da9e0ef9c77e..9bdfa00ce298fd9fc5d085216a493d4eb0c24362 100644 (file)
@@ -157,6 +157,34 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
   }
 }
 
+Expr ExprMutator::VisitExpr_(const RefCreateNode* op) {
+  Expr value = this->Mutate(op->value);
+  if (value.same_as(op->value)) {
+    return GetRef<Expr>(op);
+  } else {
+    return RefCreateNode::make(value);
+  }
+}
+
+Expr ExprMutator::VisitExpr_(const RefReadNode* op) {
+  Expr ref = this->Mutate(op->ref);
+  if (ref.same_as(op->ref)) {
+    return GetRef<Expr>(op);
+  } else {
+    return RefReadNode::make(ref);
+  }
+}
+
+Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
+  Expr ref = this->Mutate(op->ref);
+  Expr value = this->Mutate(op->value);
+  if (ref.same_as(op->ref) && value.same_as(op->value)) {
+    return GetRef<Expr>(op);
+  } else {
+    return RefWriteNode::make(ref, value);
+  }
+}
+
 Type ExprMutator::VisitType(const Type& t) { return t; }
 
 void ExprVisitor::VisitExpr(const Expr& expr) {
@@ -226,6 +254,19 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
   this->VisitExpr(op->tuple);
 }
 
+void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) {
+  this->VisitExpr(op->value);
+}
+
+void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) {
+  this->VisitExpr(op->ref);
+}
+
+void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
+  this->VisitExpr(op->ref);
+  this->VisitExpr(op->value);
+}
+
 void ExprVisitor::VisitType(const Type& t) { return; }
 
 // visitor to implement apply
index d7a8df98fa3fb45c8a08b2dc6ec70a7c6c43e3ed..d984bb051e43d513cc41d3f4e130436560cac5ee 100644 (file)
@@ -16,9 +16,9 @@ namespace relay {
 
 // Hash handler for Relay.
 class RelayHashHandler:
-      public AttrsHashHandler,
-      public TypeFunctor<size_t(const Type&)>,
-      public ExprFunctor<size_t(const Expr&)> {
+    public AttrsHashHandler,
+    public TypeFunctor<size_t(const Type&)>,
+    public ExprFunctor<size_t(const Expr&)> {
  public:
   explicit RelayHashHandler() {}
 
@@ -175,6 +175,12 @@ class RelayHashHandler:
     return hash;
   }
 
+  size_t VisitType_(const RefTypeNode* rtn) final {
+    size_t hash = std::hash<std::string>()(RefTypeNode::_type_key);
+    hash = Combine(hash, TypeHash(rtn->value));
+    return hash;
+  }
+
   // Expr hashing.
   size_t NDArrayHash(const runtime::NDArray& array) {
     size_t hash = std::hash<uint8_t>()(array->dtype.code);
@@ -280,6 +286,24 @@ class RelayHashHandler:
     return hash;
   }
 
+  size_t VisitExpr_(const RefCreateNode* rn) final {
+    size_t hash = std::hash<std::string>()(RefCreateNode::_type_key);
+    hash = Combine(hash, ExprHash(rn->value));
+    return hash;
+  }
+
+  size_t VisitExpr_(const RefReadNode* rn) final {
+    size_t hash = std::hash<std::string>()(RefReadNode::_type_key);
+    hash = Combine(hash, ExprHash(rn->ref));
+    return hash;
+  }
+
+  size_t VisitExpr_(const RefWriteNode* rn) final {
+    size_t hash = std::hash<std::string>()(RefWriteNode::_type_key);
+    hash = Combine(hash, ExprHash(rn->ref));
+    hash = Combine(hash, ExprHash(rn->value));
+    return hash;
+  }
  private:
   // renaming of NodeRef to indicate two nodes equals to each other
   std::unordered_map<NodeRef, size_t, NodeHash, NodeEqual> hash_map_;
index 8f6629a14f922ce591dcd3aea69dd751e3b70cb9..05179d584d84f1f1590a531ba7a9571643d6cf94 100644 (file)
@@ -363,6 +363,34 @@ class TextPrinter :
     return id;
   }
 
+  TextValue VisitExpr_(const RefCreateNode* op) final {
+    TextValue value = GetValue(op->value);
+    TextValue id = this->AllocTempVar();
+    this->PrintIndent();
+    stream_ << id << " = " << "RefCreate(" << op->value << ")";
+    this->PrintEndInst("\n");
+    return id;
+  }
+
+  TextValue VisitExpr_(const RefReadNode* op) final {
+    TextValue ref = GetValue(op->ref);
+    TextValue id = this->AllocTempVar();
+    this->PrintIndent();
+    stream_ << id << " = " << "RefRead(" << ref << ")";
+    this->PrintEndInst("\n");
+    return id;
+  }
+
+  TextValue VisitExpr_(const RefWriteNode* op) final {
+    TextValue ref = GetValue(op->ref);
+    TextValue value = GetValue(op->value);
+    TextValue id = this->AllocTempVar();
+    this->PrintIndent();
+    stream_ << id << " = " << "RefWrite(" << ref << ", " << value << ")";
+    this->PrintEndInst("\n");
+    return id;
+  }
+
   /*!
    * \brief Print the type to os
    * \param type The type to be printed.
@@ -405,6 +433,10 @@ class TextPrinter :
     os << "]";
   }
 
+  void VisitType_(const RefTypeNode* node, std::ostream& os) final {
+    VisitTypeDefault_(node, os);
+  }
+
   void VisitTypeDefault_(const Node* node, std::ostream& os) final {  // NOLINT(*)
     // by default always print as meta-data
     os << meta_.GetMetaNode(GetRef<NodeRef>(node));
index bbe6472609dfdbb2971867127a1b63e3d73160b1..e829d8abd63c26f94901d142bc17f5da28315aec 100644 (file)
@@ -164,5 +164,24 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
   p->stream << "TupleTypeNode(" << node->fields << ")";
 });
 
+RefType RefTypeNode::make(Type value) {
+  NodePtr<RefTypeNode> n = make_node<RefTypeNode>();
+  n->value = std::move(value);
+  return RefType(n);
+}
+
+TVM_REGISTER_API("relay._make.RefType")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  *ret = RefTypeNode::make(args[0]);
+});
+
+TVM_REGISTER_NODE_TYPE(RefTypeNode);
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RefTypeNode>([](const RefTypeNode* node,
+                              tvm::IRPrinter* p) {
+  p->stream << "RefTypeNode(" << node->value << ")";
+});
+
 }  // namespace relay
 }  // namespace tvm
index 0ef1743cbbc4ac066ab81601121d887d77064bd5..100c633a2997bc3b405af6417f79b0835027ad56 100644 (file)
@@ -38,6 +38,10 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) {
   }
 }
 
+void TypeVisitor::VisitType_(const RefTypeNode* op) {
+  this->VisitType(op->value);
+}
+
 void TypeVisitor::VisitType_(const TypeRelationNode* op) {
   for (const Type& t : op->args) {
     this->VisitType(t);
@@ -119,6 +123,10 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) {
   }
 }
 
+Type TypeMutator::VisitType_(const RefTypeNode* op) {
+  return RefTypeNode::make(this->VisitType(op->value));
+}
+
 Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
   Array<Type> new_args = MutateArray(type_rel->args);
   if (new_args.same_as(type_rel->args)) {
index e8dfd2b7cd7cdc2a6c498bd57781e149fa217537..1be55e78eee6428a37c7aca5747f76d8e6f9e06b 100644 (file)
@@ -68,7 +68,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
   virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
-
+  virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitTypeDefault_(const Node* op, Args...) {
     LOG(FATAL) << "Do not have a default for " << op->type_key();
     throw;  // unreachable, written to stop compiler warning
@@ -86,6 +86,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
     RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
     RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
     RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
+    RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
     return vtable;
   }
 };
@@ -101,6 +102,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
   void VisitType_(const FuncTypeNode* op) override;
   void VisitType_(const TupleTypeNode* op) override;
   void VisitType_(const TypeRelationNode* op) override;
+  void VisitType_(const RefTypeNode* op) override;
 };
 
 // Mutator that transform a type to another one.
@@ -112,6 +114,7 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
   Type VisitType_(const FuncTypeNode* op) override;
   Type VisitType_(const TupleTypeNode* op) override;
   Type VisitType_(const TypeRelationNode* type_rel) override;
+  Type VisitType_(const RefTypeNode* op) override;
 
  private:
   Array<Type> MutateArray(Array<Type> arr);
index 572c62cfab10bec22ef4ad71abca1949c7ece725..a6298ba448f3e8359e7fe8c25fa5ad491b77ee45 100644 (file)
@@ -162,6 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
       current->extern_ref = true;
     }
   }
+
   void AddNode(const tvm::Node* key) {
     auto it = graph_.node_map.find(key);
     CHECK(it != graph_.node_map.end())
@@ -174,7 +175,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
   }
 
   // Post order tree
-  void VisitExpr_(const FunctionNode* op) {
+  void VisitExpr_(const FunctionNode* op) final {
     for (auto param : op->params) {
       this->Update(param, nullptr, kOpaque);
     }
@@ -182,7 +183,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
     ExprVisitor::VisitExpr_(op);
   }
 
-  void VisitExpr_(const ConstantNode* op) {
+  void VisitExpr_(const ConstantNode* op) final {
     this->AddNode(op);
     Node* node = graph_.node_map.at(op);
     DataType dtype = TVMType2Type(op->data->dtype);
@@ -202,7 +203,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
     }
   }
 
-  void VisitExpr_(const CallNode* call) {
+  void VisitExpr_(const CallNode* call) final {
     CHECK(graph_.node_map.count(call));
     Node* node = graph_.node_map.at(call);
     static auto fpattern =
@@ -232,7 +233,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
     this->AddNode(call);
   }
 
-  void VisitExpr_(const TupleNode* op) {
+  void VisitExpr_(const TupleNode* op) final {
     CHECK(graph_.node_map.count(op));
     Node* tuple_node = graph_.node_map.at(op);
     tuple_node->pattern = kInjective;
@@ -247,7 +248,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
     this->AddNode(op);
   }
 
-  void VisitExpr_(const TupleGetItemNode* op) {
+  void VisitExpr_(const TupleGetItemNode* op) final {
     CHECK(graph_.node_map.count(op));
     Node* node = graph_.node_map.at(op);
     this->Update(op->tuple, node, kOpaque);
@@ -255,11 +256,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
     this->AddNode(op);
   }
 
-  void VisitExpr_(const VarNode* op) {
+  void VisitExpr_(const VarNode* op) final {
     this->AddNode(op);
   }
 
-  void VisitExpr_(const LetNode* op) {
+  void VisitExpr_(const LetNode* op) final {
     // do not fuse through let.
     this->Update(op->var, nullptr, kOpaque);
     this->Update(op->value, nullptr, kOpaque);
@@ -268,7 +269,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
     this->AddNode(op);
   }
 
-  void VisitExpr_(const IfNode* op) {
+  void VisitExpr_(const IfNode* op) final {
     // do not fuse through if.
     this->Update(op->cond, nullptr, kOpaque);
     this->Update(op->true_branch, nullptr, kOpaque);
@@ -276,6 +277,25 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
     ExprVisitor::VisitExpr_(op);
     this->AddNode(op);
   }
+
+  void VisitExpr_(const RefCreateNode* op) final {
+    this->Update(op->value, nullptr, kOpaque);
+    ExprVisitor::VisitExpr_(op);
+    this->AddNode(op);
+  }
+
+  void VisitExpr_(const RefReadNode* op) final {
+    this->Update(op->ref, nullptr, kOpaque);
+    ExprVisitor::VisitExpr_(op);
+    this->AddNode(op);
+  }
+
+  void VisitExpr_(const RefWriteNode* op) final {
+    this->Update(op->ref, nullptr, kOpaque);
+    this->Update(op->value, nullptr, kOpaque);
+    ExprVisitor::VisitExpr_(op);
+    this->AddNode(op);
+  }
 };
 
 IndexedForwardGraph IndexedForwardGraph::Create(
index 7253a600dabfbe3ad5f538b6b064028fac6d816e..200f5385a37a75e539100a12e572c6d164f0c17e 100644 (file)
@@ -82,6 +82,12 @@ struct KindChecker : TypeVisitor {
     valid = valid && IsTypeKind(op->ret_type);
   }
 
+  void VisitType_(const RefTypeNode* op) override {
+    // tuples should only contain normal types
+    this->VisitType(op->value);
+    valid = valid && IsTypeKind(op->value);
+  }
+
   void VisitType_(const TypeRelationNode* op) override {
     // arguments to type relation should be normal types
     for (const Type& t : op->args) {
index b17c1c1f04399f12640474e225e329fd09f6cc3e..10ba3b127bbf6c373b0b0983e48d030cc772d171 100644 (file)
@@ -431,6 +431,23 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {});
     return solver_.Resolve(ret);
   }
+
+  Type VisitExpr_(const RefCreateNode* op) final {
+    return RefTypeNode::make(GetType(op->value));
+  }
+
+  Type VisitExpr_(const RefReadNode* op) final {
+    Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
+    this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefRead>(op));
+    return it;
+  }
+
+  Type VisitExpr_(const RefWriteNode* op) final {
+    Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
+    this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefWrite>(op));
+    this->Unify(GetType(op->value), it, GetRef<RefWrite>(op));
+    return TupleTypeNode::make({});
+  }
 };
 
 class TypeInferencer::Resolver : public ExprMutator {
@@ -480,6 +497,18 @@ class TypeInferencer::Resolver : public ExprMutator {
     return AttachCheckedType(op);
   }
 
+  Expr VisitExpr_(const RefCreateNode* op) final {
+    return AttachCheckedType(op);
+  }
+
+  Expr VisitExpr_(const RefReadNode* op) final {
+    return AttachCheckedType(op);
+  }
+
+  Expr VisitExpr_(const RefWriteNode* op) final {
+    return AttachCheckedType(op);
+  }
+
   // attach checked type to the mutated node.
   template<typename T>
   Expr AttachCheckedType(const T* op) {
index 617aafdc712cd180b0b27c8d012323e6a420925a..fcd39e7913391b69dd431d37181c428f56f99ecd 100644 (file)
@@ -116,7 +116,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
   }
 
   // default: unify only if alpha-equal
-  Type VisitTypeDefault_(const Node* op, const Type& tn) override {
+  Type VisitTypeDefault_(const Node* op, const Type& tn) final {
     NodeRef nr = GetRef<NodeRef>(op);
     Type t1 = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
     if (!AlphaEqual(t1, tn)) {
@@ -125,7 +125,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     return t1;
   }
 
-  Type VisitType_(const TupleTypeNode* op, const Type& tn) override {
+  Type VisitType_(const TupleTypeNode* op, const Type& tn) final {
     const auto* ttn = tn.as<TupleTypeNode>();
     if (!ttn || op->fields.size() != ttn->fields.size()) {
       return Type(nullptr);
@@ -142,7 +142,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     return TupleTypeNode::make(new_fields);
   }
 
-  Type VisitType_(const FuncTypeNode* op, const Type& tn) override {
+  Type VisitType_(const FuncTypeNode* op, const Type& tn) final {
     const auto* ftn = tn.as<FuncTypeNode>();
     if (!ftn
         || op->arg_types.size() != ftn->arg_types.size()
@@ -181,6 +181,14 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints);
   }
 
+  Type VisitType_(const RefTypeNode* op, const Type& tn) final {
+    const auto* rtn = tn.as<RefTypeNode>();
+    if (!rtn) {
+      return Type(nullptr);
+    }
+    return RefTypeNode::make(Unify(op->value, rtn->value));
+  }
+
  private:
   TypeSolver* solver_;
 };
index 38c340db424df1dbbd745afb2a1efa28dc1a12cb..801b3068eff06979f0f0711dcde6082208ffe253 100644 (file)
@@ -110,6 +110,22 @@ def test_loop():
     check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
 
 
+def test_ref():
+    mod = relay.Module()
+    three_with_ref = relay.GlobalVar('three_with_ref')
+    i = relay.Var('i')
+    iv = relay.Var('iv')
+    u = relay.Var('u')
+    uv = relay.Var('uv')
+    body = relay.add(iv, uv)
+    body = relay.Let(uv, relay.RefRead(i), body)
+    body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
+    body = relay.Let(iv, relay.RefRead(i), body)
+    body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
+    mod[three_with_ref] = relay.Function([], body)
+    check_eval(three_with_ref, [], 3, mod=mod)
+
+
 def test_binds():
     x = relay.var("x")
     y = relay.add(x, x)
@@ -118,6 +134,7 @@ def test_binds():
     res = intrp.evaluate(y, binds={x: xx}).asnumpy()
     tvm.testing.assert_allclose(xx + xx, res)
 
+
 def test_kwargs_params():
     x = relay.var("x", shape=(1, 10))
     y = relay.var("y", shape=(1, 10))
@@ -131,6 +148,7 @@ def test_kwargs_params():
     res = intrp.evaluate(f)(x_data, **params).data
     tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
 
+
 if __name__ == "__main__":
     test_id()
     test_add_const()
@@ -140,3 +158,4 @@ if __name__ == "__main__":
     test_loop()
     test_binds()
     test_kwargs_params()
+    test_ref()
index ac4eb1b404dbca2c050fd2cd3cee4dbcc12e01a6..eeefbc6c3051bc0bbd081e8649e27aa718687ec7 100644 (file)
@@ -131,6 +131,18 @@ def test_tuple():
             relay.TupleType([tp, tp]))
 
 
+def test_ref():
+    x = relay.var("x", "float32")
+    y = relay.var("y", "float32")
+    r = relay.RefCreate(x)
+    st = relay.scalar_type("float32")
+    assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st)
+    g = relay.RefRead(r)
+    assert relay.ir_pass.infer_type(g).checked_type == st
+    w = relay.RefWrite(r, y)
+    assert relay.ir_pass.infer_type(w).checked_type == relay.TupleType([])
+
+
 def test_free_expr():
     x = relay.var("x", "float32")
     y = relay.add(x, x)
@@ -187,12 +199,9 @@ if __name__ == "__main__":
     test_decl()
     test_recursion()
     test_tuple()
-    test_generalized_tuple()
     test_incomplete_call()
-    test_generalized_call()
-    test_call_with_type_args()
     test_free_expr()
     test_type_args()
-    test_self_reference()
     test_global_var_recursion()
     test_equal()
+    test_ref()