[Relay][Module] Make tags for ADT constructors and ConstructorValues more robust...
authorSteven S. Lyubomirsky <sslyu@cs.washington.edu>
Fri, 5 Jul 2019 20:58:16 +0000 (13:58 -0700)
committerJared Roesch <roeschinc@gmail.com>
Fri, 5 Jul 2019 20:58:16 +0000 (13:58 -0700)
* Use hash of ADT name and constructor idx to generate tag, add reverse mapping to module and use where appropriate

* Lint and build fixes

* Add round-tripping test for getting constructors by tag

* Use int64_t everywhere for tags

* Add additional identity check

* Bring out _arg_to_ast again

* Use 8-bit hash of GTV name as MSB of tag, index as LSB for more readable tags

* Use int32 instead of int64 for tag

include/tvm/relay/adt.h
include/tvm/relay/interpreter.h
include/tvm/relay/module.h
python/tvm/relay/backend/interpreter.py
python/tvm/relay/module.py
src/relay/backend/interpreter.cc
src/relay/ir/module.cc
tests/python/relay/test_ir_module.py [new file with mode: 0644]

index 9e4e00c..2a6507b 100644 (file)
@@ -114,7 +114,7 @@ class ConstructorNode : public ExprNode {
   /*! \brief The datatype the constructor will construct. */
   GlobalTypeVar belong_to;
   /*! \brief Index in the table of constructors (set when the type is registered). */
-  mutable int tag = -1;
+  mutable int32_t tag = -1;
 
   ConstructorNode() {}
 
index 68b7cca..d05099f 100644 (file)
@@ -182,7 +182,7 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
 class ConstructorValue;
 
 struct ConstructorValueNode : ValueNode {
-  int tag;
+  int32_t tag;
 
   tvm::Array<Value> fields;
 
@@ -195,7 +195,7 @@ struct ConstructorValueNode : ValueNode {
     v->Visit("constructor", &constructor);
   }
 
-  TVM_DLL static ConstructorValue make(int tag,
+  TVM_DLL static ConstructorValue make(int32_t tag,
                                        tvm::Array<Value> fields,
                                        Constructor construtor = {});
 
index 4a3ff0b..389f0c1 100644 (file)
@@ -32,6 +32,7 @@
 #include <tvm/relay/type.h>
 #include <string>
 #include <vector>
+#include <unordered_map>
 
 namespace tvm {
 namespace relay {
@@ -133,34 +134,41 @@ class ModuleNode : public RelayNode {
   TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
 
   /*!
-   * \brief Lookup a global function by its variable.
+   * \brief Look up a global function by its variable.
    * \param var The global var to lookup.
    * \returns The function named by the variable argument.
    */
   TVM_DLL Function Lookup(const GlobalVar& var) const;
 
   /*!
-   * \brief Lookup a global function by its string name
+   * \brief Look up a global function by its string name
    * \param name The name of the function.
    * \returns The function named by the argument.
    */
   TVM_DLL Function Lookup(const std::string& name) const;
 
   /*!
-   * \brief Lookup a global type definition by its variable.
+   * \brief Look up a global type definition by its variable.
    * \param var The var of the global type definition.
    * \return The type definition.
    */
   TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const;
 
   /*!
-   * \brief Lookup a global type definition by its name.
+   * \brief Look up a global type definition by its name.
    * \param var The name of the global type definition.
    * \return The type definition.
    */
   TVM_DLL TypeData LookupDef(const std::string& var) const;
 
   /*!
+   * \brief Look up a constructor by its tag.
+   * \param tag The tag for the constructor.
+   * \return The constructor object.
+   */
+  TVM_DLL Constructor LookupTag(const int32_t tag);
+
+  /*!
    * \brief Update the functions inside this environment by
    *        functions in another environment.
    * \param other The other environment.
@@ -185,6 +193,9 @@ class ModuleNode : public RelayNode {
   TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
 
  private:
+  /*! \brief Helper function for registering a typedef's constructors */
+  void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);
+
   /*! \brief A map from string names to global variables that
    * ensures global uniqueness.
    */
@@ -194,6 +205,11 @@ class ModuleNode : public RelayNode {
    * that ensures global uniqueness.
    */
   tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;
+
+  /*! \brief A map from constructor tags to constructor objects
+   * for convenient access
+   */
+  std::unordered_map<int32_t, Constructor> constructor_tag_map_;
 };
 
 struct Module : public NodeRef {
index 5b7d9ed..5ca09b0 100644 (file)
@@ -114,17 +114,18 @@ class RefValue(Value):
             _make.RefValue, value)
 
 
-def _arg_to_ast(arg):
+def _arg_to_ast(mod, arg):
     if isinstance(arg, TensorValue):
         return Constant(arg.data.copyto(nd.cpu(0)))
     elif isinstance(arg, TupleValue):
-        return Tuple([_arg_to_ast(field) for field in arg.fields])
+        return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
     elif isinstance(arg, tuple):
-        return Tuple([_arg_to_ast(field) for field in arg])
+        return Tuple([_arg_to_ast(mod, field) for field in arg])
     elif isinstance(arg, RefValue):
-        return RefCreate(_arg_to_ast(arg.value))
+        return RefCreate(_arg_to_ast(mod, arg.value))
     elif isinstance(arg, ConstructorValue):
-        return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields])
+        return Call(mod.get_constructor(arg.tag),
+                    [_arg_to_ast(mod, field) for field in arg.fields])
     elif isinstance(arg, np.ndarray):
         return Constant(nd.array(arg))
     elif isinstance(arg, Constant):
@@ -231,7 +232,7 @@ class Executor(object):
         if binds:
             scope_builder = ScopeBuilder()
             for key, value in binds.items():
-                scope_builder.let(key, _arg_to_ast(value))
+                scope_builder.let(key, _arg_to_ast(self.mod, value))
             scope_builder.ret(expr)
             expr = scope_builder.get()
 
@@ -294,7 +295,7 @@ class Interpreter(Executor):
 
             relay_args = []
             for arg in args:
-                relay_args.append(_arg_to_ast(arg))
+                relay_args.append(_arg_to_ast(self.mod, arg))
 
             # Set the entry function for the module.
             if expr is None:
index 097dbbb..aeeedb8 100644 (file)
@@ -156,6 +156,25 @@ class Module(RelayNode):
         """
         return _module.Module_GetGlobalTypeVar(self, name)
 
+    def get_constructor(self, tag):
+        """Look up an ADT constructor by tag.
+
+        Parameters
+        ----------
+        tag: int
+            The tag for a constructor.
+
+        Returns
+        -------
+        constructor: Constructor
+           The constructor associated with the given tag,
+
+        Raises
+        ------
+        tvm.TVMError if the corresponding constructor cannot be found.
+        """
+        return _module.Module_LookupTag(self, tag)
+
     @staticmethod
     def from_expr(expr):
         return _module.Module_FromExpr(expr)
index 7c97bef..913d7ad 100644 (file)
@@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
                               p->stream << "RefValueNode(" << node->value << ")";
                             });
 
-ConstructorValue ConstructorValueNode::make(int tag,
+ConstructorValue ConstructorValueNode::make(int32_t tag,
                                             tvm::Array<Value> fields,
                                             Constructor constructor) {
   NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
index 4286be2..a616f5e 100644 (file)
@@ -53,6 +53,7 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
     CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
       << "Duplicate global type definition name " << kv.first->var->name_hint;
     n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
+    n->RegisterConstructors(kv.first, kv.second);
   }
 
   return Module(n);
@@ -108,15 +109,25 @@ void ModuleNode::Add(const GlobalVar& var,
   AddUnchecked(var, checked_func);
 }
 
+void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
+  // We hash the global type var name to use as a globally unique prefix for tags.
+  // The hash will be used as the most significant byte of the tag, with the index of
+  // the constructor in the less significant bytes
+  size_t hash = std::hash<std::string>()(var->var->name_hint);
+  int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
+  for (size_t i = 0; i < type->constructors.size(); ++i) {
+    type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
+    constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
+  }
+}
+
 void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
   this->type_definitions.Set(var, type);
   // set global type var map
   CHECK(!global_type_var_map_.count(var->var->name_hint))
     << "Duplicate global type definition name " << var->var->name_hint;
   global_type_var_map_.Set(var->var->name_hint, var);
-  for (size_t i = 0; i < type->constructors.size(); ++i) {
-    type->constructors[i]->tag = i;
-  }
+  RegisterConstructors(var, type);
 
   // need to kind check at the end because the check can look up
   // a definition potentially
@@ -159,6 +170,13 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
   return this->LookupDef(id);
 }
 
+Constructor ModuleNode::LookupTag(const int32_t tag) {
+  auto it = constructor_tag_map_.find(tag);
+  CHECK(it != constructor_tag_map_.end())
+    << "There is no constructor with the tag " << tag;
+  return (*it).second;
+}
+
 void ModuleNode::Update(const Module& mod) {
   for (auto pair : mod->functions) {
     this->Update(pair.first, pair.second);
@@ -236,6 +254,11 @@ TVM_REGISTER_API("relay._module.Module_LookupDef_str")
   return mod->LookupDef(var);
 });
 
+TVM_REGISTER_API("relay._module.Module_LookupTag")
+.set_body_typed<Constructor(Module, int32_t)>([](Module mod, int32_t tag) {
+    return mod->LookupTag(tag);
+  });
+
 TVM_REGISTER_API("relay._module.Module_FromExpr")
 .set_body_typed<Module(Expr)>([](Expr e) {
   return ModuleNode::FromExpr(e);
diff --git a/tests/python/relay/test_ir_module.py b/tests/python/relay/test_ir_module.py
new file mode 100644 (file)
index 0000000..72a92c8
--- /dev/null
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for module functionality."""
+import tvm
+from tvm import relay
+from tvm.relay import Module
+from tvm.relay.prelude import Prelude
+from tvm.relay.testing import add_nat_definitions
+
+def constructor_list(p):
+    return [p.nil, p.cons, p.rose, p.some, p.none, p.z, p.s]
+
+
+def adt_list(p):
+    return [p.nat, p.l, p.optional, p.tree]
+
+
+def test_constructor_tag_round_trip():
+    mod1 = Module()
+    p1 = Prelude(mod1)
+    add_nat_definitions(p1)
+    mod2 = Module()
+    p2 = Prelude(mod2)
+    add_nat_definitions(p2)
+
+    # ensure hashes match across modules
+    ctors1 = constructor_list(p1)
+    ctors2 = constructor_list(p2)
+
+    for i in range(len(ctors1)):
+        tag = ctors1[i].tag
+        ctor = mod2.get_constructor(tag)
+        assert ctor == ctors2[i]
+        assert ctor.name_hint == ctors1[i].name_hint
+
+
+def test_constructor_tag_differences():
+    # ensure that if we have the type data for a given ADT, the tags
+    # for the constructors of the *same ADT* are simple offsets from
+    # each other
+    mod = Module()
+    p = Prelude(mod)
+    add_nat_definitions(p)
+
+    adts = adt_list(p)
+    for adt in adts:
+        data = mod[adt]
+        for i in range(len(data.constructors) - 1):
+            ctor1 = data.constructors[i]
+            ctor2 = data.constructors[i + 1]
+            assert ctor2.tag - ctor1.tag == 1
+            # make sure there is something present at the MSB
+            assert ctor1.tag - i != 0
+            assert ctor2.tag - (i + 1) != 0