Support export ADT value in Python (#3299)
authorWei Chen <ipondering.weic@gmail.com>
Thu, 13 Jun 2019 01:21:19 +0000 (09:21 +0800)
committerJared Roesch <roeschinc@gmail.com>
Thu, 13 Jun 2019 01:21:19 +0000 (18:21 -0700)
* Support export ADT value in Python

* Cache original functions

* Cleanup

* Cleanup

13 files changed:
include/tvm/relay/interpreter.h
python/tvm/relay/backend/interpreter.py
python/tvm/relay/backend/vm.py
python/tvm/relay/prelude.py
python/tvm/relay/testing/nat.py
src/relay/backend/interpreter.cc
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/vm.cc
src/relay/pass/pass_manager.cc
tests/python/relay/test_adt.py
tests/python/relay/test_backend_interpreter.py
tests/python/relay/test_pass_to_a_normal_form.py
tests/python/relay/test_vm.py

index 15c96bb..68b7cca 100644 (file)
@@ -182,17 +182,22 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
 class ConstructorValue;
 
 struct ConstructorValueNode : ValueNode {
-  Constructor constructor;
+  int tag;
 
   tvm::Array<Value> fields;
 
+  /*! \brief Optional field tracking ADT constructor. */
+  Constructor constructor;
+
   void VisitAttrs(tvm::AttrVisitor* v) final {
-    v->Visit("constructor", &constructor);
+    v->Visit("tag", &tag);
     v->Visit("fields", &fields);
+    v->Visit("constructor", &constructor);
   }
 
-  TVM_DLL static ConstructorValue make(Constructor constructor,
-                                       tvm::Array<Value> fields);
+  TVM_DLL static ConstructorValue make(int tag,
+                                       tvm::Array<Value> fields,
+                                       Constructor construtor = {});
 
   static constexpr const char* _type_key = "relay.ConstructorValue";
   TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
index 593cf7c..ea25b97 100644 (file)
@@ -73,9 +73,9 @@ class Closure(Value):
 
 @register_relay_node
 class ConstructorValue(Value):
-    def __init__(self, constructor, fields, types):
+    def __init__(self, tag, fields, constructor, types):
         self.__init_handle_by_constructor__(
-            _make.ConstructorValue, constructor, fields, types)
+            _make.ConstructorValue, tag, fields, constructor, types)
 
 
 @register_relay_node
index 3b9946a..4cb3d61 100644 (file)
@@ -97,7 +97,6 @@ def _eval_vm(mod, ctx, *args):
     args: List[tvm.NDArray, np.ndarray]
         The arguments to evaluate.
     """
-
     mod = optimize(mod)
     args = list(args)
     assert isinstance(args, list)
index c801e49..da75b9d 100644 (file)
@@ -491,7 +491,6 @@ class Prelude:
     def __init__(self, mod):
         self.mod = mod
         self.load_prelude()
-
         self.define_list_adt()
         self.define_list_hd()
         self.define_list_tl()
index 4c0c87c..a76a340 100644 (file)
@@ -151,16 +151,16 @@ def add_nat_definitions(prelude):
 # helper functions for working with nats
 
 
-def count(n):
+def count(prelude, n):
     """Takes a ConstructorValue corresponding to a nat ADT
     and converts it into a Python integer. This is an example of
     using an ADT value in Python.
     """
     assert isinstance(n, ConstructorValue)
-    if n.constructor.name_hint == 'z':
+    if n.tag == prelude.z.tag:
         return 0
-    assert n.constructor.name_hint == 's'
-    return 1 + count(n.fields[0])
+    assert n.tag == prelude.s.tag
+    return 1 + count(prelude, n.fields[0])
 
 
 def make_nat_value(prelude, n):
@@ -168,8 +168,8 @@ def make_nat_value(prelude, n):
     constructs a ConstructorValue representing that value as a nat.
     """
     if n == 0:
-        return ConstructorValue(prelude.z, [], [])
-    return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], [])
+        return ConstructorValue(prelude.z.tag, [], None, [])
+    return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, [])
 
 
 def make_nat_expr(prelude, n):
index d700c20..1cc81d5 100644 (file)
@@ -103,11 +103,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
                               p->stream << "RefValueNode(" << node->value << ")";
                             });
 
