"tvm::relay::Span",
"tvm::relay::TempExpr",
"tvm::relay::TensorType",
- "tvm::relay::TensorValue",
"tvm::relay::Tuple",
"tvm::relay::TupleGetItem",
"tvm::relay::TupleType",
- "tvm::relay::TupleValue",
"tvm::relay::Type",
"tvm::relay::TypeCall",
"tvm::relay::TypeConstraint",
* Given a Relay module, and a Relay expression it produces a value.
*
* The interpreter's values are a naive representation of the values that
- * can be produced by a Relay program and are exposed via tvm::Node's
- * system to Python for introspection and debugging.
+ * can be produced by a Relay program and are exposed via TVM's object
+ * protocol to Python for introspection and debugging.
*
* The interpreter's intent is to serve as a reference semantics for the Relay IR,
* as well as for debugging and testing.
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/vm.h>
namespace tvm {
namespace relay {
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(IRModule mod, DLContext context, Target target);
-/*! \brief A Relay closure, i.e a scope and a function. */
-class Closure;
-
-/*! \brief The container type of Closures. */
-class ClosureNode : public Object {
+/*! \brief The container type of Closures used by the interpreter. */
+class InterpreterClosureObj : public runtime::vm::ClosureObj {
public:
/*! \brief The set of free variables in the closure.
*
*/
Function func;
- ClosureNode() {}
+ InterpreterClosureObj() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("env", &env);
v->Visit("func", &func);
}
- TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);
-
- static constexpr const char* _type_key = "relay.Closure";
- TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
+ static constexpr const char* _type_key = "interpreter.Closure";
+ TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::vm::ClosureObj);
};
-class Closure : public ObjectRef {
+class InterpreterClosure : public runtime::vm::Closure {
public:
- TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
+ TVM_DLL InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func);
+ TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure,
+ InterpreterClosureObj);
};
-/*! \brief A Relay Recursive Closure. A closure that has a name. */
-class RecClosure;
-
/*! \brief The container type of RecClosure. */
-class RecClosureNode : public Object {
+class RecClosureObj : public Object {
public:
/*! \brief The closure. */
- Closure clos;
+ InterpreterClosure clos;
/*! \brief variable the closure bind to. */
Var bind;
- RecClosureNode() {}
+ RecClosureObj() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("clos", &clos);
v->Visit("bind", &bind);
}
- TVM_DLL static RecClosure make(Closure clos, Var bind);
-
- static constexpr const char* _type_key = "relay.RecClosure";
- TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
+ static constexpr const char* _type_key = "interpreter.RecClosure";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object);
};
class RecClosure : public ObjectRef {
public:
- TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
-};
-
-/*! \brief A tuple value. */
-class TupleValue;
-
-/*! \brief Tuple (x, ... y). */
-struct TupleValueNode : Object {
- tvm::Array<ObjectRef> fields;
-
- TupleValueNode() {}
-
- void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
-
- TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);
-
- static constexpr const char* _type_key = "relay.TupleValue";
- TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
-};
-
-class TupleValue : public ObjectRef {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
+ TVM_DLL RecClosure(InterpreterClosure clos, Var bind);
+ TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj);
};
-/*! \brief A reference value. */
-class RefValue;
-
-struct RefValueNode : Object {
+struct RefValueObj : Object {
mutable ObjectRef value;
- RefValueNode() {}
+ RefValueObj() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
}
- TVM_DLL static RefValue make(ObjectRef val);
-
static constexpr const char* _type_key = "relay.RefValue";
- TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
+ TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object);
};
class RefValue : public ObjectRef {
public:
- TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
+ TVM_DLL RefValue(ObjectRef val);
+ TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj);
};
-/*! \brief An ADT constructor value. */
-class ConstructorValue;
-
-struct ConstructorValueNode : Object {
+struct ConstructorValueObj : Object {
int32_t tag;
tvm::Array<ObjectRef> fields;
v->Visit("constructor", &constructor);
}
- TVM_DLL static ConstructorValue make(int32_t tag,
- tvm::Array<ObjectRef> fields,
- Constructor construtor = {});
-
static constexpr const char* _type_key = "relay.ConstructorValue";
- TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
+ TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object);
};
class ConstructorValue : public ObjectRef {
public:
- TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
+ TVM_DLL ConstructorValue(int32_t tag,
+ tvm::Array<ObjectRef> fields,
+ Constructor construtor = {});
+
+ TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
};
} // namespace relay
enum TypeIndex {
/*! \brief Root object type. */
kRoot = 0,
- kVMTensor = 1,
- kVMClosure = 2,
- kVMADT = 3,
- kRuntimeModule = 4,
+ kClosure = 1,
+ kVMADT = 2,
+ kRuntimeModule = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
#define TVM_RUNTIME_VM_H_
#include <tvm/runtime/object.h>
+#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <memory>
#include <string>
#include <unordered_map>
+#include <utility>
#include <vector>
namespace tvm {
namespace runtime {
namespace vm {
-/*! \brief An object representing a closure. */
+/*!
+ * \brief An object representing a closure. This object is used by both the
+ * Relay VM and interpreter.
+ */
class ClosureObj : public Object {
public:
- /*! \brief The index into the VM function table. */
+ static constexpr const uint32_t _type_index = TypeIndex::kClosure;
+ static constexpr const char* _type_key = "Closure";
+ TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
+};
+
+/*! \brief reference to closure. */
+class Closure : public ObjectRef {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
+};
+
+/*!
+ * \brief An object representing a vm closure.
+ */
+class VMClosureObj : public ClosureObj {
+ public:
+ /*!
+ * \brief The index into the function list. The function could be any
+ * function object that is compatible to the VM runtime.
+ */
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars;
- static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
static constexpr const char* _type_key = "vm.Closure";
- TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
+ TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj);
};
/*! \brief reference to closure. */
-class Closure : public ObjectRef {
+class VMClosure : public Closure {
public:
- Closure(size_t func_index, std::vector<ObjectRef> free_vars);
-
- TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
+ VMClosure(size_t func_index, std::vector<ObjectRef> free_vars);
+ TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj);
};
/*! \brief Magic number for NDArray list file */
# under the License.
"""Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs
-from ._ffi.object import Object, register_object
+from tvm import ndarray as _nd
from . import _api_internal
+from ._ffi.object import Object, register_object, getitem_helper
+from ._ffi.function import _init_api
@register_object
class Array(Object):
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
+
+
+@register_object("vm.ADT")
+class ADT(Object):
+ """Algebatic data type(ADT) object.
+
+ Parameters
+ ----------
+ tag : int
+ The tag of ADT.
+
+ fields : list[Object] or tuple[Object]
+ The source tuple.
+ """
+ def __init__(self, tag, fields):
+ for f in fields:
+ assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
+ "tvm NDArray type, but received : {0}".format(type(f))
+ self.__init_handle_by_constructor__(_ADT, tag, *fields)
+
+ @property
+ def tag(self):
+ return _GetADTTag(self)
+
+ def __getitem__(self, idx):
+ return getitem_helper(
+ self, _GetADTFields, len(self), idx)
+
+ def __len__(self):
+ return _GetADTSize(self)
+
+
+def tuple_object(fields=None):
+ """Create a ADT object from source tuple.
+
+ Parameters
+ ----------
+ fields : list[Object] or tuple[Object]
+ The source tuple.
+
+ Returns
+ -------
+ ret : ADT
+ The created object.
+ """
+ fields = fields if fields else []
+ for f in fields:
+ assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
+ "NDArray type, but received : {0}".format(type(f))
+ return _Tuple(*fields)
+
+
+_init_api("tvm.container")
from . import feature
from .backend import vm
from .backend import profiler_vm
-from .backend import vmobj
# Root operators
from .op import Op
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""The VM Object FFI namespace."""
-from tvm._ffi.function import _init_api
-
-_init_api("_vmobj", __name__)
import numpy as np
+from tvm import container
from . import _backend
from .. import _make, analysis, transform
from .. import module
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder
-@register_relay_node
-class TupleValue(Object):
- """A tuple value produced by the interpreter."""
- def __init__(self, *fields):
- self.__init_handle_by_constructor__(
- _make.TupleValue, fields)
-
- def __getitem__(self, field_no):
- return self.fields[field_no]
-
- def __len__(self):
- return len(self.fields)
-
- def __str__(self):
- body = ','.join(str(f) for f in self.fields)
- return '({0})'.format(body)
-
- def __repr__(self):
- body = ','.join(repr(f) for f in self.fields)
- return '({0})'.format(body)
-
- def __iter__(self):
- return iter(self.fields)
-
-
-@register_relay_node
-class Closure(Object):
- """A closure produced by the interpreter."""
-
-
-@register_relay_node
-class RecClosure(Object):
- """A recursive closure produced by the interpreter."""
-
@register_relay_node
class ConstructorValue(Object):
def _arg_to_ast(mod, arg):
if isinstance(arg, nd.NDArray):
return Constant(arg.copyto(nd.cpu(0)))
- elif isinstance(arg, TupleValue):
- return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
+ elif isinstance(arg, container.ADT):
+ return Tuple([_arg_to_ast(mod, field) for field in arg])
elif isinstance(arg, tuple):
return Tuple([_arg_to_ast(mod, field) for field in arg])
elif isinstance(arg, RefValue):
import numpy as np
import tvm
-from tvm import autotvm
+from tvm import autotvm, container
+from tvm.object import Object
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from . import _vm
-from . import vmobj as _obj
from .interpreter import Executor
-ADT = _obj.ADT
-
def _convert(arg, cargs):
if isinstance(arg, _expr.Constant):
cargs.append(arg.data)
- elif isinstance(arg, _obj.Object):
+ elif isinstance(arg, Object):
cargs.append(arg)
elif isinstance(arg, np.ndarray):
nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0))
field_args = []
for field in arg:
_convert(field, field_args)
- cargs.append(_obj.tuple_object(field_args))
+ cargs.append(container.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0))
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Runtime Object API."""
-from __future__ import absolute_import as _abs
-
-from tvm._ffi.object import Object, register_object, getitem_helper
-from tvm import ndarray as _nd
-from . import _vmobj
-
-
-@register_object("vm.ADT")
-class ADT(Object):
- """Algebatic data type(ADT) object.
-
- Parameters
- ----------
- tag : int
- The tag of ADT.
-
- fields : list[Object] or tuple[Object]
- The source tuple.
- """
- def __init__(self, tag, fields):
- for f in fields:
- assert isinstance(f, (Object, _nd.NDArray)), "Expect object or "
- "tvm NDArray type, but received : {0}".format(type(f))
- self.__init_handle_by_constructor__(
- _vmobj.ADT, tag, *fields)
-
- @property
- def tag(self):
- return _vmobj.GetADTTag(self)
-
- def __getitem__(self, idx):
- return getitem_helper(
- self, _vmobj.GetADTFields, len(self), idx)
-
- def __len__(self):
- return _vmobj.GetADTNumberOfFields(self)
-
-
-def tuple_object(fields):
- """Create a ADT object from source tuple.
-
- Parameters
- ----------
- fields : list[Object] or tuple[Object]
- The source tuple.
-
- Returns
- -------
- ret : ADT
- The created object.
- """
- for f in fields:
- assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm "
- "NDArray type, but received : {0}".format(type(f))
- return _vmobj.Tuple(*fields)
"""Common utilities"""
from __future__ import absolute_import as _abs
import logging
+import numpy as np
import tvm
-import numpy as np
from topi.util import get_const_tuple
from .. import expr as _expr
from .. import module as _module
#pylint: disable=invalid-name
"""Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs
+import numpy as np
import tvm
import tvm.relay as relay
from tvm.relay import transform
from tvm.relay import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem, create_executor
from tvm.relay import TensorType, TupleType
-import numpy as np
from . import mlp
from . import resnet
# import numpy
# import tvm
# from tvm import relay
+# from tvm import import container as _container
# from tvm import nd
-# from tvm.relay.backend.interpreter import RefValue, TupleValue, ConstructorValue
+# from tvm.relay.backend.interpreter import RefValue, ConstructorValue
PROLOGUE = [
ast.Import([alias('numpy', None)]),
ast.Import([alias('tvm', None)]),
ast.ImportFrom('tvm', [alias('relay', None)], 0),
ast.ImportFrom('tvm', [alias('nd', None)], 0),
+ ast.ImportFrom('tvm', [alias('container', '_container')],
+ 0),
ast.ImportFrom('tvm.relay.backend.interpreter',
[alias('RefValue', None),
- alias('TupleValue', None),
alias('ConstructorValue', None)],
- 0)
+ 0),
]
class PythonConverter(ExprFunctor):
for i in range(len(arg_type.fields)):
ret += convert_input(
ast.Subscript(
- ast.Attribute(py_input, 'fields', Load()),
+ py_input,
ast.Index(Num(i)), Load()),
arg_type.fields[i])
return ret
assignments += inner_assignments
extra_args += inner_args
fields.append(inner_output)
- return (assignments, extra_args, self.create_call('TupleValue', fields))
+ fields = [ast.List(fields, Load())]
+ return (assignments, extra_args, self.create_call('_container.tuple_object', fields))
# create a function to wrap the call of the lowered op and return
# a call to that function
def visit_tuple(self, tup: Expr):
fields, ret_defs = self.convert_fields(tup.fields)
- return (self.create_call('TupleValue', fields), ret_defs)
+ fields = [ast.List(fields, Load())]
+ return (self.create_call('_container.tuple_object', fields), ret_defs)
def visit_tuple_getitem(self, tgi: Expr):
thunk_name, [],
ref_defs + val_defs + [
Assign([ast.Attribute(ref, 'value', Store())], val),
- Return(self.create_call('TupleValue', []))
+ Return(self.create_call('_container.tuple_object', []))
])
return (self.create_call(thunk_name, []), [thunk])
* \brief An interpreter for the Relay IR.
*/
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/object.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
using namespace runtime;
-inline const PackedFunc& GetPackedFunc(const std::string& name) {
- const PackedFunc* pf = tvm::runtime::Registry::Get(name);
- CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
- return *pf;
-}
-
-/* Object Implementation */
-Closure ClosureNode::make(tvm::Map<Var, ObjectRef> env, Function func) {
- ObjectPtr<ClosureNode> n = make_object<ClosureNode>();
+InterpreterClosure::InterpreterClosure(tvm::Map<Var, ObjectRef> env,
+ Function func) {
+ ObjectPtr<InterpreterClosureObj> n = make_object<InterpreterClosureObj>();
n->env = std::move(env);
n->func = std::move(func);
- return Closure(n);
+ data_ = std::move(n);
}
-TVM_REGISTER_GLOBAL("relay._make.Closure")
-.set_body_typed(ClosureNode::make);
-
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<ClosureNode>([](const ObjectRef& ref, NodePrinter* p) {
- auto* node = static_cast<const ClosureNode*>(ref.get());
- p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
- });
+.set_dispatch<InterpreterClosureObj >([](const ObjectRef& ref, NodePrinter* p) {
+ auto* node = static_cast<const InterpreterClosureObj*>(ref.get());
+ p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")";
+});
+inline const PackedFunc& GetPackedFunc(const std::string& name) {
+ const PackedFunc* pf = tvm::runtime::Registry::Get(name);
+ CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
+ return *pf;
+}
// TODO(@jroesch): this doesn't support mutual letrec
/* Object Implementation */
-RecClosure RecClosureNode::make(Closure clos, Var bind) {
- ObjectPtr<RecClosureNode> n = make_object<RecClosureNode>();
+RecClosure::RecClosure(InterpreterClosure clos, Var bind) {
+ ObjectPtr<RecClosureObj> n = make_object<RecClosureObj>();
n->clos = std::move(clos);
n->bind = std::move(bind);
- return RecClosure(n);
+ data_ = std::move(n);
}
-TVM_REGISTER_GLOBAL("relay._make.RecClosure")
-.set_body_typed(RecClosureNode::make);
-
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<RecClosureNode>([](const ObjectRef& ref, NodePrinter* p) {
- auto* node = static_cast<const RecClosureNode*>(ref.get());
- p->stream << "RecClosureNode(" << node->clos << ")";
+.set_dispatch<RecClosureObj>([](const ObjectRef& ref, NodePrinter* p) {
+ auto* node = static_cast<const RecClosureObj*>(ref.get());
+ p->stream << "RecClosureObj(" << node->clos << ")";
});
-TupleValue TupleValueNode::make(tvm::Array<ObjectRef> value) {
- ObjectPtr<TupleValueNode> n = make_object<TupleValueNode>();
- n->fields = value;
- return TupleValue(n);
-}
-
-TVM_REGISTER_GLOBAL("relay._make.TupleValue")
-.set_body_typed(TupleValueNode::make);
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TupleValueNode>([](const ObjectRef& ref, NodePrinter* p) {
- auto* node = static_cast<const TupleValueNode*>(ref.get());
- p->stream << "TupleValueNode(" << node->fields << ")";
- });
-
-
-RefValue RefValueNode::make(ObjectRef value) {
- ObjectPtr<RefValueNode> n = make_object<RefValueNode>();
+RefValue::RefValue(ObjectRef value) {
+ ObjectPtr<RefValueObj> n = make_object<RefValueObj>();
n->value = value;
- return RefValue(n);
+ data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("relay._make.RefValue")
-.set_body_typed(RefValueNode::make);
+.set_body_typed([](ObjectRef value){
+ return RefValue(value);
+});
-TVM_REGISTER_NODE_TYPE(RefValueNode);
+TVM_REGISTER_NODE_TYPE(RefValueObj);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<RefValueNode>([](const ObjectRef& ref, NodePrinter* p) {
- auto* node = static_cast<const RefValueNode*>(ref.get());
- p->stream << "RefValueNode(" << node->value << ")";
+.set_dispatch<RefValueObj>([](const ObjectRef& ref, NodePrinter* p) {
+ auto* node = static_cast<const RefValueObj*>(ref.get());
+ p->stream << "RefValueObj(" << node->value << ")";
});
-ConstructorValue ConstructorValueNode::make(int32_t tag,
- tvm::Array<ObjectRef> fields,
- Constructor constructor) {
- ObjectPtr<ConstructorValueNode> n = make_object<ConstructorValueNode>();
+ConstructorValue::ConstructorValue(int32_t tag,
+ tvm::Array<ObjectRef> fields,
+ Constructor constructor) {
+ ObjectPtr<ConstructorValueObj> n = make_object<ConstructorValueObj>();
n->tag = tag;
n->fields = fields;
n->constructor = constructor;
- return ConstructorValue(n);
+ data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("relay._make.ConstructorValue")
-.set_body_typed(ConstructorValueNode::make);
+.set_body_typed([](int32_t tag, tvm::Array<ObjectRef> fields,
+ Constructor constructor) {
+ return ConstructorValue(tag, fields, constructor);
+});
-TVM_REGISTER_NODE_TYPE(ConstructorValueNode);
+TVM_REGISTER_NODE_TYPE(ConstructorValueObj);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<ConstructorValueNode>([](const ObjectRef& ref, NodePrinter* p) {
- auto* node = static_cast<const ConstructorValueNode*>(ref.get());
- p->stream << "ConstructorValueNode(" << node->tag << ","
+.set_dispatch<ConstructorValueObj>([](const ObjectRef& ref, NodePrinter* p) {
+ auto* node = static_cast<const ConstructorValueObj*>(ref.get());
+ p->stream << "ConstructorValueObj(" << node->tag << ","
<< node->fields << ")";
});
class InterpreterState;
/*! \brief A container capturing the state of the interpreter. */
-class InterpreterStateNode : public Object {
+class InterpreterStateObj : public Object {
public:
using Frame = tvm::Map<Var, ObjectRef>;
using Stack = tvm::Array<Frame>;
static InterpreterState make(Expr current_expr, Stack stack);
static constexpr const char* _type_key = "relay.InterpreterState";
- TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateNode, Object);
+ TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateObj, Object);
};
class InterpreterState : public ObjectRef {
public:
- TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateNode);
+ TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateObj);
};
-InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
- ObjectPtr<InterpreterStateNode> n = make_object<InterpreterStateNode>();
+InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) {
+ ObjectPtr<InterpreterStateObj> n = make_object<InterpreterStateObj>();
n->current_expr = std::move(current_expr);
n->stack = std::move(stack);
return InterpreterState(n);
values.push_back(field_value);
}
- return TupleValueNode::make(values);
+ return ADT::Tuple(values);
}
ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) {
}
// We must use mutation here to build a self referential closure.
- auto closure = ClosureNode::make(captured_mod, func);
+ InterpreterClosure closure(captured_mod, func);
if (letrec_name.defined()) {
- return RecClosureNode::make(closure, letrec_name);
+ return RecClosure(closure, letrec_name);
}
return std::move(closure);
}
fset_input(arg_counter++, arg, true);
}
} else {
- const TupleValueNode* tuple = arg.as<TupleValueNode>();
- CHECK(tuple != nullptr);
+ const ADT adt = Downcast<ADT>(arg);
if (state & kNeedInputData) {
- for (size_t i = 0; i < tuple->fields.size(); ++i) {
- fset_input(arg_counter++, tuple->fields[i], false);
+ for (size_t i = 0; i < adt.size(); ++i) {
+ fset_input(arg_counter++, adt[i], false);
}
}
if (state & kNeedInputShape) {
- for (size_t i = 0; i < tuple->fields.size(); ++i) {
- fset_input(arg_counter++, tuple->fields[i], true);
+ for (size_t i = 0; i < adt.size(); ++i) {
+ fset_input(arg_counter++, adt[i], true);
}
}
}
}
// Marshal the arguments.
- // Handle tuple input/output by flattening them.
+ // Handle adt input/output by flattening them.
size_t arg_len = 0;
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->IsInstance<NDArray::ContainerType>()) {
++arg_len;
} else {
- const auto* tvalue = args[i].as<TupleValueNode>();
- arg_len += tvalue->fields.size();
+ auto adt = Downcast<ADT>(args[i]);
+ arg_len += adt.size();
}
}
size_t num_inputs = arg_len;
if (arg->IsInstance<NDArray::ContainerType>()) {
fset_input(arg_counter++, arg);
} else {
- const TupleValueNode* tuple = arg.as<TupleValueNode>();
- CHECK(tuple != nullptr);
- for (size_t i = 0; i < tuple->fields.size(); ++i) {
- fset_input(arg_counter++, tuple->fields[i]);
+ auto adt = Downcast<ADT>(arg);
+ for (size_t i = 0; i < adt.size(); ++i) {
+ fset_input(arg_counter++, adt[i]);
}
}
}
TVMRetValue rv;
if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
- Array<ObjectRef> fields;
+ std::vector<ObjectRef> fields;
for (size_t i = 0; i < rtype->fields.size(); ++i) {
if (is_dyn) {
auto sh = out_shapes[i];
}
}
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
- return TupleValueNode::make(fields);
+ return ADT::Tuple(fields);
} else {
ObjectRef out_tensor;
if (is_dyn) {
}
// Invoke the closure
- ObjectRef Invoke(const Closure& closure,
+ ObjectRef Invoke(const InterpreterClosure& closure,
const tvm::Array<ObjectRef>& args,
const Var& bind = Var()) {
// Get a reference to the function inside the closure.
}
if (bind.defined()) {
- locals.Set(bind, RecClosureNode::make(closure, bind));
+ locals.Set(bind, RecClosure(closure, bind));
}
return WithFrame<ObjectRef>(Frame(locals), [&]() { return Eval(func->body); });
"fusing and lowering";
}
if (auto con = call->op.as<ConstructorNode>()) {
- return ConstructorValueNode::make(con->tag, args, GetRef<Constructor>(con));
+ return ConstructorValue(con->tag, args, GetRef<Constructor>(con));
}
// Now we just evaluate and expect to find a closure.
ObjectRef fn_val = Eval(call->op);
- if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
- auto closure = GetRef<Closure>(closure_node);
+ if (const InterpreterClosureObj* closure_node = fn_val.as<InterpreterClosureObj>()) {
+ auto closure = GetRef<InterpreterClosure>(closure_node);
return this->Invoke(closure, args);
- } else if (const RecClosureNode* closure_node = fn_val.as<RecClosureNode>()) {
+ } else if (const RecClosureObj* closure_node = fn_val.as<RecClosureObj>()) {
return this->Invoke(closure_node->clos, args, closure_node->bind);
} else {
LOG(FATAL) << "internal error: type error, expected function value in the call "
ObjectRef VisitExpr_(const TupleGetItemNode* op) final {
ObjectRef val = Eval(op->tuple);
- auto product_node = val.as<TupleValueNode>();
- CHECK(product_node)
- << "interal error: when evaluating TupleGetItem expected a tuple value";
- CHECK_LT(static_cast<size_t>(op->index), product_node->fields.size())
+ const auto* adt_obj = val.as<ADTObj>();
+ CHECK(adt_obj)
+ << "interal error: when evaluating TupleGetItem expected an ADT value";
+ auto adt = GetRef<ADT>(adt_obj);
+ CHECK_LT(static_cast<size_t>(op->index), adt.size())
<< "internal error: index out of bounds";
- return product_node->fields[op->index];
+ return adt[op->index];
}
ObjectRef VisitExpr_(const IfNode* op) final {
ObjectRef VisitExpr_(const RefWriteNode* op) final {
ObjectRef r = Eval(op->ref);
- if (const RefValueNode* rv = r.as<RefValueNode>()) {
+ if (const RefValueObj* rv = r.as<RefValueObj>()) {
rv->value = Eval(op->value);
- return TupleValueNode::make({});
+ return ADT::Tuple(std::vector<ObjectRef>());
} else {
LOG(FATAL) << "type error, type system should have caught this";
return ObjectRef();
}
ObjectRef VisitExpr_(const RefCreateNode* op) final {
- return RefValueNode::make(Eval(op->value));
+ return RefValue(Eval(op->value));
}
ObjectRef VisitExpr_(const RefReadNode* op) final {
ObjectRef r = Eval(op->ref);
- if (const RefValueNode* rv = r.as<RefValueNode>()) {
+ if (const RefValueObj* rv = r.as<RefValueObj>()) {
return rv->value;
} else {
LOG(FATAL) << "type error, type system should have caught this";
}
bool VisitPattern_(const PatternConstructorNode* op, const ObjectRef& v) final {
- const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
+ const ConstructorValueObj* cvn = v.as<ConstructorValueObj>();
CHECK(cvn) << "need to be a constructor for match";
CHECK_NE(op->constructor->tag, -1);
CHECK_NE(cvn->tag, -1);
}
bool VisitPattern_(const PatternTupleNode* op, const ObjectRef& v) final {
- const TupleValueNode* tvn = v.as<TupleValueNode>();
- CHECK(tvn) << "need to be a tuple for match";
- CHECK_EQ(op->patterns.size(), tvn->fields.size());
+ auto adt = Downcast<ADT>(v);
+ CHECK_EQ(op->patterns.size(), adt.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
- if (!VisitPattern(op->patterns[i], tvn->fields[i])) {
+ if (!VisitPattern(op->patterns[i], adt[i])) {
return false;
}
}
}
InterpreterState get_state(Expr e = Expr()) const {
- InterpreterStateNode::Stack stack;
+ InterpreterStateObj::Stack stack;
for (auto fr : this->stack_.frames) {
- InterpreterStateNode::Frame frame = fr.locals;
+ InterpreterStateObj::Frame frame = fr.locals;
stack.push_back(frame);
}
- auto state = InterpreterStateNode::make(e, stack);
+ auto state = InterpreterStateObj::make(e, stack);
return state;
}
TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter")
.set_body_typed(CreateInterpreter);
-TVM_REGISTER_NODE_TYPE(ClosureNode);
-TVM_REGISTER_NODE_TYPE(TupleValueNode);
-
} // namespace relay
} // namespace tvm
#include <tvm/relay/transform.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/container.h>
#include "pattern_util.h"
namespace tvm {
<< "invalid dimension after constant eval";
}
return ConstantNode::make(nd_array);
- } else if (const auto* val = value.as<TupleValueNode>()) {
+ } else if (const auto* val = value.as<runtime::ADTObj>()) {
+ runtime::ADT adt = GetRef<runtime::ADT>(val);
Array<Expr> fields;
- for (ObjectRef field : val->fields) {
- fields.push_back(ObjectToExpr(field));
+ for (size_t i = 0; i < adt.size(); ++i) {
+ fields.push_back(ObjectToExpr(adt[i]));
}
return TupleNode::make(fields);
} else {
if (v->IsInstance<runtime::NDArray::ContainerType>()) {
auto nd_array = Downcast<runtime::NDArray>(v);
return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array)));
- } else if (const TupleValueNode* op = v.as<TupleValueNode>()) {
+ } else if (const runtime::ADTObj* op = v.as<runtime::ADTObj>()) {
std::vector<PStatic> fields;
tvm::Array<Expr> fields_dyn;
- for (const ObjectRef& field : op->fields) {
- PStatic ps = Reify(field, ll);
+ auto adt = GetRef<runtime::ADT>(op);
+ for (size_t i = 0; i < adt.size(); ++i) {
+ PStatic ps = Reify(adt[i], ll);
fields.push_back(ps);
fields_dyn.push_back(ps->dynamic);
}
*/
/*!
- * \file src/runtime/vm/object.cc
- * \brief VM related objects.
+ * \file src/runtime/container.cc
+ * \brief Implementations of common plain old data (POD) containers.
*/
-#include <tvm/support/logging.h>
#include <tvm/runtime/container.h>
+#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/vm.h>
-#include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h>
-#include <tvm/runtime/c_runtime_api.h>
-#include "../runtime_base.h"
namespace tvm {
namespace runtime {
-namespace vm {
-
-Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
- auto ptr = make_object<ClosureObj>();
- ptr->func_index = func_index;
- ptr->free_vars = std::move(free_vars);
- data_ = std::move(ptr);
-}
+using namespace vm;
-TVM_REGISTER_GLOBAL("_vmobj.GetADTTag")
+TVM_REGISTER_GLOBAL("container._GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.tag());
});
-TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields")
+TVM_REGISTER_GLOBAL("container._GetADTSize")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
});
-TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
+TVM_REGISTER_GLOBAL("container._GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
int idx = args[1];
*rv = adt[idx];
});
-TVM_REGISTER_GLOBAL("_vmobj.Tuple")
+TVM_REGISTER_GLOBAL("container._Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) {
*rv = ADT::Tuple(fields);
});
-TVM_REGISTER_GLOBAL("_vmobj.ADT")
+TVM_REGISTER_GLOBAL("container._ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
-} // namespace vm
+
} // namespace runtime
} // namespace tvm
-
-using namespace tvm::runtime;
-
-int TVMGetObjectTag(TVMObjectHandle handle, int* tag) {
- API_BEGIN();
- int res = static_cast<int>(static_cast<Object*>(handle)->type_index());
- *tag = res;
- API_END();
-}
namespace runtime {
namespace vm {
+VMClosure::VMClosure(size_t func_index, std::vector<ObjectRef> free_vars) {
+ auto ptr = make_object<VMClosureObj>();
+ ptr->func_index = func_index;
+ ptr->free_vars = std::move(free_vars);
+ data_ = std::move(ptr);
+}
inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint, TVMContext ctx) {
// We could put cache in here, from ctx to storage allocator.
}
case Opcode::InvokeClosure: {
auto object = ReadRegister(instr.closure);
- const auto* closure = object.as<ClosureObj>();
+ const auto* closure = object.as<VMClosureObj>();
std::vector<ObjectRef> args;
for (auto free_var : closure->free_vars) {
for (Index i = 0; i < instr.num_freevar; i++) {
free_vars.push_back(ReadRegister(instr.free_vars[i]));
}
- WriteRegister(instr.dst, Closure(instr.func_index, free_vars));
+ WriteRegister(instr.dst, VMClosure(instr.func_index, free_vars));
pc_++;
goto main_loop;
}
def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
- elif isinstance(o, tvm.relay.backend.vmobj.ADT):
+ elif isinstance(o, tvm.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return result
- elif isinstance(o, tvm.relay.backend.interpreter.TupleValue):
- result = []
- for f in o.fields:
- result.append(vmobj_to_list(f))
- return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1])
import tvm
from tvm.contrib import graph_runtime
-from tvm import relay
+from tvm import relay, container
from tvm.relay import testing
from tvm.relay import vm
-from tvm.relay import vmobj as _obj
def benchmark_execution(mod,
ftimer = rly_vm.mod.time_evaluator("invoke", ctx, number=number,
repeat=repeat)
# Measure in millisecond.
- prof_res = np.array(ftimer("main", _obj.Tensor(data)).results) * 1000
+ prof_res = np.array(ftimer("main", data).results) * 1000
print("Mean vm inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
def vmobj_to_list(o, dtype="float32"):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
- elif isinstance(o, tvm.relay.backend.vmobj.ADT):
+ elif isinstance(o, tvm.container.ADT):
if len(o) == 0:
tensor_nil = p.get_var("tensor_nil", dtype=dtype)
if tensor_nil.tag == o.tag:
import tvm
import tvm.testing
from tvm import nd
-from tvm import relay
-from tvm.relay.backend.interpreter import TupleValue
+from tvm import relay, container
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor
def test_tuple_value():
- tv = TupleValue(relay.const(1), relay.const(2), relay.const(3))
+ tv = container.tuple_object([relay.const(1), relay.const(2),
+ relay.const(3)])
np.testing.assert_allclose(tv[0].data.asnumpy(), 1)
np.testing.assert_allclose(tv[1].data.asnumpy(), 2)
np.testing.assert_allclose(tv[2].data.asnumpy(), 3)
], prelude.cons)
ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32')))
- tuple_value = TupleValue(*[
+ tuple_value = container.tuple_object([
nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10)
])
res_tuple = id_func(tuple_value)
for i in range(10):
- tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(),
- tuple_value.fields[i].asnumpy())
+ tvm.testing.assert_allclose(res_tuple[i].asnumpy(),
+ tuple_value[i].asnumpy())
def test_tuple_passing():
x = relay.var('x', type_annotation=relay.ty.TupleType([
out = f((10, 8))
tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
# Second use a tuple value.
- value_tuple = TupleValue(nd.array(np.array(11)), nd.array(np.array(12)))
+ value_tuple = container.tuple_object([nd.array(np.array(11)),
+ nd.array(np.array(12))])
out = f(value_tuple)
tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
from tvm import relay
from tvm.relay.testing import to_python, run_as_python
from tvm.relay.prelude import Prelude
-from tvm.relay.backend.interpreter import TupleValue, RefValue, ConstructorValue
+from tvm.container import ADT
+from tvm.relay.backend.interpreter import RefValue, ConstructorValue
# helper: uses a dummy let binding to sequence a list
# of expressions: expr1; expr2; expr3, etc.
assert np.array_equal(candidate.asnumpy(), np.array(val))
-# assert that the candidate is a TupleValue with the indicate number of fields
-def assert_tuple_value(candidate, fields):
- assert isinstance(candidate, TupleValue)
- assert len(candidate.fields) == fields
+# assert that the candidate is an ADT with the indicated number of fields
+def assert_adt_len(candidate, fields):
+ assert isinstance(candidate, ADT)
+ assert len(candidate) == fields
# assert that the candidate is a ConstructorValue with the approrpaite constructor
def test_create_empty_tuple():
empty = relay.Tuple([])
tup_val = run_as_python(empty)
- assert_tuple_value(tup_val, 0)
+ assert_adt_len(tup_val, 0)
def test_create_scalar():
])
])
tup_val = run_as_python(relay_tup)
- assert_tuple_value(tup_val, 3)
+ assert_adt_len(tup_val, 3)
for i in range(2):
- assert_tensor_value(tup_val.fields[i], i + 1)
- assert_tuple_value(tup_val.fields[2], 2)
+ assert_tensor_value(tup_val[i], i + 1)
+ assert_adt_len(tup_val[2], 2)
for i in range(2):
- assert_tensor_value(tup_val.fields[2].fields[i], i + 3)
+ assert_tensor_value(tup_val[2][i], i + 3)
def test_tuple_get_item():
v = relay.Var('v')
let = relay.Let(v, relay.Tuple([]), relay.Tuple([v, v]))
tup_val = run_as_python(let)
- assert_tuple_value(tup_val, 2)
- assert_tuple_value(tup_val.fields[0], 0)
- assert_tuple_value(tup_val.fields[1], 0)
+ assert_adt_len(tup_val, 2)
+ assert_adt_len(tup_val[0], 0)
+ assert_adt_len(tup_val[1], 0)
def test_create_ref():
relay_ref = relay.RefCreate(relay.Tuple([]))
ref_val = run_as_python(relay_ref)
assert isinstance(ref_val, RefValue)
- assert_tuple_value(ref_val.value, 0)
+ assert_adt_len(ref_val.value, 0)
def test_ref_read():
v = relay.Var('v')
assign = relay.Let(v, relay.RefCreate(relay.Tuple([])), relay.RefRead(v))
read_val = run_as_python(assign)
- assert_tuple_value(read_val, 0)
+ assert_adt_len(read_val, 0)
def test_ref_write():
initial_write = relay.Let(v, relay.RefCreate(relay.Tuple([relay.const(1)])),
relay.RefWrite(v, relay.Tuple([relay.const(2)])))
write_val = run_as_python(initial_write)
- assert_tuple_value(write_val, 0)
+ assert_adt_len(write_val, 0)
# now ensure that the value, once written, can be read back
# (we read the value before and after mutation)
seq(relay.RefWrite(v, relay.Tuple([relay.const(2)])),
relay.Tuple([relay.RefRead(w), relay.RefRead(v)]))))
read_val = run_as_python(read_after_write)
- assert_tuple_value(read_val, 2)
- assert_tuple_value(read_val.fields[0], 1)
- assert_tuple_value(read_val.fields[1], 1)
- assert_tensor_value(read_val.fields[0].fields[0], 1)
- assert_tensor_value(read_val.fields[1].fields[0], 2)
+ assert_adt_len(read_val, 2)
+ assert_adt_len(read_val[0], 1)
+ assert_adt_len(read_val[1], 1)
+ assert_tensor_value(read_val[0][0], 1)
+ assert_tensor_value(read_val[1][0], 2)
def test_if():
call2 = relay.Let(f, ident, f(relay.const(2)))
call_val1 = run_as_python(call1)
- assert_tuple_value(call_val1, 0)
+ assert_adt_len(call_val1, 0)
call_val2 = run_as_python(call2)
assert_tensor_value(call_val2, 2)
assert_tensor_value(call_val1, 1)
call_val2 = run_as_python(call2, mod)
- assert_tuple_value(call_val2, 2)
- assert_tensor_value(call_val2.fields[0], 2)
- assert_tensor_value(call_val2.fields[1], 2)
+ assert_adt_len(call_val2, 2)
+ assert_tensor_value(call_val2[0], 2)
+ assert_tensor_value(call_val2[1], 2)
def test_constructor():
box_val_tup = run_as_python(init_box_tup, mod)
assert_constructor_value(box_val_tup, box_ctor, 1)
- assert_tuple_value(box_val_tup.fields[0], 0)
+ assert_adt_len(box_val_tup.fields[0], 0)
def test_match_wildcard():
assert_tensor_value(val.fields[1].fields[0], 2)
assert_constructor_value(val.fields[1].fields[1], p.cons, 2)
assert_tensor_value(val.fields[1].fields[1].fields[0], 3)
- assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0)
+ assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0)
def test_global_recursion():
call2 = copy_def(p.cons(relay.Tuple([]), p.nil()))
val2 = run_as_python(call2, mod)
assert_constructor_value(val2, p.cons, 2)
- assert_tuple_value(val2.fields[0], 0)
+ assert_adt_len(val2.fields[0], 0)
assert_constructor_value(val2.fields[1], p.nil, 0)
])
tup_val = run_as_python(expr, mod)
- assert_tuple_value(tup_val, 3)
- assert_tensor_value(tup_val.fields[0], 2)
- assert_tensor_value(tup_val.fields[1], 3)
- assert_tensor_value(tup_val.fields[2], 4)
+ assert_adt_len(tup_val, 3)
+ assert_tensor_value(tup_val[0], 2)
+ assert_tensor_value(tup_val[1], 3)
+ assert_tensor_value(tup_val[2], 4)
def test_ref_execution_order():
])))
tup_val = run_as_python(expr)
- assert_tuple_value(tup_val, 5)
- assert_tensor_value(tup_val.fields[0], 1)
- assert_tensor_value(tup_val.fields[1], 2)
- assert_tensor_value(tup_val.fields[2], 3)
- assert_tensor_value(tup_val.fields[3], 4)
- assert_tensor_value(tup_val.fields[4], 5)
+ assert_adt_len(tup_val, 5)
+ assert_tensor_value(tup_val[0], 1)
+ assert_tensor_value(tup_val[1], 2)
+ assert_tensor_value(tup_val[2], 3)
+ assert_tensor_value(tup_val[3], 4)
+ assert_tensor_value(tup_val[4], 5)
def test_op_add():
args.append(relay.const(data))
call = relay.stack(relay.Tuple(args), axis)
call_val = run_as_python(call)
+ type(call_val)
assert_tensor_value(call_val, ref_res)
verify_stack([(2,), (2,), (2,)], -1)
ref_res = np.split(x, indices_or_sections, axis=axis)
call = relay.split(relay.const(x), indices_or_sections, axis=axis)
call_val = run_as_python(call)
- assert_tuple_value(call_val, len(ref_res))
+ assert_adt_len(call_val, len(ref_res))
for i in range(len(ref_res)):
- assert_tensor_value(call_val.fields[i], ref_res[i])
+ assert_tensor_value(call_val[i], ref_res[i])
verify_split((2, 3), 2)
verify_split((5, 3), [3])
def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
- elif isinstance(o, tvm.relay.backend.vm.ADT):
+ elif isinstance(o, tvm.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import numpy as np
-import tvm
-from tvm.relay import vm
-
-def test_adt():
- arr = tvm.nd.array([1,2,3])
- y = vm.ADT(0, [arr, arr])
-
- assert len(y) == 2
- assert isinstance(y, vm.ADT)
- y[0:1][-1] == arr
- assert y.tag == 0
- assert isinstance(arr, tvm.nd.NDArray)
-
-
-if __name__ == "__main__":
- test_adt()
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import tvm
+from tvm import nd, relay
+from tvm import container as _container
+
+
+def test_adt_constructor():
+ arr = nd.array([1, 2, 3])
+ fields = [arr, arr]
+ y = _container.ADT(0, [arr, arr])
+
+ assert len(y) == 2
+ assert isinstance(y, _container.ADT)
+ y[0:1][-1] == arr
+ assert y.tag == 0
+ assert isinstance(arr, nd.NDArray)
+
+
+def test_tuple_object():
+ x = relay.var(
+ 'x',
+ type_annotation=relay.ty.TupleType([
+ relay.ty.TensorType((), 'int32'),
+ relay.ty.TensorType((), 'int32')
+ ]))
+
+ fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
+ mod = relay.Module.from_expr(fn)
+
+ exe = relay.create_executor(
+ kind="vm", mod=mod, ctx=nd.cpu(), target="llvm")
+ f = exe.evaluate()
+ value_tuple = _container.tuple_object(
+ [nd.array(np.array(11)),
+ nd.array(np.array(12))])
+ # pass an ADT object to evaluate
+ out = f(value_tuple)
+ tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
+
+
+if __name__ == "__main__":
+ test_adt_constructor()
+ test_tuple_object()