return ret;
}
/*! \brief A set that contain all the Feature. */
- static FeatureSet AllFeature() {
+ static FeatureSet All() {
FeatureSet fs;
fs.bs_.flip();
return fs;
}
/*! \brief The empty set. Contain no Feature. */
- static FeatureSet NoFeature() {
+ static FeatureSet No() {
FeatureSet fs;
return fs;
}
"""
seq = transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
+ transform.ToANormalForm(),
transform.InferType()])
return seq(self.mod)
#include <tvm/relay/interpreter.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/debug.h>
+#include <tvm/relay/feature.h>
#include "compile_engine.h"
namespace tvm {
Target target) {
auto intrp = std::make_shared<Interpreter>(mod, context, target);
auto packed = [intrp](Expr expr) {
+ auto f = DetectFeature(expr);
+ CHECK(f.is_subset_of(FeatureSet::All() - fGraph));
return intrp->Eval(expr);
};
return TypedPackedFunc<Value(Expr)>(packed);
* \return the comparison result.
*/
bool TypeEqual(const Type& lhs, const Type& rhs) {
- auto compute = [&](){
+ auto compute = [&]() {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitType(lhs, rhs);
FeatureSet DetectFeature(const Expr& expr) {
if (!expr.defined()) {
- return FeatureSet::NoFeature();
+ return FeatureSet::No();
}
struct FeatureDetector : ExprVisitor {
std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
- FeatureSet fs = FeatureSet::NoFeature();
+ FeatureSet fs = FeatureSet::No();
+
void VisitExpr(const Expr& expr) final {
if (visited_.count(expr) == 0) {
+ visited_.insert(expr);
ExprVisitor::VisitExpr(expr);
} else {
if (!IsAtomic(expr)) {
void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
STMT \
fs += f##CONSTRUCT_NAME; \
- ExprVisitor::VisitExpr_(op); \
}
-#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, {})
+#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \
+ ExprVisitor::VisitExpr_(op); \
+ })
DETECT_DEFAULT_CONSTRUCT(Var)
DETECT_DEFAULT_CONSTRUCT(GlobalVar)
DETECT_DEFAULT_CONSTRUCT(Constant)
DETECT_DEFAULT_CONSTRUCT(Tuple)
DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
- DETECT_DEFAULT_CONSTRUCT(Function)
+ DETECT_CONSTRUCT(Function, {
+ if (!op->IsPrimitive()) {
+ ExprVisitor::VisitExpr_(op);
+ }
+ })
DETECT_DEFAULT_CONSTRUCT(Op)
DETECT_DEFAULT_CONSTRUCT(Call)
DETECT_CONSTRUCT(Let, {
fs += fLetRec;
}
}
+ ExprVisitor::VisitExpr_(op);
})
DETECT_DEFAULT_CONSTRUCT(If)
DETECT_DEFAULT_CONSTRUCT(RefCreate)
}
FeatureSet DetectFeature(const Module& mod) {
- FeatureSet fs = FeatureSet::NoFeature();
+ FeatureSet fs = FeatureSet::No();
if (mod.defined()) {
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
// Perform unification on two types and report the error at the expression
// or the span of the expression.
Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) {
- // TODO(tqchen, jroesch): propagate span to solver
try {
- // instantiate higher-order func types when unifying because
- // we only allow polymorphism at the top level
- Type first = t1;
- Type second = t2;
- if (auto* ft1 = t1.as<FuncTypeNode>()) {
- first = InstantiateFuncType(ft1);
- }
- if (auto* ft2 = t2.as<FuncTypeNode>()) {
- second = InstantiateFuncType(ft2);
- }
- return solver_.Unify(first, second, expr);
+ return solver_.Unify(t1, t2, expr);
} catch (const dmlc::Error &e) {
this->ReportFatalError(
expr,
const auto* ftn = tn.as<FuncTypeNode>();
if (!ftn
|| op->arg_types.size() != ftn->arg_types.size()
- || op->type_params.size() != ftn->type_params.size()
|| op->type_constraints.size() != ftn->type_constraints.size()) {
return Type(nullptr);
}
+ // without loss of generality, suppose op->type_params.size() >= ftn->type_params.size().
+ if (op->type_params.size() < ftn->type_params.size()) {
+ return VisitType_(ftn, GetRef<FuncType>(op));
+ }
+
// remap type vars so they match
Map<TypeVar, Type> subst_map;
- for (size_t i = 0; i < op->type_params.size(); i++) {
- subst_map.Set(ftn->type_params[i], op->type_params[i]);
+ tvm::Array<TypeVar> ft_type_params;
+ for (size_t i = 0; i < ftn->type_params.size(); ++i) {
+ subst_map.Set(op->type_params[i], ftn->type_params[i]);
+ ft_type_params.push_back(op->type_params[i]);
+ }
+
+ for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) {
+ subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType));
}
- auto ft1 = GetRef<FuncType>(op);
- auto ft2 = Downcast<FuncType>(Bind(GetRef<FuncType>(ftn), subst_map));
+ FuncType ft = FuncTypeNode::make(op->arg_types,
+ op->ret_type,
+ ft_type_params,
+ op->type_constraints);
+ auto ft1 = Downcast<FuncType>(Bind(ft, subst_map));
+ auto ft2 = GetRef<FuncType>(ftn);
Type ret_type = Unify(ft1->ret_type, ft2->ret_type);
std::vector<Type> arg_types;
- for (size_t i = 0; i < ft1->arg_types.size(); i++) {
+ for (size_t i = 0; i < ft2->arg_types.size(); ++i) {
Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]);
arg_types.push_back(arg_type);
}
std::vector<TypeConstraint> type_constraints;
- for (size_t i = 0; i < ft1->type_constraints.size(); i++) {
+ for (size_t i = 0; i < ft1->type_constraints.size(); ++i) {
Type unified_constraint = Unify(ft1->type_constraints[i],
ft2->type_constraints[i]);
const auto* tcn = unified_constraint.as<TypeConstraintNode>();
type_constraints.push_back(GetRef<TypeConstraint>(tcn));
}
- return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints);
+ return FuncTypeNode::make(arg_types, ret_type, ft2->type_params, type_constraints);
}
Type VisitType_(const RefTypeNode* op, const Type& tn) final {
Feature.fLet,
Feature.fRefCreate,
Feature.fRefRead,
- Feature.fRefWrite
+ Feature.fRefWrite,
+ Feature.fGraph
])
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
+def test_id():
+ x = relay.var("x", shape=[])
+ id = run_infer_type(relay.Function([x], x))
+ id_cps = run_infer_type(to_cps(id))
+
+
+def test_double():
+ t = relay.TypeVar("t")
+ x = relay.var("x", t)
+ f = relay.var("f", relay.FuncType([t], t))
+ double = run_infer_type(relay.Function([f, x], f(f(x)), t, [t]))
+ double_cps = run_infer_type(to_cps(double))
+
+
# make sure cps work for recursion.
def test_recursion():
mod = relay.Module()
"""
from tvm import relay
from tvm.relay import op, transform, analysis
+from tvm.relay.analysis import assert_alpha_equal
def run_infer_type(expr, mod=None):
assert ft.checked_type == relay.FuncType([tt], relay.TupleType([]))
+def test_let_polymorphism():
+ id = relay.Var("id")
+ xt = relay.TypeVar("xt")
+ x = relay.Var("x", xt)
+ body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))])
+ body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
+ body = run_infer_type(body)
+ int32 = relay.TensorType((), "int32")
+ assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
+
+
if __name__ == "__main__":
test_free_expr()
test_dual_op()
test_constructor_type()
test_constructor_call()
test_adt_match()
+ test_let_polymorphism()