* Use hash of ADT name and constructor idx to generate tag, add reverse mapping to module and use where appropriate
* Lint and build fixes
* Add round-tripping test for getting constructors by tag
* Use int64_t everywhere for tags
* Add additional identity check
* Bring out _arg_to_ast again
* Use 8-bit hash of GTV name as MSB of tag, index as LSB for more readable tags
* Use int32 instead of int64 for tag
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
- mutable int tag = -1;
+ mutable int32_t tag = -1;
ConstructorNode() {}
class ConstructorValue;
struct ConstructorValueNode : ValueNode {
- int tag;
+ int32_t tag;
tvm::Array<Value> fields;
v->Visit("constructor", &constructor);
}
- TVM_DLL static ConstructorValue make(int tag,
+ TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<Value> fields,
Constructor construtor = {});
#include <tvm/relay/type.h>
#include <string>
#include <vector>
+#include <unordered_map>
namespace tvm {
namespace relay {
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
/*!
- * \brief Lookup a global function by its variable.
+ * \brief Look up a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
TVM_DLL Function Lookup(const GlobalVar& var) const;
/*!
- * \brief Lookup a global function by its string name
+ * \brief Look up a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL Function Lookup(const std::string& name) const;
/*!
- * \brief Lookup a global type definition by its variable.
+ * \brief Look up a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const;
/*!
- * \brief Lookup a global type definition by its name.
+ * \brief Look up a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const std::string& var) const;
/*!
+ * \brief Look up a constructor by its tag.
+ * \param tag The tag for the constructor.
+ * \return The constructor object.
+ */
+ TVM_DLL Constructor LookupTag(const int32_t tag);
+
+ /*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
private:
+ /*! \brief Helper function for registering a typedef's constructors */
+ void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);
+
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
* that ensures global uniqueness.
*/
tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;
+
+ /*! \brief A map from constructor tags to constructor objects
+ * for convenient access
+ */
+ std::unordered_map<int32_t, Constructor> constructor_tag_map_;
};
struct Module : public NodeRef {
_make.RefValue, value)
-def _arg_to_ast(arg):
+def _arg_to_ast(mod, arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
- return Tuple([_arg_to_ast(field) for field in arg.fields])
+ return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, tuple):
- return Tuple([_arg_to_ast(field) for field in arg])
+ return Tuple([_arg_to_ast(mod, field) for field in arg])
elif isinstance(arg, RefValue):
- return RefCreate(_arg_to_ast(arg.value))
+ return RefCreate(_arg_to_ast(mod, arg.value))
elif isinstance(arg, ConstructorValue):
- return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields])
+ return Call(mod.get_constructor(arg.tag),
+ [_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
if binds:
scope_builder = ScopeBuilder()
for key, value in binds.items():
- scope_builder.let(key, _arg_to_ast(value))
+ scope_builder.let(key, _arg_to_ast(self.mod, value))
scope_builder.ret(expr)
expr = scope_builder.get()
relay_args = []
for arg in args:
- relay_args.append(_arg_to_ast(arg))
+ relay_args.append(_arg_to_ast(self.mod, arg))
# Set the entry function for the module.
if expr is None:
"""
return _module.Module_GetGlobalTypeVar(self, name)
+ def get_constructor(self, tag):
+ """Look up an ADT constructor by tag.
+
+ Parameters
+ ----------
+ tag: int
+ The tag for a constructor.
+
+ Returns
+ -------
+ constructor: Constructor
+ The constructor associated with the given tag,
+
+ Raises
+ ------
+ tvm.TVMError if the corresponding constructor cannot be found.
+ """
+ return _module.Module_LookupTag(self, tag)
+
@staticmethod
def from_expr(expr):
return _module.Module_FromExpr(expr)
p->stream << "RefValueNode(" << node->value << ")";
});
-ConstructorValue ConstructorValueNode::make(int tag,
+ConstructorValue ConstructorValueNode::make(int32_t tag,
tvm::Array<Value> fields,
Constructor constructor) {
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
<< "Duplicate global type definition name " << kv.first->var->name_hint;
n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
+ n->RegisterConstructors(kv.first, kv.second);
}
return Module(n);
AddUnchecked(var, checked_func);
}
+void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
+ // We hash the global type var name to use as a globally unique prefix for tags.
+ // The hash will be used as the most significant byte of the tag, with the index of
+ // the constructor in the less significant bytes
+ size_t hash = std::hash<std::string>()(var->var->name_hint);
+ int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
+ for (size_t i = 0; i < type->constructors.size(); ++i) {
+ type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
+ constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
+ }
+}
+
void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
this->type_definitions.Set(var, type);
// set global type var map
CHECK(!global_type_var_map_.count(var->var->name_hint))
<< "Duplicate global type definition name " << var->var->name_hint;
global_type_var_map_.Set(var->var->name_hint, var);
- for (size_t i = 0; i < type->constructors.size(); ++i) {
- type->constructors[i]->tag = i;
- }
+ RegisterConstructors(var, type);
// need to kind check at the end because the check can look up
// a definition potentially
return this->LookupDef(id);
}
+Constructor ModuleNode::LookupTag(const int32_t tag) {
+ auto it = constructor_tag_map_.find(tag);
+ CHECK(it != constructor_tag_map_.end())
+ << "There is no constructor with the tag " << tag;
+ return (*it).second;
+}
+
void ModuleNode::Update(const Module& mod) {
for (auto pair : mod->functions) {
this->Update(pair.first, pair.second);
return mod->LookupDef(var);
});
+TVM_REGISTER_API("relay._module.Module_LookupTag")
+.set_body_typed<Constructor(Module, int32_t)>([](Module mod, int32_t tag) {
+ return mod->LookupTag(tag);
+ });
+
TVM_REGISTER_API("relay._module.Module_FromExpr")
.set_body_typed<Module(Expr)>([](Expr e) {
return ModuleNode::FromExpr(e);
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for module functionality."""
+import tvm
+from tvm import relay
+from tvm.relay import Module
+from tvm.relay.prelude import Prelude
+from tvm.relay.testing import add_nat_definitions
+
+def constructor_list(p):
+ return [p.nil, p.cons, p.rose, p.some, p.none, p.z, p.s]
+
+
+def adt_list(p):
+ return [p.nat, p.l, p.optional, p.tree]
+
+
+def test_constructor_tag_round_trip():
+ mod1 = Module()
+ p1 = Prelude(mod1)
+ add_nat_definitions(p1)
+ mod2 = Module()
+ p2 = Prelude(mod2)
+ add_nat_definitions(p2)
+
+ # ensure hashes match across modules
+ ctors1 = constructor_list(p1)
+ ctors2 = constructor_list(p2)
+
+ for i in range(len(ctors1)):
+ tag = ctors1[i].tag
+ ctor = mod2.get_constructor(tag)
+ assert ctor == ctors2[i]
+ assert ctor.name_hint == ctors1[i].name_hint
+
+
+def test_constructor_tag_differences():
+ # ensure that if we have the type data for a given ADT, the tags
+ # for the constructors of the *same ADT* are simple offsets from
+ # each other
+ mod = Module()
+ p = Prelude(mod)
+ add_nat_definitions(p)
+
+ adts = adt_list(p)
+ for adt in adts:
+ data = mod[adt]
+ for i in range(len(data.constructors) - 1):
+ ctor1 = data.constructors[i]
+ ctor2 = data.constructors[i + 1]
+ assert ctor2.tag - ctor1.tag == 1
+ # make sure there is something present at the MSB
+ assert ctor1.tag - i != 0
+ assert ctor2.tag - (i + 1) != 0