-ConstructorValue ConstructorValueNode::make(Constructor constructor,
-                                            tvm::Array<Value> fields) {
+ConstructorValue ConstructorValueNode::make(int tag,
+                                            tvm::Array<Value> fields,
+                                            Constructor constructor) {
   NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
-  n->constructor = constructor;
+  n->tag = tag;
   n->fields = fields;
+  n->constructor = constructor;
   return ConstructorValue(n);
 }
 
@@ -117,7 +119,7 @@ TVM_REGISTER_API("relay._make.ConstructorValue")
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
                                        tvm::IRPrinter* p) {
-  p->stream << "ConstructorValueNode(" << node->constructor
+  p->stream << "ConstructorValueNode(" << node->tag << ","
             << node->fields << ")";
 });
 
@@ -448,7 +450,7 @@ class Interpreter :
                     "fusing and lowering";
     }
     if (auto con = call->op.as<ConstructorNode>()) {
-      return ConstructorValueNode::make(GetRef<Constructor>(con), args);
+      return ConstructorValueNode::make(con->tag, args, GetRef<Constructor>(con));
     }
     // Now we just evaluate and expect to find a closure.
     Value fn_val = Eval(call->op);
@@ -544,9 +546,8 @@ class Interpreter :
     const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
     CHECK(cvn) << "need to be a constructor for match";
     CHECK_NE(op->constructor->tag, -1);
-    CHECK_NE(cvn->constructor->tag, -1);
-    if (op->constructor->tag == cvn->constructor->tag) {
-      // todo(M.K.): should use ptr equality but it is broken
+    CHECK_NE(cvn->tag, -1);
+    if (op->constructor->tag == cvn->tag) {
       CHECK_EQ(op->patterns.size(), cvn->fields.size());
       for (size_t i = 0; i < op->patterns.size(); ++i) {
         if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
index 07633fc..9b4ab6b 100644 (file)
@@ -80,6 +80,8 @@ struct VMCompilerContext {
   ConstTensorShapeMap const_tensor_shape_map;
   // List of lowered functions
   std::vector<LoweredFunc> lowered_funcs;
+  // The functions that have been lowered.
+  std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
 };
 
 // Compute the constant pool, i.e a mapping from Constant node to constant index.
@@ -184,9 +186,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
   size_t registers_num;
   CompileEngine engine;
 
-  /*! \brief The functions that have been lowered. */
-  std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
-
   /*! \brief Global shared meta data */
   VMCompilerContext* context;
 
@@ -260,7 +259,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
 
   void VisitExpr_(const MatchNode* match_node) {
     auto match = GetRef<Match>(match_node);
-    LOG(FATAL) << "translation of match nodes to the VM is"
+    LOG(FATAL) << "translation of match nodes to the VM is "
                << "currently unsupported" << std::endl;
   }
 
@@ -280,7 +279,8 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
   }
 
   void VisitExpr_(const GlobalVarNode* gvar) {
-    LOG(FATAL) << "Global variables should only appear in the call position";
+    // TODO(wweic): Support Load GlobalVar into a register
+    LOG(FATAL) << "Loading GlobalVar into register is not yet supported";
   }
 
   void VisitExpr_(const IfNode* if_node) {
@@ -405,12 +405,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
     // TODO(jroesch): support lowered funcs for multiple targets
     CHECK_EQ(cfunc->funcs.size(), 1);
     auto op_index = -1;
-    if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) {
+    if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) {
       op_index = this->context->lowered_funcs.size();
       this->context->lowered_funcs.push_back(cfunc->funcs[0]);
-      seen_funcs[cfunc->funcs[0]] = op_index;
+      this->context->seen_funcs[cfunc->funcs[0]] = op_index;
     } else {
-      op_index = seen_funcs[cfunc->funcs[0]];
+      op_index = this->context->seen_funcs[cfunc->funcs[0]];
     }
 
     Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
@@ -429,7 +429,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
     std::vector<Index> args_registers;
 
     for (auto arg : call_node->args) {
-      CHECK(arg.as<VarNode>()) << "found: " << AsText(arg, false) << std::endl << arg;
       this->VisitExpr(arg);
       args_registers.push_back(last_register);
     }
@@ -449,18 +448,14 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
       auto func = this->context->module->Lookup(global);
       if (IsClosure(func)) {
         auto arity = func->params.size();
-        std::vector<Index> free_var_registers;
-        for (size_t i = 0; i < arity; ++i) {
-          free_var_registers.push_back(var_register_map.at(func->params[i]));
-        }
-        Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister()));
+        Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
       } else {
         Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
       }
     } else if (auto constructor_node = op.as<ConstructorNode>()) {
       auto constructor = GetRef<Constructor>(constructor_node);
-      auto tag = GetConstructorTag(constructor);
-      Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister()));
+      Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers,
+                                      NewRegister()));
     } else if (auto var_node = op.as<VarNode>()) {
       VisitExpr(GetRef<Var>(var_node));
       Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
@@ -469,18 +464,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
     }
   }
 
