TVM_DLL static TupleGetItem make(Expr tuple, int index);
- static constexpr const char * _type_key = "relay.TupleGetItem";
+ static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
+/*! \brief Create a new Reference out of initial value. */
+class RefCreate;
+class RefCreateNode : public ExprNode {
+ public:
+ /*! \brief The initial value of the Reference. */
+ Expr value;
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("value", &value);
+ v->Visit("span", &span);
+ v->Visit("_checked_type_", &checked_type_);
+ }
+
+ TVM_DLL static RefCreate make(Expr value);
+
+ static constexpr const char* _type_key = "relay.RefCreate";
+ TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);
+
+/*! \brief Get value out of Reference. */
+class RefRead;
+class RefReadNode : public ExprNode {
+ public:
+ /*! \brief The Reference Expression. */
+ Expr ref;
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("ref", &ref);
+ v->Visit("span", &span);
+ v->Visit("_checked_type_", &checked_type_);
+ }
+
+ TVM_DLL static RefRead make(Expr ref);
+
+ static constexpr const char* _type_key = "relay.RefRead";
+ TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);
+
+/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
+class RefWrite;
+class RefWriteNode : public ExprNode {
+ public:
+ /*! \brief The Reference Expression. */
+ Expr ref;
+ /*! \brief The value to write into. */
+ Expr value;
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("ref", &ref);
+ v->Visit("value", &value);
+ v->Visit("span", &span);
+ v->Visit("_checked_type_", &checked_type_);
+ }
+
+ TVM_DLL static RefWrite make(Expr ref, Expr value);
+
+ static constexpr const char* _type_key = "relay.RefWrite";
+ TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);
+
/*!
* \brief Base class of the temporary expression.
*
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
+ RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
+ RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
+ RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
return vtable;
}
};
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
+ void VisitExpr_(const RefCreateNode* op) override;
+ void VisitExpr_(const RefReadNode* op) override;
+ void VisitExpr_(const RefWriteNode* op) override;
virtual void VisitType(const Type& t);
protected:
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
+ Expr VisitExpr_(const RefCreateNode* op) override;
+ Expr VisitExpr_(const RefReadNode* op) override;
+ Expr VisitExpr_(const RefWriteNode* op) override;
/*!
* \brief Used to visit the types inside of expressions.
*
RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);
+/*! \brief A reference value. */
+class RefValue;
+
+struct RefValueNode : ValueNode {
+ mutable Value value;
+
+ RefValueNode() {}
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("value", &value);
+ }
+
+ TVM_DLL static RefValue make(Value val);
+
+ static constexpr const char* _type_key = "relay.RefValue";
+ TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode);
+};
+
+RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
} // namespace relay
} // namespace tvm
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
+/*!
+ * \brief The type of reference values.
+ */
+class RefType;
+/*!
+ * \brief Reference Type in relay.
+ */
+class RefTypeNode : public TypeNode {
+ public:
+ /*! \brief The type of value in the Reference. */
+ Type value;
+
+ RefTypeNode() {}
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("value", &value);
+ v->Visit("span", &span);
+ }
+
+ TVM_DLL static RefType make(Type value);
+
+ static constexpr const char* _type_key = "relay.RefType";
+ TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode);
+};
+
+RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type);
+
class TypeReporter;
/*!
TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type
+RefType = ty.RefType
# Expr
Expr = expr.Expr
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem
-
-# ExprFunctor
-ExprFunctor = expr_functor.ExprFunctor
-ExprMutator = expr_functor.ExprMutator
+RefCreate = expr.RefCreate
+RefRead = expr.RefRead
+RefWrite = expr.RefWrite
# helper functions
var = expr.var
const = expr.const
bind = expr.bind
+# ExprFunctor
+ExprFunctor = expr_functor.ExprFunctor
+ExprMutator = expr_functor.ExprMutator
+
# Parser
fromtext = parser.fromtext
def visit_op(self, _):
raise Exception("can not compile op in non-eta expanded form")
+ def visit_ref_create(self, _):
+ raise RuntimeError("reference not supported")
+
+ def visit_ref_read(self, _):
+ raise RuntimeError("reference not supported")
+
+ def visit_ref_write(self, _):
+ raise RuntimeError("reference not supported")
+
def _get_json(self):
"""
Convert the sequence of nodes stored by the compiler into the
def __iter__(self):
return iter(self.fields)
+
@register_relay_node
class Closure(Value):
"""A closure produced by the interpreter."""
return str(self.data)
+@register_relay_node
+class RefValue(Value):
+ def __init__(self, value):
+ self.__init_handle_by_constructor__(
+ _make.RefValue, value)
+
+
def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(_nd.cpu(0)))
_make.TupleGetItem, tuple_value, index)
+@register_relay_node
+class RefCreate(Expr):
+ """Create a new reference from initial value.
+ Parameters
+ ----------
+ value: tvm.relay.Expr
+ The initial value.
+ """
+ def __init__(self, value):
+ self.__init_handle_by_constructor__(_make.RefCreate, value)
+
+
+@register_relay_node
+class RefRead(Expr):
+ """Get the value inside the reference.
+ Parameters
+ ----------
+ ref: tvm.relay.Expr
+ The reference.
+ """
+ def __init__(self, ref):
+ self.__init_handle_by_constructor__(_make.RefRead, ref)
+
+
+@register_relay_node
+class RefWrite(Expr):
+ """
+ Update the value inside the reference.
+ The whole expression will evaluate to an empty tuple.
+ Parameters
+ ----------
+ ref: tvm.relay.Expr
+ The reference.
+ value: tvm.relay.Expr
+ The new value.
+ """
+ def __init__(self, ref, value):
+ self.__init_handle_by_constructor__(_make.RefWrite, ref, value)
+
+
class TempExpr(Expr):
"""Baseclass of all TempExpr.
res = self.visit_constant(expr)
elif isinstance(expr, Op):
res = self.visit_op(expr)
+ elif isinstance(expr, RefCreate):
+ res = self.visit_ref_create(expr)
+ elif isinstance(expr, RefRead):
+ res = self.visit_ref_read(expr)
+ elif isinstance(expr, RefWrite):
+ res = self.visit_ref_write(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))
def visit_constant(self, _):
raise NotImplementedError()
+ def visit_ref_create(self, _):
+ raise NotImplementedError()
+
+ def visit_ref_write(self, _):
+ raise NotImplementedError()
+
+ def visit_ref_read(self, _):
+ raise NotImplementedError()
class ExprMutator(ExprFunctor):
"""
def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern])
- def visit_ref_new(self, r):
- return RefNew(self.visit(r.value))
+ def visit_ref_create(self, r):
+ return RefCreate(self.visit(r.value))
def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value))
func, args, num_inputs, attrs)
+@register_relay_node
+class RefType(Type):
+ """Reference Type in relay.
+
+ Parameters
+ ----------
+ value: Type
+ The value type.
+ """
+ def __init__(self, value):
+ self.__init_handle_by_constructor__(_make.RefType, value)
+
+
def scalar_type(dtype):
"""Creates a scalar type.
*ret = TensorValueNode::make(data);
});
+RefValue RefValueNode::make(Value value) {
+ NodePtr<RefValueNode> n = make_node<RefValueNode>();
+ n->value = value;
+ return RefValue(n);
+}
+
+TVM_REGISTER_API("relay._make.RefValue")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ *ret = RefValueNode::make(args[0]);
+ });
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RefValueNode>([](const RefValueNode* node,
+ tvm::IRPrinter* p) {
+ p->stream << "RefValueNode(" << node->value << ")";
+ });
+
/*!
* \brief A stack frame in the Relay interpreter.
*
}
}
+ Value VisitExpr_(const RefWriteNode* op) final {
+ Value r = Eval(op->ref);
+ if (const RefValueNode* rv = r.as<RefValueNode>()) {
+ rv->value = Eval(op->value);
+ return TupleValueNode::make({});
+ } else {
+ LOG(FATAL) << "type error, type system should have caught this";
+ return Value();
+ }
+ }
+
+ Value VisitExpr_(const RefCreateNode* op) final {
+ return RefValueNode::make(Eval(op->value));
+ }
+
+ Value VisitExpr_(const RefReadNode* op) final {
+ Value r = Eval(op->ref);
+ if (const RefValueNode* rv = r.as<RefValueNode>()) {
+ return rv->value;
+ } else {
+ LOG(FATAL) << "type error, type system should have caught this";
+ return Value();
+ }
+ }
+
InterpreterState get_state(Expr e = Expr()) const {
InterpreterStateNode::Stack stack;
for (auto fr : this->stack_.frames) {
return false;
}
}
+
+ bool VisitType_(const RefTypeNode* lhs, const Type& other) final {
+ if (const RefTypeNode* rhs = other.as<RefTypeNode>()) {
+ return TypeEqual(lhs->value, rhs->value);
+ }
+ return false;
+ }
+
// Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) {
}
}
+ bool VisitExpr_(const RefCreateNode* op, const Expr& e2) final {
+ if (const RefCreateNode* nr = e2.as<RefCreateNode>()) {
+ return ExprEqual(op->value, nr->value);
+ } else {
+ return false;
+ }
+ }
+
+ bool VisitExpr_(const RefReadNode* op, const Expr& e2) final {
+ if (const RefReadNode* r = e2.as<RefReadNode>()) {
+ return ExprEqual(op->ref, r->ref);
+ } else {
+ return false;
+ }
+ }
+
+ bool VisitExpr_(const RefWriteNode* op, const Expr& e2) final {
+ if (const RefWriteNode* r = e2.as<RefWriteNode>()) {
+ return ExprEqual(op->ref, r->ref) && ExprEqual(op->value, r->value);
+ } else {
+ return false;
+ }
+ }
private:
// whether to map open terms.
bool map_free_var_{false};
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});
+RefCreate RefCreateNode::make(Expr value) {
+ NodePtr<RefCreateNode> n = make_node<RefCreateNode>();
+ n->value = std::move(value);
+ return RefCreate(n);
+}
-TVM_REGISTER_API("relay._expr.TempExprRealize")
+TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
+ *ret = RefCreateNode::make(args[0]);
+});
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
+ p->stream << "RefCreateNode(" << node->value << ")";
+});
+
+RefRead RefReadNode::make(Expr ref) {
+ NodePtr<RefReadNode> n = make_node<RefReadNode>();
+ n->ref = std::move(ref);
+ return RefRead(n);
+}
+
+TVM_REGISTER_API("relay._make.RefRead")
.set_body([](TVMArgs args, TVMRetValue* ret) {
- TempExpr temp = args[0];
- *ret = temp->Realize();
+ *ret = RefReadNode::make(args[0]);
+});
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) {
+ p->stream << "RefReadNode(" << node->ref << ")";
});
+RefWrite RefWriteNode::make(Expr ref, Expr value) {
+ NodePtr<RefWriteNode> n = make_node<RefWriteNode>();
+ n->ref = std::move(ref);
+ n->value = std::move(value);
+ return RefWrite(n);
+}
+
+TVM_REGISTER_API("relay._make.RefWrite")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ *ret = RefWriteNode::make(args[0], args[1]);
+});
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
+ p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
+});
+
+TVM_REGISTER_API("relay._expr.TempExprRealize")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ TempExpr temp = args[0];
+ *ret = temp->Realize();
+});
} // namespace relay
} // namespace tvm
}
}
+Expr ExprMutator::VisitExpr_(const RefCreateNode* op) {
+ Expr value = this->Mutate(op->value);
+ if (value.same_as(op->value)) {
+ return GetRef<Expr>(op);
+ } else {
+ return RefCreateNode::make(value);
+ }
+}
+
+Expr ExprMutator::VisitExpr_(const RefReadNode* op) {
+ Expr ref = this->Mutate(op->ref);
+ if (ref.same_as(op->ref)) {
+ return GetRef<Expr>(op);
+ } else {
+ return RefReadNode::make(ref);
+ }
+}
+
+Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
+ Expr ref = this->Mutate(op->ref);
+ Expr value = this->Mutate(op->value);
+ if (ref.same_as(op->ref) && value.same_as(op->value)) {
+ return GetRef<Expr>(op);
+ } else {
+ return RefWriteNode::make(ref, value);
+ }
+}
+
Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::VisitExpr(const Expr& expr) {
this->VisitExpr(op->tuple);
}
+void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) {
+ this->VisitExpr(op->value);
+}
+
+void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) {
+ this->VisitExpr(op->ref);
+}
+
+void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
+ this->VisitExpr(op->ref);
+ this->VisitExpr(op->value);
+}
+
void ExprVisitor::VisitType(const Type& t) { return; }
// visitor to implement apply
// Hash handler for Relay.
class RelayHashHandler:
- public AttrsHashHandler,
- public TypeFunctor<size_t(const Type&)>,
- public ExprFunctor<size_t(const Expr&)> {
+ public AttrsHashHandler,
+ public TypeFunctor<size_t(const Type&)>,
+ public ExprFunctor<size_t(const Expr&)> {
public:
explicit RelayHashHandler() {}
return hash;
}
+ size_t VisitType_(const RefTypeNode* rtn) final {
+ size_t hash = std::hash<std::string>()(RefTypeNode::_type_key);
+ hash = Combine(hash, TypeHash(rtn->value));
+ return hash;
+ }
+
// Expr hashing.
size_t NDArrayHash(const runtime::NDArray& array) {
size_t hash = std::hash<uint8_t>()(array->dtype.code);
return hash;
}
+ size_t VisitExpr_(const RefCreateNode* rn) final {
+ size_t hash = std::hash<std::string>()(RefCreateNode::_type_key);
+ hash = Combine(hash, ExprHash(rn->value));
+ return hash;
+ }
+
+ size_t VisitExpr_(const RefReadNode* rn) final {
+ size_t hash = std::hash<std::string>()(RefReadNode::_type_key);
+ hash = Combine(hash, ExprHash(rn->ref));
+ return hash;
+ }
+
+ size_t VisitExpr_(const RefWriteNode* rn) final {
+ size_t hash = std::hash<std::string>()(RefWriteNode::_type_key);
+ hash = Combine(hash, ExprHash(rn->ref));
+ hash = Combine(hash, ExprHash(rn->value));
+ return hash;
+ }
private:
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, size_t, NodeHash, NodeEqual> hash_map_;
return id;
}
+ TextValue VisitExpr_(const RefCreateNode* op) final {
+ TextValue value = GetValue(op->value);
+ TextValue id = this->AllocTempVar();
+ this->PrintIndent();
+ stream_ << id << " = " << "RefCreate(" << op->value << ")";
+ this->PrintEndInst("\n");
+ return id;
+ }
+
+ TextValue VisitExpr_(const RefReadNode* op) final {
+ TextValue ref = GetValue(op->ref);
+ TextValue id = this->AllocTempVar();
+ this->PrintIndent();
+ stream_ << id << " = " << "RefRead(" << ref << ")";
+ this->PrintEndInst("\n");
+ return id;
+ }
+
+ TextValue VisitExpr_(const RefWriteNode* op) final {
+ TextValue ref = GetValue(op->ref);
+ TextValue value = GetValue(op->value);
+ TextValue id = this->AllocTempVar();
+ this->PrintIndent();
+ stream_ << id << " = " << "RefWrite(" << ref << ", " << value << ")";
+ this->PrintEndInst("\n");
+ return id;
+ }
+
/*!
* \brief Print the type to os
* \param type The type to be printed.
os << "]";
}
+ void VisitType_(const RefTypeNode* node, std::ostream& os) final {
+ VisitTypeDefault_(node, os);
+ }
+
void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*)
// by default always print as meta-data
os << meta_.GetMetaNode(GetRef<NodeRef>(node));
p->stream << "TupleTypeNode(" << node->fields << ")";
});
+RefType RefTypeNode::make(Type value) {
+ NodePtr<RefTypeNode> n = make_node<RefTypeNode>();
+ n->value = std::move(value);
+ return RefType(n);
+}
+
+TVM_REGISTER_API("relay._make.RefType")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ *ret = RefTypeNode::make(args[0]);
+});
+
+TVM_REGISTER_NODE_TYPE(RefTypeNode);
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RefTypeNode>([](const RefTypeNode* node,
+ tvm::IRPrinter* p) {
+ p->stream << "RefTypeNode(" << node->value << ")";
+});
+
} // namespace relay
} // namespace tvm
}
}
+void TypeVisitor::VisitType_(const RefTypeNode* op) {
+ this->VisitType(op->value);
+}
+
void TypeVisitor::VisitType_(const TypeRelationNode* op) {
for (const Type& t : op->args) {
this->VisitType(t);
}
}
+Type TypeMutator::VisitType_(const RefTypeNode* op) {
+ return RefTypeNode::make(this->VisitType(op->value));
+}
+
Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
Array<Type> new_args = MutateArray(type_rel->args);
if (new_args.same_as(type_rel->args)) {
virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
-
+ virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Node* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
throw; // unreachable, written to stop compiler warning
RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
+ RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
return vtable;
}
};
void VisitType_(const FuncTypeNode* op) override;
void VisitType_(const TupleTypeNode* op) override;
void VisitType_(const TypeRelationNode* op) override;
+ void VisitType_(const RefTypeNode* op) override;
};
// Mutator that transform a type to another one.
Type VisitType_(const FuncTypeNode* op) override;
Type VisitType_(const TupleTypeNode* op) override;
Type VisitType_(const TypeRelationNode* type_rel) override;
+ Type VisitType_(const RefTypeNode* op) override;
private:
Array<Type> MutateArray(Array<Type> arr);
current->extern_ref = true;
}
}
+
void AddNode(const tvm::Node* key) {
auto it = graph_.node_map.find(key);
CHECK(it != graph_.node_map.end())
}
// Post order tree
- void VisitExpr_(const FunctionNode* op) {
+ void VisitExpr_(const FunctionNode* op) final {
for (auto param : op->params) {
this->Update(param, nullptr, kOpaque);
}
ExprVisitor::VisitExpr_(op);
}
- void VisitExpr_(const ConstantNode* op) {
+ void VisitExpr_(const ConstantNode* op) final {
this->AddNode(op);
Node* node = graph_.node_map.at(op);
DataType dtype = TVMType2Type(op->data->dtype);
}
}
- void VisitExpr_(const CallNode* call) {
+ void VisitExpr_(const CallNode* call) final {
CHECK(graph_.node_map.count(call));
Node* node = graph_.node_map.at(call);
static auto fpattern =
this->AddNode(call);
}
- void VisitExpr_(const TupleNode* op) {
+ void VisitExpr_(const TupleNode* op) final {
CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective;
this->AddNode(op);
}
- void VisitExpr_(const TupleGetItemNode* op) {
+ void VisitExpr_(const TupleGetItemNode* op) final {
CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op);
this->Update(op->tuple, node, kOpaque);
this->AddNode(op);
}
- void VisitExpr_(const VarNode* op) {
+ void VisitExpr_(const VarNode* op) final {
this->AddNode(op);
}
- void VisitExpr_(const LetNode* op) {
+ void VisitExpr_(const LetNode* op) final {
// do not fuse through let.
this->Update(op->var, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
this->AddNode(op);
}
- void VisitExpr_(const IfNode* op) {
+ void VisitExpr_(const IfNode* op) final {
// do not fuse through if.
this->Update(op->cond, nullptr, kOpaque);
this->Update(op->true_branch, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
+
+ void VisitExpr_(const RefCreateNode* op) final {
+ this->Update(op->value, nullptr, kOpaque);
+ ExprVisitor::VisitExpr_(op);
+ this->AddNode(op);
+ }
+
+ void VisitExpr_(const RefReadNode* op) final {
+ this->Update(op->ref, nullptr, kOpaque);
+ ExprVisitor::VisitExpr_(op);
+ this->AddNode(op);
+ }
+
+ void VisitExpr_(const RefWriteNode* op) final {
+ this->Update(op->ref, nullptr, kOpaque);
+ this->Update(op->value, nullptr, kOpaque);
+ ExprVisitor::VisitExpr_(op);
+ this->AddNode(op);
+ }
};
IndexedForwardGraph IndexedForwardGraph::Create(
valid = valid && IsTypeKind(op->ret_type);
}
+ void VisitType_(const RefTypeNode* op) override {
+ // tuples should only contain normal types
+ this->VisitType(op->value);
+ valid = valid && IsTypeKind(op->value);
+ }
+
void VisitType_(const TypeRelationNode* op) override {
// arguments to type relation should be normal types
for (const Type& t : op->args) {
auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {});
return solver_.Resolve(ret);
}
+
+ Type VisitExpr_(const RefCreateNode* op) final {
+ return RefTypeNode::make(GetType(op->value));
+ }
+
+ Type VisitExpr_(const RefReadNode* op) final {
+ Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
+ this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefRead>(op));
+ return it;
+ }
+
+ Type VisitExpr_(const RefWriteNode* op) final {
+ Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
+ this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefWrite>(op));
+ this->Unify(GetType(op->value), it, GetRef<RefWrite>(op));
+ return TupleTypeNode::make({});
+ }
};
class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op);
}
+ Expr VisitExpr_(const RefCreateNode* op) final {
+ return AttachCheckedType(op);
+ }
+
+ Expr VisitExpr_(const RefReadNode* op) final {
+ return AttachCheckedType(op);
+ }
+
+ Expr VisitExpr_(const RefWriteNode* op) final {
+ return AttachCheckedType(op);
+ }
+
// attach checked type to the mutated node.
template<typename T>
Expr AttachCheckedType(const T* op) {
}
// default: unify only if alpha-equal
- Type VisitTypeDefault_(const Node* op, const Type& tn) override {
+ Type VisitTypeDefault_(const Node* op, const Type& tn) final {
NodeRef nr = GetRef<NodeRef>(op);
Type t1 = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
if (!AlphaEqual(t1, tn)) {
return t1;
}
- Type VisitType_(const TupleTypeNode* op, const Type& tn) override {
+ Type VisitType_(const TupleTypeNode* op, const Type& tn) final {
const auto* ttn = tn.as<TupleTypeNode>();
if (!ttn || op->fields.size() != ttn->fields.size()) {
return Type(nullptr);
return TupleTypeNode::make(new_fields);
}
- Type VisitType_(const FuncTypeNode* op, const Type& tn) override {
+ Type VisitType_(const FuncTypeNode* op, const Type& tn) final {
const auto* ftn = tn.as<FuncTypeNode>();
if (!ftn
|| op->arg_types.size() != ftn->arg_types.size()
return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints);
}
+ Type VisitType_(const RefTypeNode* op, const Type& tn) final {
+ const auto* rtn = tn.as<RefTypeNode>();
+ if (!rtn) {
+ return Type(nullptr);
+ }
+ return RefTypeNode::make(Unify(op->value, rtn->value));
+ }
+
private:
TypeSolver* solver_;
};
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
+def test_ref():
+ mod = relay.Module()
+ three_with_ref = relay.GlobalVar('three_with_ref')
+ i = relay.Var('i')
+ iv = relay.Var('iv')
+ u = relay.Var('u')
+ uv = relay.Var('uv')
+ body = relay.add(iv, uv)
+ body = relay.Let(uv, relay.RefRead(i), body)
+ body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
+ body = relay.Let(iv, relay.RefRead(i), body)
+ body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
+ mod[three_with_ref] = relay.Function([], body)
+ check_eval(three_with_ref, [], 3, mod=mod)
+
+
def test_binds():
x = relay.var("x")
y = relay.add(x, x)
res = intrp.evaluate(y, binds={x: xx}).asnumpy()
tvm.testing.assert_allclose(xx + xx, res)
+
def test_kwargs_params():
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
res = intrp.evaluate(f)(x_data, **params).data
tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
+
if __name__ == "__main__":
test_id()
test_add_const()
test_loop()
test_binds()
test_kwargs_params()
+ test_ref()
relay.TupleType([tp, tp]))
+def test_ref():
+ x = relay.var("x", "float32")
+ y = relay.var("y", "float32")
+ r = relay.RefCreate(x)
+ st = relay.scalar_type("float32")
+ assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st)
+ g = relay.RefRead(r)
+ assert relay.ir_pass.infer_type(g).checked_type == st
+ w = relay.RefWrite(r, y)
+ assert relay.ir_pass.infer_type(w).checked_type == relay.TupleType([])
+
+
def test_free_expr():
x = relay.var("x", "float32")
y = relay.add(x, x)
test_decl()
test_recursion()
test_tuple()
- test_generalized_tuple()
test_incomplete_call()
- test_generalized_call()
- test_call_with_type_args()
test_free_expr()
test_type_args()
- test_self_reference()
test_global_var_recursion()
test_equal()
+ test_ref()