class ConstructorValue;
struct ConstructorValueNode : ValueNode {
- Constructor constructor;
+ int tag;
tvm::Array<Value> fields;
+ /*! \brief Optional field tracking ADT constructor. */
+ Constructor constructor;
+
void VisitAttrs(tvm::AttrVisitor* v) final {
- v->Visit("constructor", &constructor);
+ v->Visit("tag", &tag);
v->Visit("fields", &fields);
+ v->Visit("constructor", &constructor);
}
- TVM_DLL static ConstructorValue make(Constructor constructor,
- tvm::Array<Value> fields);
+ TVM_DLL static ConstructorValue make(int tag,
+ tvm::Array<Value> fields,
+ Constructor construtor = {});
static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
@register_relay_node
class ConstructorValue(Value):
- def __init__(self, constructor, fields, types):
+ def __init__(self, tag, fields, constructor, types):
self.__init_handle_by_constructor__(
- _make.ConstructorValue, constructor, fields, types)
+ _make.ConstructorValue, tag, fields, constructor, types)
@register_relay_node
args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate.
"""
-
mod = optimize(mod)
args = list(args)
assert isinstance(args, list)
def __init__(self, mod):
self.mod = mod
self.load_prelude()
-
self.define_list_adt()
self.define_list_hd()
self.define_list_tl()
# helper functions for working with nats
-def count(n):
+def count(prelude, n):
"""Takes a ConstructorValue corresponding to a nat ADT
and converts it into a Python integer. This is an example of
using an ADT value in Python.
"""
assert isinstance(n, ConstructorValue)
- if n.constructor.name_hint == 'z':
+ if n.tag == prelude.z.tag:
return 0
- assert n.constructor.name_hint == 's'
- return 1 + count(n.fields[0])
+ assert n.tag == prelude.s.tag
+ return 1 + count(prelude, n.fields[0])
def make_nat_value(prelude, n):
constructs a ConstructorValue representing that value as a nat.
"""
if n == 0:
- return ConstructorValue(prelude.z, [], [])
- return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], [])
+ return ConstructorValue(prelude.z.tag, [], None, [])
+ return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, [])
def make_nat_expr(prelude, n):
p->stream << "RefValueNode(" << node->value << ")";
});
-ConstructorValue ConstructorValueNode::make(Constructor constructor,
- tvm::Array<Value> fields) {
+ConstructorValue ConstructorValueNode::make(int tag,
+ tvm::Array<Value> fields,
+ Constructor constructor) {
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
- n->constructor = constructor;
+ n->tag = tag;
n->fields = fields;
+ n->constructor = constructor;
return ConstructorValue(n);
}
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
tvm::IRPrinter* p) {
- p->stream << "ConstructorValueNode(" << node->constructor
+ p->stream << "ConstructorValueNode(" << node->tag << ","
<< node->fields << ")";
});
"fusing and lowering";
}
if (auto con = call->op.as<ConstructorNode>()) {
- return ConstructorValueNode::make(GetRef<Constructor>(con), args);
+ return ConstructorValueNode::make(con->tag, args, GetRef<Constructor>(con));
}
// Now we just evaluate and expect to find a closure.
Value fn_val = Eval(call->op);
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
CHECK(cvn) << "need to be a constructor for match";
CHECK_NE(op->constructor->tag, -1);
- CHECK_NE(cvn->constructor->tag, -1);
- if (op->constructor->tag == cvn->constructor->tag) {
- // todo(M.K.): should use ptr equality but it is broken
+ CHECK_NE(cvn->tag, -1);
+ if (op->constructor->tag == cvn->tag) {
CHECK_EQ(op->patterns.size(), cvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
ConstTensorShapeMap const_tensor_shape_map;
// List of lowered functions
std::vector<LoweredFunc> lowered_funcs;
+ // The functions that have been lowered.
+ std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
};
// Compute the constant pool, i.e a mapping from Constant node to constant index.
size_t registers_num;
CompileEngine engine;
- /*! \brief The functions that have been lowered. */
- std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
-
/*! \brief Global shared meta data */
VMCompilerContext* context;
void VisitExpr_(const MatchNode* match_node) {
auto match = GetRef<Match>(match_node);
- LOG(FATAL) << "translation of match nodes to the VM is"
+ LOG(FATAL) << "translation of match nodes to the VM is "
<< "currently unsupported" << std::endl;
}
}
void VisitExpr_(const GlobalVarNode* gvar) {
- LOG(FATAL) << "Global variables should only appear in the call position";
+ // TODO(wweic): Support Load GlobalVar into a register
+ LOG(FATAL) << "Loading GlobalVar into register is not yet supported";
}
void VisitExpr_(const IfNode* if_node) {
// TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1;
- if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) {
+ if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) {
op_index = this->context->lowered_funcs.size();
this->context->lowered_funcs.push_back(cfunc->funcs[0]);
- seen_funcs[cfunc->funcs[0]] = op_index;
+ this->context->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
- op_index = seen_funcs[cfunc->funcs[0]];
+ op_index = this->context->seen_funcs[cfunc->funcs[0]];
}
Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
std::vector<Index> args_registers;
for (auto arg : call_node->args) {
- CHECK(arg.as<VarNode>()) << "found: " << AsText(arg, false) << std::endl << arg;
this->VisitExpr(arg);
args_registers.push_back(last_register);
}
auto func = this->context->module->Lookup(global);
if (IsClosure(func)) {
auto arity = func->params.size();
- std::vector<Index> free_var_registers;
- for (size_t i = 0; i < arity; ++i) {
- free_var_registers.push_back(var_register_map.at(func->params[i]));
- }
- Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister()));
+ Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
} else {
Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
}
} else if (auto constructor_node = op.as<ConstructorNode>()) {
auto constructor = GetRef<Constructor>(constructor_node);
- auto tag = GetConstructorTag(constructor);
- Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister()));
+ Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers,
+ NewRegister()));
} else if (auto var_node = op.as<VarNode>()) {
VisitExpr(GetRef<Var>(var_node));
Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
}
}
- size_t GetConstructorTag(tvm::relay::Constructor constructor) {
- auto it = this->context->tag_map.find(constructor);
- if (it != this->context->tag_map.end()) {
- return it->second;
- } else {
- auto tag = this->context->tag_map.size();
- this->context->tag_map[constructor] = tag;
- this->context->tag_index_map[tag] = constructor;
- return tag;
- }
- }
-
void VisitExpr_(const FunctionNode* func_node) {
if (!func_node->IsPrimitive()) {
LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
}
VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
- DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl;
+ DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false) << std::endl;
size_t params = func->params.size();
VMCompiler compiler(context);
compiler.Compile(func);
return res;
}
-Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) {
- CHECK(module.defined() && type.defined());
+Value VMToValue(const relay::Module& module, Object obj) {
+ CHECK(module.defined());
switch (obj->tag) {
case ObjectTag::kTensor: {
- CHECK(type.as<TensorTypeNode>()) << "VM internal error: return value must be a tensor";
return TensorValueNode::make(ToNDArray(obj));
}
case ObjectTag::kDatatype: {
- // const auto* tuple_type
- // const auto& data_type = obj.AsDatatype();
+ const auto& data_type = obj.AsDatatype();
- // tvm::Array<Value> fields;
- // for (size_t i = 0; i < data_type->fields.size(); ++i) {
- // fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
- // }
+ tvm::Array<Value> fields;
+ for (size_t i = 0; i < data_type->fields.size(); ++i) {
+ fields.push_back(VMToValue(module, data_type->fields[i]));
+ }
- // return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields);
- LOG(FATAL) << "fix me";
+ return ConstructorValueNode::make(data_type->tag, fields);
}
default:
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
LOG(FATAL) << "expected function or module";
}
- auto return_type = module->Lookup(module->entry_func)->ret_type;
-
std::vector<Object> vm_args;
for (auto i = 3; i < args.size(); i++) {
Object obj = args[i];
auto result = EvaluateModule(module, {ctx}, vm_args);
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
- *ret = VMToValue(module, return_type, result);
+ *ret = VMToValue(module, result);
});
} // namespace vm
Module updated_mod = mod;
// Execute the pass function and return a new module.
std::vector<std::pair<GlobalVar, Function> > updates;
- for (const auto& it : mod->functions) {
+ auto original = mod->functions;
+ for (const auto& it : original) {
auto updated_func = SkipFunction(it.second)
? it.second
: pass_func(it.second, updated_mod, pass_ctx);
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay import testing, create_executor
from tvm.relay.prelude import Prelude
-from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
+from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
mod = relay.Module()
p = Prelude(mod)
add_nat_definitions(p)
+def count(e):
+ return count_(p, e)
+
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
val = l
ret = []
while True:
- if val.constructor.name_hint == 'cons':
+ if val.tag == p.cons.tag:
ret.append(val.fields[0])
val = val.fields[1]
else:
- assert val.constructor.name_hint == 'nil'
+ assert val.tag == p.nil.tag
break
return ret
def tree_to_dict(t):
assert isinstance(t, ConstructorValue)
ret = {}
- assert t.constructor.name_hint == 'rose'
+ assert t.tag == p.rose.tag
ret['member'] = t.fields[0]
ret['children'] = []
for subtree in to_list(t.fields[1]):
prelude = relay.prelude.Prelude(mod)
intrp = create_executor("debug", mod)
- nil_value = ConstructorValue(prelude.nil, [], [])
- cons_value = ConstructorValue(prelude.cons, [
+ nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, [])
+ cons_value = ConstructorValue(prelude.cons.tag, [
TensorValue(np.random.rand(1, 10).astype('float32')),
nil_value
- ], [relay.TensorType((1, 10), 'float32')])
+ ], prelude.cons, [relay.TensorType((1, 10), 'float32')])
ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[
id_func = intrp.evaluate(prelude.id)
res_nil = id_func(nil_value)
- assert res_nil.constructor == nil_value.constructor
+ assert res_nil.tag == nil_value.tag
assert len(res_nil.fields) == 0
res_cons = id_func(cons_value)
- assert res_cons.constructor == cons_value.constructor
+ assert res_cons.tag == cons_value.tag
assert len(res_cons.fields) == len(cons_value.fields)
tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
cons_value.fields[0].asnumpy())
assert isinstance(res_cons.fields[1], ConstructorValue)
- assert res_cons.fields[1].constructor == prelude.nil
+ assert res_cons.fields[1].tag == prelude.nil.tag
assert len(res_cons.fields[1].fields) == 0
res_ref = id_func(ref_value)
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
- assert count(intrp.evaluate(add(s(z()), s(z())))) == 2
- assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
+ assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
+ assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext()
result = veval(f, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data)
-@nottest
def test_list_constructor():
- # TODO(wweic): implement pattern match to support this test
def to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy().tolist()]
cons = p.cons
l = p.l
+ # remove all functions to not have pattern match to pass vm compilation
+ # TODO(wweic): remove the hack and implement pattern match
+ for v, _ in mod.functions.items():
+ mod[v] = relay.const(0)
+
one2 = cons(relay.const(1), nil())
one3 = cons(relay.const(2), one2)
one4 = cons(relay.const(3), one3)
result = veval(mod)()
obj = to_list(result)
- import pdb; pdb.set_trace()
tvm.testing.assert_allclose(obj, np.array([3,2,1]))
def test_let_tensor():