-  size_t GetConstructorTag(tvm::relay::Constructor constructor) {
-    auto it = this->context->tag_map.find(constructor);
-    if (it != this->context->tag_map.end()) {
-      return it->second;
-    } else {
-      auto tag = this->context->tag_map.size();
-      this->context->tag_map[constructor] = tag;
-      this->context->tag_index_map[tag] = constructor;
-      return tag;
-    }
-  }
-
   void VisitExpr_(const FunctionNode* func_node) {
     if (!func_node->IsPrimitive()) {
       LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
@@ -549,7 +532,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
 }
 
 VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
-  DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl;
+  DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false) << std::endl;
   size_t params = func->params.size();
   VMCompiler compiler(context);
   compiler.Compile(func);
index 34d067b..cf0b952 100644 (file)
@@ -63,24 +63,21 @@ Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
   return res;
 }
 
-Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) {
-  CHECK(module.defined() && type.defined());
+Value VMToValue(const relay::Module& module, Object obj) {
+  CHECK(module.defined());
   switch (obj->tag) {
     case ObjectTag::kTensor: {
-      CHECK(type.as<TensorTypeNode>()) << "VM internal error: return value must be a tensor";
       return TensorValueNode::make(ToNDArray(obj));
     }
     case ObjectTag::kDatatype: {
-      // const auto* tuple_type
-      // const auto& data_type = obj.AsDatatype();
+      const auto& data_type = obj.AsDatatype();
 
-      // tvm::Array<Value> fields;
-      // for (size_t i = 0; i < data_type->fields.size(); ++i) {
-      //   fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
-      // }
+      tvm::Array<Value> fields;
+      for (size_t i = 0; i < data_type->fields.size(); ++i) {
+        fields.push_back(VMToValue(module, data_type->fields[i]));
+      }
 
-      // return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields);
-      LOG(FATAL) << "fix me";
+      return ConstructorValueNode::make(data_type->tag, fields);
     }
     default:
       LOG(FATAL) << "unsupported return value of type: " << obj->tag;
@@ -141,8 +138,6 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue
     LOG(FATAL) << "expected function or module";
   }
 
-  auto return_type = module->Lookup(module->entry_func)->ret_type;
-
   std::vector<Object> vm_args;
   for (auto i = 3; i < args.size(); i++) {
     Object obj = args[i];
@@ -151,7 +146,7 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue
 
   auto result = EvaluateModule(module, {ctx}, vm_args);
   DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
-  *ret = VMToValue(module, return_type, result);
+  *ret = VMToValue(module, result);
 });
 
 }  // namespace vm
index 500bdce..fa79a5e 100644 (file)
@@ -316,7 +316,8 @@ Module FunctionPassNode::operator()(const Module& mod,
   Module updated_mod = mod;
   // Execute the pass function and return a new module.
   std::vector<std::pair<GlobalVar, Function> > updates;
-  for (const auto& it : mod->functions) {
+  auto original = mod->functions;
+  for (const auto& it : original) {
     auto updated_func = SkipFunction(it.second)
                             ? it.second
                             : pass_func(it.second, updated_mod, pass_ctx);
index 77f4ab1..f3a08a8 100644 (file)
@@ -21,12 +21,15 @@ from tvm.relay.ir_pass import infer_type
 from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
 from tvm.relay import testing, create_executor
 from tvm.relay.prelude import Prelude
-from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
+from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
 
 mod = relay.Module()
 p = Prelude(mod)
 add_nat_definitions(p)
 
+def count(e):
+    return count_(p, e)
+
 ctx = tvm.context("llvm", 0)
 intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
 
@@ -91,18 +94,18 @@ def to_list(l):
     val = l
     ret = []
     while True:
-        if val.constructor.name_hint == 'cons':
+        if val.tag == p.cons.tag:
             ret.append(val.fields[0])
             val = val.fields[1]
         else:
-            assert val.constructor.name_hint == 'nil'
+            assert val.tag == p.nil.tag
             break
     return ret
 
 def tree_to_dict(t):
     assert isinstance(t, ConstructorValue)
     ret = {}
-    assert t.constructor.name_hint == 'rose'
+    assert t.tag == p.rose.tag
     ret['member'] = t.fields[0]
     ret['children'] = []
     for subtree in to_list(t.fields[1]):
index 1e5e231..11ce11e 100644 (file)
@@ -183,11 +183,11 @@ def test_function_taking_adt_ref_tuple():
     prelude = relay.prelude.Prelude(mod)
     intrp = create_executor("debug", mod)
 
-    nil_value = ConstructorValue(prelude.nil, [], [])
-    cons_value = ConstructorValue(prelude.cons, [
+    nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, [])
+    cons_value = ConstructorValue(prelude.cons.tag, [
         TensorValue(np.random.rand(1, 10).astype('float32')),
         nil_value
-    ], [relay.TensorType((1, 10), 'float32')])
+    ], prelude.cons, [relay.TensorType((1, 10), 'float32')])
 
     ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
     tuple_value = TupleValue(*[
@@ -197,16 +197,16 @@ def test_function_taking_adt_ref_tuple():
     id_func = intrp.evaluate(prelude.id)
 
     res_nil = id_func(nil_value)
-    assert res_nil.constructor == nil_value.constructor
+    assert res_nil.tag == nil_value.tag
     assert len(res_nil.fields) == 0
 
     res_cons = id_func(cons_value)
-    assert res_cons.constructor == cons_value.constructor
+    assert res_cons.tag == cons_value.tag
     assert len(res_cons.fields) == len(cons_value.fields)
     tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
                                 cons_value.fields[0].asnumpy())
     assert isinstance(res_cons.fields[1], ConstructorValue)
-    assert res_cons.fields[1].constructor == prelude.nil
+    assert res_cons.fields[1].tag == prelude.nil.tag
     assert len(res_cons.fields[1].fields) == 0
 
     res_ref = id_func(ref_value)
index f395580..db40c86 100644 (file)
@@ -142,8 +142,8 @@ def test_nat_add():
     ctx = tvm.context("llvm", 0)
     intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
     assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
-    assert count(intrp.evaluate(add(s(z()), s(z())))) == 2
-    assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
+    assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
+    assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
     assert "let" in mod[add].astext()
 
 
index d727e77..12e343b 100644 (file)
@@ -185,9 +185,7 @@ def test_tuple_second():
     result = veval(f, (i_data, j_data))
     tvm.testing.assert_allclose(result.asnumpy(), j_data)
 
-@nottest
 def test_list_constructor():
-    # TODO(wweic): implement pattern match to support this test
     def to_list(o):
         if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
             return [o.data.asnumpy().tolist()]
@@ -204,6 +202,11 @@ def test_list_constructor():
     cons = p.cons
     l = p.l
 
+    # remove all functions to not have pattern match to pass vm compilation
+    # TODO(wweic): remove the hack and implement pattern match
+    for v, _ in mod.functions.items():
+        mod[v] = relay.const(0)
+
     one2 = cons(relay.const(1), nil())
     one3 = cons(relay.const(2), one2)
     one4 = cons(relay.const(3), one3)
@@ -213,7 +216,6 @@ def test_list_constructor():
 
     result = veval(mod)()
     obj = to_list(result)
-    import pdb; pdb.set_trace()
     tvm.testing.assert_allclose(obj, np.array([3,2,1]))
 
 def test_let_tensor():