Implement explicit IR representation of memory alloction (#3560)
authorJared Roesch <roeschinc@gmail.com>
Fri, 1 Nov 2019 21:28:23 +0000 (16:28 -0500)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 1 Nov 2019 21:28:23 +0000 (14:28 -0700)
42 files changed:
CMakeLists.txt
include/tvm/relay/attrs/memory.h [new file with mode: 0644]
include/tvm/relay/base.h
include/tvm/relay/module.h
include/tvm/runtime/object.h
include/tvm/runtime/vm.h
python/tvm/relay/__init__.py
python/tvm/relay/backend/compile_engine.py
python/tvm/relay/debug.py
python/tvm/relay/expr.py
python/tvm/relay/memory_alloc.py [new file with mode: 0644]
python/tvm/relay/op/__init__.py
python/tvm/relay/op/memory/__init__.py [new file with mode: 0644]
python/tvm/relay/op/memory/_make.py [new file with mode: 0644]
python/tvm/relay/op/memory/memory.py [new file with mode: 0644]
python/tvm/relay/std/core.rly [new file with mode: 0644]
python/tvm/relay/ty.py
src/relay/backend/compile_engine.cc
src/relay/backend/interpreter.cc
src/relay/backend/utils.h
src/relay/backend/vm/compiler.cc
src/relay/ir/expr.cc
src/relay/ir/module.cc
src/relay/op/annotation/annotation.cc
src/relay/op/device_copy.cc
src/relay/op/memory/memory.cc [new file with mode: 0644]
src/relay/op/op_common.h
src/relay/op/tensor/unary.cc
src/relay/op/type_relations.cc
src/relay/op/type_relations.h
src/relay/pass/device_annotation.cc
src/relay/pass/fold_constant.cc
src/relay/pass/fuse_ops.cc
src/relay/pass/pass_manager.cc
src/relay/pass/type_infer.cc
src/relay/pass/type_solver.cc
src/runtime/vm/executable.cc
src/runtime/vm/memory_manager.cc
src/runtime/vm/memory_manager.h
src/runtime/vm/vm.cc
tests/python/relay/test_memory_alloc.py [new file with mode: 0644]
tests/python/relay/test_vm_serialization.py

index a5f5f14..2bea818 100644 (file)
@@ -272,6 +272,7 @@ add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
 if(USE_RELAY_DEBUG)
   message(STATUS "Building Relay in debug mode...")
   set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG")
+  set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "DMLC_LOG_DEBUG")
 else()
   set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG")
 endif(USE_RELAY_DEBUG)
diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h
new file mode 100644 (file)
index 0000000..2e279a5
--- /dev/null
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/attrs/memory.h
+ * \brief Attributes for memory operators.
+ */
+#ifndef TVM_RELAY_ATTRS_MEMORY_H_
+#define TVM_RELAY_ATTRS_MEMORY_H_
+
+#include <tvm/attrs.h>
+#include <tvm/relay/expr.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Options for allocating tensors.
+ */
+struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
+  Constant const_shape;
+  Array<IndexExpr> assert_shape;
+  DataType dtype;
+
+  TVM_DECLARE_ATTRS(AllocTensorAttrs, "relay.attrs.AllocTensorAttrs") {
+    TVM_ATTR_FIELD(dtype)
+      .describe(
+         "The dtype of the tensor to allocate.")
+      .set_default(Float(32, 1));
+    TVM_ATTR_FIELD(const_shape)
+      .describe(
+         "The shape of constant used to aid in type inference.");
+    TVM_ATTR_FIELD(assert_shape)
+      .describe(
+         "The shape to cast the return type of the allocation to, "\
+         "used to specify the shape obtained via further analysis.");
+  }
+};
+
+/*!
+ * \brief Options for the shape function operator.
+ */
+struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
+  Array<Integer> is_input;
+
+  TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") {
+    TVM_ATTR_FIELD(is_input)
+      .describe(
+         "A bool indicating whether the shape function should"\
+         "expect shape or input in each position.");
+  }
+};
+
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_ATTRS_MEMORY_H_
index 5a2326e..42a01f0 100644 (file)
@@ -47,6 +47,12 @@ namespace relay {
   (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
 }
 
+#define RELAY_DEBUG_INTERP(...) \
+{ auto fdebug = runtime::Registry::Get("relay.debug_interp"); \
+  CHECK(fdebug) << "Could not find Relay Python debugger function."; \
+  (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
+}
+
 /*!
  * \brief We always used NodeRef for referencing nodes.
  *
index 160ae5d..1ef7ca8 100644 (file)
@@ -76,7 +76,8 @@ class ModuleNode : public RelayNode {
   }
 
   TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
-                             tvm::Map<GlobalTypeVar, TypeData> global_type_defs);
+                             tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
+                             std::unordered_set<std::string> imports = {});
 
   /*!
    * \brief Add a function to the global environment.
@@ -235,6 +236,11 @@ class ModuleNode : public RelayNode {
    */
   TVM_DLL void ImportFromStd(const std::string& path);
 
+  /*!
+   * \brief The set of imported files.
+   */
+  TVM_DLL std::unordered_set<std::string> Imports() const;
+
   /*! \brief Construct a module from a standalone expression.
    *
    * Allows one to optionally pass a global function map and
index cc4a295..0aa7815 100644 (file)
@@ -283,6 +283,8 @@ class Object {
    * \note The deleter will be called when ref_counter_ becomes zero.
    */
   inline void DecRef();
+
+ private:
   /*!
    * \return The usage count of the cell.
    * \note We use stl style naming to be consistent with known API in shared_ptr.
@@ -675,6 +677,16 @@ struct ObjectEqual {
   operator bool() const { return data_ != nullptr; }                    \
   using ContainerType = ObjectName;
 
+#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
+  TypeName() {}                                                             \
+  explicit TypeName(                                                        \
+      ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n)                  \
+      : ParentType(n) {}                                                    \
+  ObjectName* operator->() {                                    \
+    return static_cast<ObjectName*>(data_.get());                     \
+  }                                                                         \
+  operator bool() const { return data_ != nullptr; }                        \
+  using ContainerType = ObjectName;
 
 // Implementations details below
 // Object reference counting.
index ee973cb..a196afd 100644 (file)
@@ -138,6 +138,7 @@ enum class Opcode {
   GetTag = 13U,
   LoadConsti = 14U,
   Fatal = 15U,
+  AllocStorage = 16U,
 };
 
 /*! \brief A single virtual machine instruction.
@@ -158,6 +159,8 @@ struct Instruction {
 
   union {
     struct /* AllocTensor Operands */ {
+      /*! \brief The storage to allocate from. */
+      RegName storage;
       /*! \brief The number of dimensions. */
       uint32_t ndim;
       /*! \brief The shape of tensor. */
@@ -166,6 +169,8 @@ struct Instruction {
       DLDataType dtype;
     } alloc_tensor;
     struct /* AllocTensorReg Operands */ {
+      /*! \brief The storage to allocate from. */
+      RegName storage;
       /*! \brief The register to read the shape out of. */
       RegName shape_register;
       /*! \brief The datatype of tensor to be allocated. */
@@ -253,6 +258,14 @@ struct Instruction {
       /*! \brief The free variables as an array. */
       RegName* free_vars;
     };
+    struct /* AllocStorage Operands */ {
+      /*! \brief The size of the allocation. */
+      RegName allocation_size;
+      /*! \brief The alignment of the allocation. */
+      RegName alignment;
+      /*! \brief The hint of the dtype. */
+      DLDataType dtype_hint;
+    } alloc_storage;
   };
 
   /*! \brief Construct a return instruction.
@@ -274,19 +287,23 @@ struct Instruction {
   static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
                                   const std::vector<RegName>& args);
   /*! \brief Construct an allocate tensor instruction with constant shape.
+   *  \param storage The storage to allocate out of.
    *  \param shape The shape of the tensor.
    *  \param dtype The dtype of the tensor.
    *  \param dst The destination register.
    *  \return The allocate tensor instruction.
    */
-  static Instruction AllocTensor(std::vector<int64_t> shape, DLDataType dtype, RegName dst);
+  static Instruction AllocTensor(RegName storage,
+                                 const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
   /*! \brief Construct an allocate tensor instruction with register.
+   *  \param storage The storage to allocate out of.
    *  \param shape_register The register containing the shape.
    *  \param dtype The dtype of the tensor.
    *  \param dst The destination register.
    *  \return The allocate tensor instruction.
    */
-  static Instruction AllocTensorReg(RegName shape_register, DLDataType dtype, RegName dst);
+  static Instruction AllocTensorReg(RegName storage,
+                                    RegName shape_register, DLDataType dtype, RegName dst);
   /*! \brief Construct an allocate datatype instruction.
    *  \param tag The datatype tag.
    *  \param num_fields The number of fields for the datatype.
@@ -295,7 +312,7 @@ struct Instruction {
    *  \return The allocate instruction tensor.
    */
   static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
-                                   RegName dst);
+                              RegName dst);
   /*! \brief Construct an allocate closure instruction.
    *  \param func_index The index of the function table.
    *  \param num_freevar The number of free variables.
@@ -364,6 +381,16 @@ struct Instruction {
    */
   static Instruction Move(RegName src, RegName dst);
 
+   /*! \brief Allocate a storage block.
+   *  \param size The size of the allocation.
+   *  \param alignment The allocation's alignment.
+   *  \param dtype_hint The data type hint for the allocator.
+   *  \param dst The destination to place the storage.
+   *  \return The alloc storage instruction.
+   */
+  static Instruction AllocStorage(RegName size, RegName alignment,
+                                  DLDataType dtype_hint, RegName dst);
+
   Instruction();
   Instruction(const Instruction& instr);
   Instruction& operator=(const Instruction& instr);
index bd3f5bd..c7cbcf0 100644 (file)
@@ -59,6 +59,8 @@ from . import quantize
 from . import qnn
 
 from .scope_builder import ScopeBuilder
+# Load Memory pass
+from . import memory_alloc
 
 # Required to traverse large programs
 setrecursionlimit(10000)
index 152da61..545f96a 100644 (file)
@@ -99,6 +99,10 @@ class CompileEngine(NodeBase):
             msg += "--------------------------\n"
             raise RuntimeError(msg)
 
+    def lower_shape_func(self, source_func, target=None):
+        key = _get_cache_key(source_func, target)
+        return _backend._CompileEngineLowerShapeFunc(self, key)
+
     def jit(self, source_func, target=None):
         """JIT a source_func to a tvm.Function.
 
index 8887a7e..de18352 100644 (file)
@@ -25,9 +25,14 @@ def _debugger_init(expr, stack):
     import pdb
     pdb.set_trace()
 
-# pylint: disable=unused-argument
 @register_func("relay.debug")
 def _debug(*args):
+    import pdb
+    pdb.set_trace()
+
+# pylint: disable=unused-argument
+@register_func("relay.debug_interp")
+def _debug_interp(*args):
     _, _, _, ist = args
     print("Relay Debugger")
     print("  You can manipulate the expression under evaluation with the name `expr`.")
index 8d59e99..e288c3d 100644 (file)
@@ -317,6 +317,9 @@ class Function(Expr):
 
         return _expr.FunctionSetParams(self, params)
 
+    def set_attribute(self, name, ref):
+        return _expr.FunctionSetAttr(self, name, ref)
+
 
 @register_relay_node
 class Call(Expr):
diff --git a/python/tvm/relay/memory_alloc.py b/python/tvm/relay/memory_alloc.py
new file mode 100644 (file)
index 0000000..7116af3
--- /dev/null
@@ -0,0 +1,279 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return,invalid-name,len-as-condition
+"""
+A pass for manifesting explicit memory allocations.
+"""
+import numpy as np
+from .expr_functor import ExprMutator
+from .scope_builder import ScopeBuilder
+from . import transform
+from . import op, ty, expr
+from .. import TVMType, register_func
+from .backend import compile_engine
+
+
+def is_primitive(call):
+    return hasattr(call.op, 'attrs') and int(call.op.attrs.Primitive) == 1
+
+# TODO(@jroesch): port to c++ and unify with existing code
+class LinearizeRetType:
+    """A linear view of a Relay type, handles a linear order
+       for nested tuples, and tensor types.
+    """
+
+    def __init__(self, typ):
+        """Initialize the linearizer."""
+        self.typ = typ
+
+    def unpack(self):
+        """Return the linear representation of the type."""
+        def _unpack(typ, out):
+            # TODO(@jroesch): replace with new flattening pass
+            if isinstance(typ, ty.TensorType):
+                out.append(typ)
+            elif isinstance(typ, ty.TupleType):
+                for field_ty in typ.fields:
+                    _unpack(field_ty, out)
+            else:
+                raise Exception(f"unsupported Relay type: {typ}")
+
+        output = []
+        _unpack(self.typ, output)
+        return output
+
+    def pack(self, seq):
+        """Repack a linear type as a nested type."""
+        def _pack(value, typ, out):
+            if isinstance(typ, ty.TensorType):
+                out.append(value)
+            elif isinstance(typ, ty.TupleType):
+                tuple_out = []
+                for i, field_ty in enumerate(typ.fields):
+                    _pack(value[i], field_ty, tuple_out)
+                out.append(expr.Tuple(tuple_out))
+            else:
+                raise Exception(f"unsupported Relay type: {typ}")
+
+        if len(seq) == 1:
+            return seq[0]
+        else:
+            out = []
+            _pack(seq, self.typ, out)
+            assert len(out) == 1, "must return fully packed type"
+            return out[0]
+
+
+class ManifestAllocPass(ExprMutator):
+    """A pass for explictly manifesting all memory allocations in Relay."""
+
+    def __init__(self, target_host):
+        self.invoke_tvm = op.memory.invoke_tvm_op
+        self.alloc_storage = op.memory.alloc_storage
+        self.alloc_tensor = op.memory.alloc_tensor
+        self.shape_func = op.memory.shape_func
+        self.scopes = [ScopeBuilder()]
+        self.target_host = target_host
+        self.compute_dtype = "int64"
+        super().__init__()
+
+    def current_scope(self):
+        return self.scopes[-1]
+
+    def shape_of(self, e):
+        return op.shape_of(e, self.compute_dtype)
+
+    def visit_tuple(self, tup):
+        scope = self.current_scope()
+        new_fields = []
+        for field in tup.fields:
+            field = self.visit(field)
+            if isinstance(field, expr.Constant):
+                field = scope.let('const', field)
+            new_fields.append(field)
+        return expr.Tuple(new_fields)
+
+    def compute_alignment(self, dtype):
+        dtype = TVMType(dtype)
+        align = (dtype.bits // 8) * dtype.lanes
+        # MAGIC CONSTANT FROM device_api.h
+        if align < 64:
+            align = 64
+
+        return expr.const(align, dtype="int64")
+
+    def compute_storage_in_relay(self, shape, dtype):
+        dtype = TVMType(dtype)
+        els = op.prod(shape)
+        num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype)
+        num = num + expr.const(7, self.compute_dtype)
+        div = expr.const(8, self.compute_dtype)
+        return els * (num / div)
+
+    def compute_storage(self, tensor_type):
+        dtype = TVMType(tensor_type.dtype)
+        shape = [int(sh) for sh in tensor_type.shape]
+        size = 1
+        for sh in shape:
+            size *= sh
+        size *= (dtype.bits * dtype.lanes + 7) // 8
+        return expr.const(size, dtype=self.compute_dtype)
+
+    def make_static_allocation(self, scope, tensor_type, i):
+        """Allocate a tensor with a statically known shape."""
+        shape = [int(sh) for sh in tensor_type.shape]
+        if len(shape) == 0:
+            shape = expr.const(np.array([]).astype(
+                self.compute_dtype), dtype=self.compute_dtype)
+        else:
+            shape = expr.const(np.array(shape), dtype=self.compute_dtype)
+        size = self.compute_storage(tensor_type)
+        alignment = self.compute_alignment(tensor_type.dtype)
+        dtype = tensor_type.dtype
+        sto = scope.let(f"storage_{i}", self.alloc_storage(
+            size, alignment, dtype))
+        # TODO(@jroesch): There is a bug with typing based on the constant shape.
+        tensor = self.alloc_tensor(sto, shape, dtype, tensor_type.shape)
+        return scope.let(f"tensor_{i}", tensor)
+
+    def visit_let(self, let):
+        scope = ScopeBuilder()
+
+        self.scopes.append(scope)
+        while isinstance(let, expr.Let):
+            new_val = self.visit(let.value)
+            scope.let(let.var, new_val)
+            let = let.body
+
+        new_body = self.visit(let)
+        scope.ret(new_body)
+        self.scopes.pop()
+
+        return scope.get()
+
+    def visit_call(self, call):
+        if is_primitive(call):
+            # Because we are in ANF we do not need to visit the arguments.
+            scope = self.current_scope()
+            new_args = [self.visit(arg) for arg in call.args]
+            ins = expr.Tuple(new_args)
+            ret_type = call.checked_type
+
+            is_dynamic = ret_type.is_dynamic()
+            # TODO(@jroesch): restore this code, more complex then it seems
+            # for arg in call.args:
+            #     is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
+
+            if is_dynamic:
+                assert isinstance(ret_type, ty.TensorType)
+                shape_func_ins = []
+                engine = compile_engine.get()
+                cfunc = engine.lower_shape_func(call.op, self.target_host)
+                input_states = cfunc.shape_func_param_states
+
+                is_inputs = []
+                for i, (arg, state) in enumerate(zip(new_args, input_states)):
+                    state = int(state)
+                    # Pass Shapes
+                    if state == 2:
+                        sh_of = self.visit(self.shape_of(arg))
+                        shape_func_ins.append(
+                            scope.let(f"in_shape_{i}", sh_of))
+                        is_inputs.append(0)
+                    # Pass Inputs
+                    elif state == 1:
+                        new_arg = self.visit(arg)
+                        shape_func_ins.append(
+                            scope.let(f"in_shape_{i}", new_arg))
+                        is_inputs.append(1)
+                    # TODO(@jroesch): handle 3rd case
+                    else:
+                        raise Exception("unsupported shape function input state")
+
+                out_shapes = []
+                for i, out in enumerate(cfunc.outputs):
+                    tt = ty.TensorType(out.shape, out.dtype)
+                    alloc = self.make_static_allocation(scope, tt, i)
+                    alloc = scope.let(f"shape_func_out_{i}", alloc)
+                    out_shapes.append(alloc)
+
+                shape_call = self.shape_func(
+                    call.op,
+                    expr.Tuple(shape_func_ins),
+                    expr.Tuple(out_shapes), is_inputs)
+
+                scope.let("shape_func", shape_call)
+
+                out_types = []
+                out_types.append(call.checked_type)
+
+                storages = []
+                for out_shape, out_type in zip(out_shapes, out_types):
+                    size = self.compute_storage_in_relay(
+                        out_shape, out_type.dtype)
+                    alignment = self.compute_alignment(out_type.dtype)
+                    sto = scope.let(f"storage_{i}", self.alloc_storage(
+                        size, alignment, out_type.dtype))
+                    storages.append(sto)
+
+                outs = []
+                sh_ty_storage = zip(out_shapes, out_types, storages)
+                for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage):
+                    alloc = self.alloc_tensor(
+                        storage,
+                        out_shape,
+                        out_type.dtype,
+                        out_type.shape)
+                    alloc = scope.let(f"out_{i}", alloc)
+                    outs.append(alloc)
+
+                invoke = self.invoke_tvm(call.op, ins, expr.Tuple(outs))
+                scope.let("", invoke)
+                return outs[0]
+            else:
+                view = LinearizeRetType(ret_type)
+                out_tys = view.unpack()
+
+                outs = []
+                for i, out_ty in enumerate(out_tys):
+                    out = self.make_static_allocation(scope, out_ty, i)
+                    outs.append(out)
+
+                output = expr.Tuple(outs)
+                invoke = self.invoke_tvm(call.op, ins, output)
+                scope.let("", invoke)
+                return view.pack(output)
+        else:
+            return super().visit_call(call)
+
+
+@transform.function_pass(opt_level=0)
+class ManifestAlloc:
+    """The explicit pass wrapper around ManifestAlloc."""
+    def __init__(self, target_host):
+        self.target_host = target_host
+
+    def transform_function(self, func, mod, _):
+        # TODO(@jroesch): Is there a way to do one shot initilization?
+        # can we have def pass_init?
+        mod.import_from_std("core.rly")
+        ea = ManifestAllocPass(self.target_host)
+        func = ea.visit(func)
+        return func
+
+
+register_func("relay.transform.ManifestAlloc", ManifestAlloc)
index b8ef4df..a089cab 100644 (file)
@@ -28,6 +28,7 @@ from .transform import *
 from .algorithm import *
 from . import nn
 from . import annotation
+from . import memory
 from . import image
 from . import vision
 from . import contrib
diff --git a/python/tvm/relay/op/memory/__init__.py b/python/tvm/relay/op/memory/__init__.py
new file mode 100644 (file)
index 0000000..f3f7355
--- /dev/null
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import
+"""Operators for manipulating low level memory."""
+from __future__ import absolute_import as _abs
+from .memory import *
diff --git a/python/tvm/relay/op/memory/_make.py b/python/tvm/relay/op/memory/_make.py
new file mode 100644 (file)
index 0000000..cdf2dcc
--- /dev/null
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Constructor APIs"""
+from ...._ffi.function import _init_api
+
+_init_api("relay.op.memory._make", __name__)
diff --git a/python/tvm/relay/op/memory/memory.py b/python/tvm/relay/op/memory/memory.py
new file mode 100644 (file)
index 0000000..892ba88
--- /dev/null
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operators for manipulating low-level memory."""
+from __future__ import absolute_import as _abs
+from . import _make
+
+def invoke_tvm_op(func, inputs, outputs):
+    """Call a primitive function with the TVM operator calling convention.
+
+    Parameters
+    ----------
+    inputs : tvm.relay.Expr
+        A tuple of the inputs to pass to the TVM function.
+
+    outputs : tvm.relay.Expr
+        A tuple of the outputs to pass to the TVM function.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The invoke_tvm_op call node.
+    """
+    return _make.invoke_tvm_op(func, inputs, outputs)
+
+def alloc_tensor(storage, shape, dtype='float32', assert_shape=None):
+    """Allocate a tensor with the provided shape, and dtype.
+
+    Parameters
+    ----------
+    storage : tvm.relay.Expr
+        The storage to allocate from.
+
+    shape : tvm.relay.Expr
+        The shape of the tensor to allocate.
+
+    dtype: str
+        The dtype of the tensor.
+
+    assert_shape: Control the static shape when computed by dynamic shape expression.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The alloc_tensor expression.
+    """
+    return _make.alloc_tensor(storage, shape, dtype, assert_shape)
+
+def alloc_storage(size, alignment, dtype_hint='float32'):
+    """Allocate a piece of tensor storage.
+
+    Parameters
+    ----------
+    size : tvm.relay.Expr
+        The size of the allocation.
+    alignment : tvm.relay.Expr
+        The alignment of the allocation.
+    dtype : str
+        The dtype_hint of the allocation.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The alloc_storage expression.
+    """
+    return _make.alloc_storage(size, alignment, dtype_hint)
+
+def shape_func(func, inputs, outputs, dependent=False):
+    """Invoke the shape function of the passed function.
+
+    Parameters
+    ----------
+    func : tvm.relay.Expr
+        The primitive function from which to compute the shape function.
+    inputs : tvm.relay.Tuple
+        The tupled inputs.
+    outputs : tvm.relay.Tuple
+        The tupled outputs.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The shape function expression.
+    """
+    return _make.shape_func(func, inputs, outputs, dependent)
diff --git a/python/tvm/relay/std/core.rly b/python/tvm/relay/std/core.rly
new file mode 100644 (file)
index 0000000..6a3facc
--- /dev/null
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+v0.0.4
+
+extern type Storage
index 99692fd..356fe0b 100644 (file)
@@ -52,6 +52,9 @@ class Type(RelayNode):
         """
         return TypeCall(self, args)
 
+    def is_dynamic(self):
+        return _make.IsDynamic(self)
+
 @register_relay_node
 class TensorType(Type):
     """A concrete TensorType in Relay.
@@ -317,7 +320,6 @@ class RefType(Type):
     def __init__(self, value):
         self.__init_handle_by_constructor__(_make.RefType, value)
 
-
 def scalar_type(dtype):
     """Creates a scalar type.
 
index 993c4bf..083fa5d 100644 (file)
@@ -72,6 +72,10 @@ bool IsDynamic(const Type& ty) {
   return v.is_dyn;
 }
 
+// TODO(@jroesch): MOVE ME
+TVM_REGISTER_API("relay._make.IsDynamic")
+.set_body_typed(IsDynamic);
+
 Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
   // for now, we always use int32 shape when possible
   // even if the result of shape inference becomes int64.
@@ -775,6 +779,12 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
       return self->Lower(key);
     });
 
+TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
+.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
+    [](CompileEngine self, CCacheKey key) {
+      return self->LowerShapeFunc(key);
+    });
+
 TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
 .set_body_typed<PackedFunc(CompileEngine, CCacheKey)>(
     [](CompileEngine self, CCacheKey key) {
index 8c6dace..962728e 100644 (file)
@@ -458,7 +458,7 @@ class Interpreter :
       if (dattrs->debug_func.defined()) {
         dattrs->debug_func(interp_state);
       } else {
-        RELAY_DEBUG(interp_state);
+        RELAY_DEBUG_INTERP(interp_state);
       }
 
       return args[0];
@@ -479,7 +479,8 @@ class Interpreter :
     if (const auto* tuple_type = func->body->checked_type().as<TupleTypeNode>()) {
       arg_len += tuple_type->fields.size();
     } else {
-      CHECK(func->body->checked_type().as<TensorTypeNode>());
+      CHECK(func->body->checked_type().as<TensorTypeNode>())
+        << func->body->checked_type();
       arg_len += 1;
     }
     std::vector<TVMValue> values(arg_len);
index f59b117..1436a13 100644 (file)
@@ -48,6 +48,19 @@ namespace backend {
 inline const PackedFunc* GetPackedFunc(const std::string& func_name) {
   return tvm::runtime::Registry::Get(func_name);
 }
+
+/*!
+ * \brief Get a typed packed function.
+ *
+ * \param func_name
+ * \return const PackedFunc*
+ */
+template <typename R, typename... Args>
+inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) {
+  auto *pf = GetPackedFunc(func_name);
+  CHECK(pf != nullptr) << "can not find packed function";
+  return runtime::TypedPackedFunc<R(Args...)>(*pf);
+}
 /*!
  * \brief Convert type to string
  *
index fab01bd..3cfea5c 100644 (file)
@@ -31,6 +31,7 @@
 #include <tvm/logging.h>
 #include <tvm/relay/transform.h>
 #include <tvm/runtime/vm.h>
+#include <tvm/relay/attrs/memory.h>
 #include <topi/tags.h>
 #include <algorithm>
 #include <iostream>
@@ -44,6 +45,7 @@
 #include "../../../runtime/vm/naive_allocator.h"
 #include "../../backend/compile_engine.h"
 #include "../../pass/pass_util.h"
+#include "../../op/op_common.h"
 #include "compiler.h"
 
 namespace tvm {
@@ -54,6 +56,12 @@ namespace transform {
 Pass LambdaLift();
 Pass InlinePrimitives();
 
+Pass ManifestAlloc(Target target_host) {
+  auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
+  CHECK(f != nullptr) << "could not load memory allocation pass";
+  return (*f)(target_host);
+}
+
 }  // namespace transform
 
 namespace vm {
@@ -194,6 +202,39 @@ TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause>
   return else_branch;
 }
 
+std::vector<int64_t> ToAllocTensorShape64(NDArray shape) {
+  std::vector<int64_t> raw_shape;
+  DLTensor tensor = shape.ToDLPack()->dl_tensor;
+  CHECK_EQ(tensor.ndim, 1u);
+  CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
+
+  // TODO(@jroesch): we really need to standaridize the bit width of
+  // all of the shape manipulating code.
+  CHECK_EQ(tensor.dtype.bits, 64) << "found " << tensor.dtype.bits;
+  int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
+  for (auto i = 0; i < tensor.shape[0]; i++) {
+    raw_shape.push_back(int_ptr[i]);
+  }
+  return raw_shape;
+}
+
+
+std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
+  std::vector<int64_t> raw_shape;
+  DLTensor tensor = shape.ToDLPack()->dl_tensor;
+  CHECK_EQ(tensor.ndim, 1u);
+  CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
+
+  // TODO(@jroesch): we really need to standaridize the bit width of
+  // all of the shape manipulating code.
+  CHECK_LE(tensor.dtype.bits, 32) << "found " << tensor.dtype.bits;
+  int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
+  for (auto i = 0; i < tensor.shape[0]; i++) {
+    raw_shape.push_back(static_cast<int64_t>(int_ptr[i]));
+  }
+  return raw_shape;
+}
+
 class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
  public:
   VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
@@ -248,13 +289,12 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
       case Opcode::LoadConsti:
       case Opcode::Invoke:
       case Opcode::AllocClosure:
+      case Opcode::AllocStorage:
       case Opcode::Move:
       case Opcode::InvokeClosure:
         last_register_ = instr.dst;
         break;
       case Opcode::InvokePacked:
-        last_register_ = instr.packed_args[instr.arity - 1];
-        break;
       case Opcode::If:
       case Opcode::Ret:
       case Opcode::Goto:
@@ -302,7 +342,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
   }
 
   void VisitExpr_(const LetNode* let_node) {
-    DLOG(INFO) << AsText(let_node->value);
+    DLOG(INFO) << PrettyPrint(let_node->value);
     this->VisitExpr(let_node->value);
     var_register_map_.insert({let_node->var, this->last_register_});
     this->VisitExpr(let_node->body);
@@ -369,100 +409,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
     this->last_register_ = true_register;
   }
 
-  Index EmitGetShape(const TensorTypeNode* ttype, Index reg) {
-    bool const_shape = true;
-    std::vector<int64_t> shape;
-    for (auto dim : ttype->shape) {
-      if (auto kdim = dim.as<IntImm>()) {
-        shape.push_back(kdim->value);
-      } else {
-        const_shape = false;
-      }
-    }
-    if (const_shape) {
-      int64_t ndim = shape.size();
-      DLContext cpu_ctx;
-      cpu_ctx.device_type = kDLCPU;
-      cpu_ctx.device_id = 0;
-      NDArray shape_tensor;
-      if (ndim == 0) {
-        shape_tensor = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
-      } else {
-        shape_tensor = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
-        int64_t* dims = reinterpret_cast<int64_t*>(shape_tensor->data);
-        for (size_t i = 0; i < shape.size(); ++i) {
-          dims[i] = shape[i];
-        }
-      }
-      size_t konst_idx = context_->constants.size();
-      context_->constants.push_back(shape_tensor);
-      Emit(Instruction::LoadConst(konst_idx, NewRegister()));
-      return last_register_;
-    }
-    // For dynamic shape, we need insert shape_of op to get its shape at runtime
-    auto attrs = make_node<ShapeOfAttrs>();
-    attrs->dtype = Int(64);
-    static const Op& op = Op::Get("shape_of");
-    auto input = VarNode::make("input", GetRef<Type>(ttype));
-    auto expr = CallNode::make(op, {input}, Attrs(attrs), {});
-    auto func = FunctionNode::make({input}, expr, IncompleteTypeNode::make(Kind::kType), {});
-    auto mod = ModuleNode::make({}, {});
-    auto main_gv = GlobalVarNode::make("main");
-    mod->Add(main_gv, func);
-    func = mod->Lookup(main_gv);
-
-    // shape_of op has to be run on the host target
-    // TODO(@icemelon9): handle heterogeneous target, such as cuda
-    auto key = CCacheKeyNode::make(func, target_host_);
-    auto cfunc = engine_->Lower(key);
-    auto op_index = -1;
-    if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
-      op_index = context_->cached_funcs.size();
-      context_->cached_funcs.push_back(cfunc);
-      context_->seen_funcs[cfunc->funcs[0]] = op_index;
-    } else {
-      op_index = context_->seen_funcs[cfunc->funcs[0]];
-    }
-    std::vector<Index> arg_regs{reg};
-    int64_t ndim = ttype->shape.size();
-    if (ndim == 0) {
-      Emit(Instruction::AllocTensor({}, Int(64), NewRegister()));
-    } else {
-      Emit(Instruction::AllocTensor({ndim}, Int(64), NewRegister()));
-    }
-    Index shape_reg = last_register_;
-    arg_regs.push_back(shape_reg);
-    Emit(Instruction::InvokePacked(op_index, 2, 1, arg_regs));
-    return shape_reg;
-  }
-
-  std::vector<Index> EmitShapeFunc(const Type& ret_type, const Function& func,
-                                   const std::vector<Index>& unpacked_arg_regs) {
-    // Find the mapping from params to registers
-    int idx = 0;
-    std::vector<std::vector<Index>> param_regs;
-    std::vector<std::vector<const TensorTypeNode*>> param_types;
-    for (auto param : func->params) {
-      auto ty = param->checked_type();
-      std::vector<Index> regs;
-      std::vector<const TensorTypeNode*> types;
-      if (auto ttype = ty.as<TensorTypeNode>()) {
-        regs.push_back(unpacked_arg_regs[idx++]);
-        types.push_back(ttype);
-      } else if (const auto tuple_ty = ret_type.as<TupleTypeNode>()) {
-        for (size_t j = 0; j < tuple_ty->fields.size(); ++j, ++idx) {
-          regs.push_back(unpacked_arg_regs[idx]);
-          auto ttype = tuple_ty->fields[j].as<TensorTypeNode>();
-          CHECK(ttype);
-          types.push_back(ttype);
-        }
-      } else {
-        LOG(FATAL) << "unsupported parameter type " << ty;
-      }
-      param_regs.push_back(regs);
-      param_types.push_back(types);
-    }
-
+  void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
     // Lower shape function
     auto key = CCacheKeyNode::make(func, target_host_);
     auto cfunc = engine_->LowerShapeFunc(key);
@@ -476,125 +423,60 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
     }
 
     // Prepare input and output registers
-    std::vector<Index> shape_func_args;
-    std::vector<Index> shape_regs;
-    for (size_t i = 0; i < func->params.size(); ++i) {
-      int state = cfunc->shape_func_param_states[i]->value;
-      if (state & kNeedInputData) {
-        for (auto reg : param_regs[i]) {
-          // TODO(@icemelon9): Need to copy data here for heterogeneous exec
-          shape_func_args.push_back(reg);
-        }
-      }
-      if (state & kNeedInputShape) {
-        for (size_t j = 0; j < param_regs[i].size(); ++j) {
-          shape_func_args.push_back(EmitGetShape(param_types[i][j], param_regs[i][j]));
-        }
-      }
-    }
-    for (auto t : cfunc->outputs) {
-      int64_t ndim = t->shape[0].as<IntImm>()->value;
-      Emit(Instruction::AllocTensor({ndim}, t->dtype, NewRegister()));
-      shape_func_args.push_back(last_register_);
-      shape_regs.push_back(last_register_);
+    std::vector<Index> argument_registers;
+    for (auto input : inputs) {
+      auto reg = var_register_map_.find(Downcast<Var>(input));
+      CHECK(reg != var_register_map_.end())
+        << "internal error: all variables should be in the register mapping";
+      argument_registers.push_back(reg->second);
     }
 
-    int arity = shape_func_args.size();
-    int ret_count = shape_regs.size();
-    Emit(Instruction::InvokePacked(op_index, arity, ret_count, shape_func_args));
-
-    // Alloc return tensors given the shape regs
-    std::vector<DataType> ret_dtypes;
-    if (const auto* tuple_type = ret_type.as<TupleTypeNode>()) {
-      for (auto field : tuple_type->fields) {
-        const TensorTypeNode* tty = field.as<TensorTypeNode>();
-        CHECK(tty);
-        ret_dtypes.push_back(tty->dtype);
-      }
-    } else {
-      auto tty = ret_type.as<TensorTypeNode>();
-      CHECK(tty);
-      ret_dtypes.push_back(tty->dtype);
-    }
-    std::vector<Index> ret_regs;
-    for (size_t i = 0; i < shape_regs.size(); ++i) {
-      Emit(Instruction::AllocTensorReg(shape_regs[i], ret_dtypes[i], NewRegister()));
-      ret_regs.push_back(last_register_);
+    for (auto output : outputs) {
+      auto reg = var_register_map_.find(Downcast<Var>(output));
+      CHECK(reg != var_register_map_.end())
+        << "internal error: all variables should be in the register mapping";
+      argument_registers.push_back(reg->second);
     }
-    return ret_regs;
-  }
-
-  std::vector<Index> AllocReturnType(const Type& ret_type, const Function& func,
-                                     const std::vector<Index>& unpacked_arg_regs) {
-    auto op = func->body.as<CallNode>()->op;
-    // 1. If either func param types or ret type is dynamic, we need to insert
-    // shape func to perform type checking at runtime.
-    // 2. We skip the shape_of function since currently Relay doesn't support
-    // dynamic rank tensor.
-    if (op != Op::Get("shape_of") && IsDynamic(func->checked_type())) {
-      return EmitShapeFunc(ret_type, func, unpacked_arg_regs);
-    }
-    std::vector<Index> ret_regs;
-    auto alloc_tensor = [&](const TensorTypeNode* ttype) {
-      const TensorType& tensor_type = GetRef<TensorType>(ttype);
-      std::vector<int64_t> shape;
-      for (auto dim : tensor_type->shape) {
-        shape.push_back(Downcast<tvm::Integer>(dim)->value);
-      }
-      Emit(Instruction::AllocTensor(shape, Type2TVMType(tensor_type->dtype), NewRegister()));
-      ret_regs.push_back(last_register_);
-    };
-    if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
-      alloc_tensor(ttype);
-    } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
-      for (auto field : ttype->fields) {
-        alloc_tensor(field.as<TensorTypeNode>());
-      }
-    } else {
-      LOG(FATAL) << "Unsupported return value type";
-    }
-    return ret_regs;
-  }
-
-  void EmitInvokePrimitive(const Function& func,
-                           const std::vector<Index>& arg_registers,
-                           const Type& ret_type) {
-    std::vector<Index> unpacked_arg_regs;
-    std::vector<Instruction> allocs;
-
-    // Arity calculation must flatten tuples.
-    size_t arity = 0;
-    CHECK_EQ(func->params.size(), arg_registers.size());
-    for (size_t i = 0; i < func->params.size(); i++) {
-      auto ty = func->params[i]->checked_type();
-      if (ty.as<TensorTypeNode>()) {
-        unpacked_arg_regs.push_back(arg_registers[i]);
-        arity += 1;
-      } else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
-        for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
-          const auto& field = tuple_ty->fields[f];
-          CHECK(field.as<TensorTypeNode>())
-            << "only supports non-nested tuples currently "
-            << "found " << field;
-          auto dst =  NewRegister();
-          Emit(Instruction::GetField(arg_registers[i], f, dst));
-          unpacked_arg_regs.push_back(dst);
-        }
-        arity += tuple_ty->fields.size();
-      } else {
-        LOG(FATAL) << "unsupported parameter type " << ty;
-      }
+
+    Emit(Instruction::InvokePacked(op_index,
+      argument_registers.size(),
+      outputs.size(),
+      argument_registers));
+  }
+
+  void EmitInvokeTVMOp(const Function& func,
+                       const Expr& inputs,
+                       const Expr& outputs) {
+    std::vector<Index> argument_registers;
+
+    CHECK(func->IsPrimitive())
+      << "internal error: invoke_tvm_op requires the first argument to be a relay::Function";
+
+    auto input_tuple = inputs.as<TupleNode>();
+    CHECK(input_tuple)
+      << "internal error: invoke_tvm_op inputs must be a tuple,"
+      << "please file a bug in the memory manifestation pass";
+
+    auto output_tuple = outputs.as<TupleNode>();
+    CHECK(output_tuple)
+      << "internal error: invoke_tvm_op outputs must be a tuple,"
+      << "please file a bug in the memory manifestation pass";
+
+    for (auto input : input_tuple->fields) {
+      auto reg = var_register_map_.find(Downcast<Var>(input));
+      CHECK(reg != var_register_map_.end())
+        << "internal error: all variables should be in the register mapping";
+      argument_registers.push_back(reg->second);
     }
 
-    auto ret_regs = AllocReturnType(ret_type, func, unpacked_arg_regs);
-    size_t return_count = ret_regs.size();
-    arity += return_count;
-    for (auto reg : ret_regs) {
-      unpacked_arg_regs.push_back(reg);
+    for (auto output : output_tuple->fields) {
+      auto reg = var_register_map_.find(Downcast<Var>(output));
+      CHECK(reg != var_register_map_.end())
+        << "internal error: all variables should be in the register mapping";
+      argument_registers.push_back(reg->second);
     }
 
     // Next generate the invoke instruction.
-    CHECK(func->IsPrimitive());
     Target target;
     if (targets_.size() == 1) {
       // homogeneous execution.
@@ -605,8 +487,10 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
       // heterogeneous execution.
       LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
     }
+
     auto key = CCacheKeyNode::make(func, target);
     auto cfunc = engine_->Lower(key);
+
     // TODO(jroesch): support lowered funcs for multiple targets
     CHECK_EQ(cfunc->funcs.size(), 1);
     auto op_index = -1;
@@ -618,19 +502,99 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
       op_index = context_->seen_funcs[cfunc->funcs[0]];
     }
 
-    Emit(Instruction::InvokePacked(op_index, arity, return_count, unpacked_arg_regs));
-
-    if (return_count > 1) {
-      // return value is a tuple, we need to create a tuple
-      std::vector<Index> fields_registers;
-      for (size_t i = arity - return_count; i < arity; ++i) {
-        fields_registers.push_back(unpacked_arg_regs[i]);
-      }
-      Emit(Instruction::AllocADT(0, return_count, fields_registers, NewRegister()));
-    }
+    Emit(Instruction::InvokePacked(op_index,
+      argument_registers.size(),
+      output_tuple->fields.size(),
+      argument_registers));
   }
 
   void VisitExpr_(const CallNode* call_node) {
+    Expr op = call_node->op;
+
+    // First we handle the case in which we are using an opaque
+    // operator used to define a sub-dialect, such as memory
+    // allocation operations.
+    if (op.as<OpNode>()) {
+      OpMatch<void> matcher;
+      matcher.Match("memory.invoke_tvm_op",
+        [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+          CHECK_EQ(args.size(), 3);
+          EmitInvokeTVMOp(Downcast<Function>(args[0]), args[1], args[2]);
+      }).Match("memory.alloc_tensor",
+        [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+          CHECK_EQ(args.size(), 2);
+
+          // Get the attributes.
+          auto alloc_attrs = attrs.as<AllocTensorAttrs>();
+          CHECK(alloc_attrs != nullptr)
+              << "must be the alloc tensor attrs";
+          auto dtype = alloc_attrs->dtype;
+
+          // The storage will be passed dynamically.
+          this->VisitExpr(args[0]);
+          auto storage_register = last_register_;
+
+          // If the shape is constant then we will emit a static tensor allocation instruction.
+          auto const_shape = args[1].as<ConstantNode>();
+
+          if (const_shape) {
+            NDArray shape = const_shape->data;
+            std::vector<int64_t> raw_shape;
+            DLTensor tensor = shape.ToDLPack()->dl_tensor;
+            // TODO(@jroesch): we need to get an RFC done to standarize this
+            if (tensor.dtype.bits == 64) {
+              raw_shape = ToAllocTensorShape64(shape);
+            } else if (tensor.dtype.bits == 32) {
+              raw_shape = ToAllocTensorShape32(shape);
+            } else {
+              LOG(FATAL) << "unsupported bitwidth: " << tensor.dtype.bits;
+            }
+
+            // Add context field.
+            Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister()));
+          } else {
+            this->VisitExpr(args[1]);
+            auto shape_register = last_register_;
+            Emit(Instruction::AllocTensorReg(
+              storage_register,
+              shape_register,
+              dtype,
+              NewRegister()));
+          }
+      }).Match("memory.alloc_storage",
+        [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+          CHECK_EQ(args.size(), 2);
+          // Compute the size of the allocation.
+          this->VisitExpr(args[0]);
+          auto size_register = last_register_;
+
+          this->VisitExpr(args[1]);
+          auto alignment_register = last_register_;
+
+          // Get the dtype hint from the attributes.
+          auto alloc_attrs = attrs.as<AllocTensorAttrs>();
+          CHECK(alloc_attrs != nullptr)
+              << "must be the alloc tensor attrs";
+          auto dtype = alloc_attrs->dtype;
+
+          Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, NewRegister()));
+      }).Match("memory.shape_func",
+        [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+          CHECK_EQ(args.size(), 3);
+          auto shape_func = Downcast<Function>(args[0]);
+          auto inputs = Downcast<Tuple>(args[1]);
+          auto outputs = Downcast<Tuple>(args[2]);
+          EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
+      }).Match("memory.kill",
+        [](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+          LOG(FATAL) << "memory.kill is not yet supported";
+      });
+      matcher(GetRef<Call>(call_node));
+      return;
+    }
+
+    // In the case its not one of these specialized operators we will generate code
+    // for one of the "standard" cases.
     std::vector<Index> args_registers;
 
     for (auto arg : call_node->args) {
@@ -638,18 +602,16 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
       args_registers.push_back(last_register_);
     }
 
-    Expr op = call_node->op;
-
-    if (auto func_node = op.as<FunctionNode>()) {
-      CHECK(func_node->IsPrimitive());
-      EmitInvokePrimitive(GetRef<Function>(func_node), args_registers, call_node->checked_type());
-    } else if (auto global_node = op.as<GlobalVarNode>()) {
+    if (auto global_node = op.as<GlobalVarNode>()) {
+      // In the case we are invoking a global we need to find its
+      // global ID, and then check whether it is closure invocation
+      // or whether it is a standard global, and emit the correct
+      // calling convention.
       auto global = GetRef<GlobalVar>(global_node);
       auto it = context_->global_map.find(global);
       CHECK(it != context_->global_map.end());
       DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
                       << " with func_index=" << it->second;
-
       auto func = context_->module->Lookup(global);
       if (IsClosure(func)) {
         auto arity = func->params.size();
@@ -658,14 +620,21 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
         Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
       }
     } else if (auto constructor_node = op.as<ConstructorNode>()) {
+      // In the constructor case, we simply need to find its tag
+      // and emit a call to allocate the data structure.
       auto constructor = GetRef<Constructor>(constructor_node);
       Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
                                       NewRegister()));
     } else if (auto var_node = op.as<VarNode>()) {
+      // If we are calling a variable, it must be the case that it is a closure so we
+      // emit invoke closure here.
       VisitExpr(GetRef<Var>(var_node));
       Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
     } else {
-      LOG(FATAL) << "unsupported case in vm compiler: " << op;
+      // Finally if there are any other cases this is a bug.
+      LOG(FATAL) << "internal error: unreachable code,"
+                 << "should be transformed away by previous passes"
+                 << PrettyPrint(GetRef<Expr>(call_node));
     }
   }
 
@@ -836,7 +805,6 @@ relay::Function VMCompiler::BindParamsByName(
   return ret;
 }
 
-
 void VMCompiler::Compile(Module mod,
                          const TargetsMap& targets,
                          const tvm::Target& target_host) {
@@ -852,8 +820,7 @@ void VMCompiler::Compile(Module mod,
   targets_ = targets;
   target_host_ = target_host;
 
-  // Run some optimizations first, this code should
-  // be moved to pass manager.
+  // Run the optimizations necessary to target the VM.
   context_.module = OptimizeModule(mod, targets_);
 
   // Populate the global map.
@@ -885,7 +852,7 @@ void VMCompiler::Compile(Module mod,
 
   // populate constants
   for (auto data : context_.constants) {
-    exec_->constants.push_back(runtime::vm::Tensor(data));
+    exec_->constants.push_back(vm::Tensor(data));
   }
 
   LibraryCodegen();
@@ -942,6 +909,15 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets)
   pass_seqs.push_back(transform::LambdaLift());
   pass_seqs.push_back(transform::InlinePrimitives());
 
+  // Manifest the allocations.
+  pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
+  // Compute away possibly introduced constant computation.
+  pass_seqs.push_back(transform::FoldConstant());
+  // Fuse the shape functions.
+  pass_seqs.push_back(transform::FuseOps());
+  // Manifest the allocations needed for the shape functions.
+  pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
+
   transform::Sequential seq(pass_seqs);
   transform::PassContext pass_ctx = PassContext::Current();
   // TODO(wweic): Support heterogenous execution
index c36b4c8..672cdab 100644 (file)
@@ -355,5 +355,11 @@ TVM_REGISTER_API("relay._expr.TempExprRealize")
   return temp->Realize();
 });
 
+TVM_REGISTER_API("relay._expr.FunctionSetAttr")
+.set_body_typed<Function(Function, std::string, NodeRef)>(
+  [](Function func, std::string name, NodeRef ref) {
+    return FunctionSetAttr(func, name, ref);
+});
+
 }  // namespace relay
 }  // namespace tvm
index cd5b1e6..8a90f14 100644 (file)
@@ -35,13 +35,16 @@ using tvm::IRPrinter;
 using namespace runtime;
 
 Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
-                        tvm::Map<GlobalTypeVar, TypeData> global_type_defs) {
+                        tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
+                        std::unordered_set<std::string> imports
+                        ) {
   auto n = make_node<ModuleNode>();
   n->functions = std::move(global_funcs);
   n->type_definitions = std::move(global_type_defs);
   n->global_type_var_map_ = {};
   n->global_var_map_ = {};
   n->constructor_tag_map_ = {};
+  n->import_set_ = imports;
 
   for (const auto& kv : n->functions) {
     // set global var map
@@ -283,9 +286,9 @@ Module ModuleNode::FromExpr(
 }
 
 void ModuleNode::Import(const std::string& path) {
-  LOG(INFO) << "Importing: " << path;
   if (this->import_set_.count(path) == 0) {
     this->import_set_.insert(path);
+    DLOG(INFO) << "Importing: " << path;
     std::fstream src_file(path, std::fstream::in);
     std::string file_contents {
       std::istreambuf_iterator<char>(src_file),
@@ -302,6 +305,10 @@ void ModuleNode::ImportFromStd(const std::string& path) {
   return this->Import(std_path + "/" + path);
 }
 
+std::unordered_set<std::string> ModuleNode::Imports() const {
+  return this->import_set_;
+}
+
 Module FromText(const std::string& source, const std::string& source_name) {
   auto* f = tvm::runtime::Registry::Get("relay.fromtext");
   CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
@@ -312,7 +319,10 @@ Module FromText(const std::string& source, const std::string& source_name) {
 TVM_REGISTER_NODE_TYPE(ModuleNode);
 
 TVM_REGISTER_API("relay._make.Module")
-.set_body_typed(ModuleNode::make);
+.set_body_typed<Module(tvm::Map<GlobalVar, Function>, tvm::Map<GlobalTypeVar, TypeData>)>(
+[](tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalTypeVar, TypeData> types) {
+  return ModuleNode::make(funcs, types, {});
+});
 
 TVM_REGISTER_API("relay._module.Module_Add")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
index 5a8ad33..f79208b 100644 (file)
  * \brief Registration of annotation operators.
  */
 
+#include <tvm/expr.h>
 #include <tvm/relay/attrs/annotation.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
 #include <topi/elemwise.h>
 
-#include "../type_relations.h"
 #include "../../pass/alter_op_layout.h"
+#include "../type_relations.h"
 
 namespace tvm {
 namespace relay {
index 589fb29..229e6b6 100644 (file)
@@ -27,6 +27,7 @@
  * used as "barrier" to avoid fusing operators belonging to differen devices.
  */
 
+#include <tvm/expr.h>
 #include <tvm/relay/attrs/device_copy.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/op.h>
diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc
new file mode 100644 (file)
index 0000000..8f6dea0
--- /dev/null
@@ -0,0 +1,348 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/memory/memory.cc
+ * \brief Operators for manifest shape-aware memory allocation in Relay.
+ */
+
+#include <topi/elemwise.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/attrs/memory.h>
+
+#include "../op_common.h"
+#include "../../pass/alter_op_layout.h"
+#include "../type_relations.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);
+TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);
+
+// The passing value in attrs and args doesn't seem super great.
+// We should consider a better solution, i.e the type relation
+// being able to see the arguments as well?
+TVM_REGISTER_API("relay.op.memory._make.alloc_storage")
+    .set_body_typed<Expr(Expr, Expr, DataType)>([](Expr size, Expr alignment, DataType dtype) {
+      auto attrs = make_node<AllocTensorAttrs>();
+      attrs->dtype = dtype;
+      static const Op& op = Op::Get("memory.alloc_storage");
+      return CallNode::make(op, {size, alignment}, Attrs(attrs), {});
+    });
+
+bool AllocStorageRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                     const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3u);
+  auto size_type = types[0];
+  auto tensor_type = size_type.as<TensorTypeNode>();
+  CHECK(tensor_type != nullptr);
+  CHECK_EQ(tensor_type->dtype, Int(64));
+  CHECK_EQ(tensor_type->shape.size(), 0);
+  auto align_type = types[1];
+  auto align_ttype = align_type.as<TensorTypeNode>();
+  CHECK(align_ttype != nullptr);
+  CHECK_EQ(align_ttype->dtype, Int(64));
+  CHECK_EQ(align_ttype->shape.size(), 0);
+  auto mod = reporter->GetModule();
+  CHECK(mod.defined());
+  auto storage_name = mod->GetGlobalTypeVar("Storage");
+  auto storage = TypeCallNode::make(storage_name, {});
+  reporter->Assign(types[2], storage);
+  return true;
+}
+
+RELAY_REGISTER_OP("memory.alloc_storage")
+    .describe(R"code(Explicitly allocate storage to be used by tensors.)code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .add_argument("size", "Tensor", "The size of the storage to allocate.")
+    .add_argument("alignment", "Tensor", "The alignment of the storage.")
+    .add_type_rel("AllocStorage", AllocStorageRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+    .set_attr<FTVMCompute>("FTVMCompute",
+                           [](const Attrs& attrs, const Array<Tensor>& inputs,
+                              const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                             return {topi::identity(inputs[0])};
+                           });
+
+TVM_REGISTER_API("relay.op.memory._make.alloc_tensor")
+    .set_body_typed<Expr(Expr, Expr, DataType, Array<IndexExpr> assert_shape)>(
+        [](Expr storage, tvm::relay::Expr shape, DataType dtype, Array<IndexExpr> assert_shape) {
+          auto attrs = make_node<AllocTensorAttrs>();
+          attrs->dtype = dtype;
+          if (assert_shape.defined()) {
+            attrs->assert_shape = assert_shape;
+          } else {
+            attrs->const_shape = Downcast<Constant>(shape);
+          }
+          static const Op& op = Op::Get("memory.alloc_tensor");
+          return CallNode::make(op, {storage, shape}, Attrs(attrs), {});
+        });
+
+std::vector<int64_t> FromConstShape(Constant konst) {
+  runtime::NDArray shape = konst->data;
+  std::vector<int64_t> raw_shape;
+  DLTensor tensor = shape.ToDLPack()->dl_tensor;
+  CHECK_EQ(tensor.ndim, 1u);
+  CHECK_EQ(tensor.dtype.code, 0U)
+    << "found " << tensor.dtype.code;
+
+  CHECK(tensor.dtype.bits == 64 || tensor.dtype.bits == 32)
+    << "found " << static_cast<int>(tensor.dtype.bits);
+
+  if (tensor.dtype.bits == 32) {
+    const int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
+    for (auto i = 0; i < tensor.shape[0]; i++) {
+      raw_shape.push_back(int_ptr[i]);
+    }
+  } else if (tensor.dtype.bits == 64) {
+    const int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
+    for (auto i = 0; i < tensor.shape[0]; i++) {
+      raw_shape.push_back(int_ptr[i]);
+    }
+  }
+
+  return raw_shape;
+}
+
+bool AllocTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                    const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3u);
+  auto alloc_attrs = attrs.as<AllocTensorAttrs>();
+  CHECK(alloc_attrs != nullptr) << "must be alloc_tensor attributes";
+  // First argument should be storage.
+  auto mod = reporter->GetModule();
+  CHECK(mod.defined());
+  auto storage_name = mod->GetGlobalTypeVar("Storage");
+  auto storage = relay::TypeCallNode::make(storage_name, {});
+  reporter->Assign(types[0], storage);
+  // Second argument should be shape tensor.
+  auto tt = types[1].as<TensorTypeNode>();
+  CHECK(tt != nullptr) << "must be tensor type";
+  auto rank = tt->shape[0].as<tvm::IntImm>();
+  CHECK(rank != nullptr);
+  auto dims = rank->value;
+
+  // Constant node case.
+  Type alloc_type;
+  if (alloc_attrs->const_shape.defined()) {
+    auto con = alloc_attrs->const_shape;
+    auto sh = FromConstShape(con);
+    Array<IndexExpr> out_shape;
+    for (auto i = 0u; i < dims; i++) {
+      out_shape.push_back(tvm::Integer(sh[i]));
+    }
+    alloc_type = TensorTypeNode::make(out_shape, alloc_attrs->dtype);
+  } else {
+    CHECK(alloc_attrs->assert_shape.defined())
+        << "the assert_shape must be set when const_shape is not";
+    alloc_type = TensorTypeNode::make(alloc_attrs->assert_shape, alloc_attrs->dtype);
+    return true;
+  }
+
+  reporter->Assign(types[2], alloc_type);
+  return true;
+}
+
+RELAY_REGISTER_OP("memory.alloc_tensor")
+    .describe(R"code(Explicitly allocate storage to be used by tensors.)code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .add_argument("storage", "Storage", "The storage to allocate from.")
+    .add_argument("shape", "Tensor", "The shape of the tensor to allocate.")
+    .add_type_rel("AllocTensor", AllocTensorRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+    .set_attr<FTVMCompute>("FTVMCompute",
+                           [](const Attrs& attrs, const Array<Tensor>& inputs,
+                              const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                             return {topi::identity(inputs[0])};
+                           });
+
+bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                    const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 4u);
+  auto func_type = types[0].as<FuncTypeNode>();
+  CHECK(func_type != nullptr) << "input must be operator with known type";
+  auto input_type = types[1].as<TupleTypeNode>();
+  auto output_type = types[2].as<TupleTypeNode>();
+  CHECK(input_type != nullptr)
+      << "internal invariant violated: invoke_tvm_op inputs must be a tuple";
+  CHECK(output_type != nullptr)
+      << "internal invariant violated: invoke_tvm_op outputs must be a tuple";
+  Type ex_output;
+  if (func_type->ret_type.as<TensorTypeNode>()) {
+    ex_output = TupleTypeNode::make({func_type->ret_type});
+  } else {
+    CHECK(func_type->ret_type.as<TupleTypeNode>()) << "should be tuple type";
+    ex_output = func_type->ret_type;
+  }
+  auto ex_input = TupleTypeNode::make(func_type->arg_types);
+  reporter->Assign(ex_input, GetRef<Type>(input_type));
+  reporter->Assign(ex_output, GetRef<Type>(output_type));
+  reporter->Assign(types[3], TupleTypeNode::make({}));
+  return true;
+}
+
+TVM_REGISTER_API("relay.op.memory._make.invoke_tvm_op")
+    .set_body_typed<Expr(Expr, Expr, Expr)>(
+        [](Expr func, Expr inputs, Expr outputs) {
+          return CallNode::make(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
+        });
+
+RELAY_REGISTER_OP("memory.invoke_tvm_op")
+    .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
+    .set_num_inputs(3)
+    .add_argument("op", "Function", "The operation to call")
+    .add_argument("ins", "Tuple", "The input tensors.")
+    .add_argument("outs", "Tuple", "The output tensors.")
+    .add_type_rel("InvokeTVMOP", InvokeTVMOPRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+    .set_attr<FTVMCompute>("FTVMCompute",
+                           [](const Attrs& attrs, const Array<Tensor>& inputs,
+                              const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                             return {topi::identity(inputs[0])};
+                           });
+
+bool KillRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+             const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2u);
+  // TODO(@jroesch): should only support tensors.
+  reporter->Assign(types[1], TupleTypeNode::make({}));
+  return true;
+}
+
+RELAY_REGISTER_OP("memory.kill")
+    .describe(R"code(Mark a tensor for release to the allocator.)code" TVM_ADD_FILELINE)
+    .set_num_inputs(3)
+    .add_argument("to_free", "Tensor", "The tensor to free.")
+    .add_type_rel("Kill", KillRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+    .set_attr<FTVMCompute>("FTVMCompute",
+                           [](const Attrs& attrs, const Array<Tensor>& inputs,
+                              const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                             return {topi::identity(inputs[0])};
+                           });
+
+TVM_REGISTER_API("relay.op.memory._make.shape_func")
+    .set_body_typed<Expr(Expr, Expr, Expr, Array<tvm::Integer>)>(
+      [](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
+      static const Op& op = Op::Get("memory.shape_func");
+      auto attrs = make_node<ShapeFuncAttrs>();
+      attrs->is_input = is_input;
+      return CallNode::make(op, {func, inputs, outputs}, Attrs(attrs), {});
+    });
+
+static void FlattenTypeAux(const Type& type, std::vector<TensorType>* out) {
+  if (auto tt = type.as<TensorTypeNode>()) {
+    out->push_back(GetRef<TensorType>(tt));
+  } else if (auto tuple_ty = type.as<TupleTypeNode>()) {
+    for (auto field : tuple_ty->fields) {
+      FlattenTypeAux(field, out);
+    }
+  } else {
+    LOG(FATAL) << "unsupported " << type;
+  }
+}
+
+std::vector<TensorType> FlattenType(const Type& type) {
+  std::vector<TensorType> out;
+  FlattenTypeAux(type, &out);
+  return out;
+}
+
+Expr PackByType(const Type& t, const Array<Expr>& exprs) {
+  LOG(FATAL) << "NYI";
+  return Expr();
+}
+
+bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 4u);
+  auto shape_func_attrs = attrs.as<ShapeFuncAttrs>();
+  CHECK(shape_func_attrs != nullptr) << "Internal compiler error";
+
+  auto func_type = types[0].as<FuncTypeNode>();
+  CHECK(func_type != nullptr);
+
+  auto tuple = TupleTypeNode::make(func_type->arg_types);
+  auto in_types = FlattenType(tuple);
+  auto out_types = FlattenType(func_type->ret_type);
+
+  Array<Type> shape_func_ins, shape_func_outs;
+  for (size_t i = 0; i < in_types.size(); i++) {
+    auto in_type = in_types[i];
+
+    if (shape_func_attrs->is_input[i]) {
+      shape_func_ins.push_back(in_type);
+    } else {
+      auto shape = RankShape(in_type->shape);
+      shape_func_ins.push_back(TensorTypeNode::make(shape, Int(64)));
+    }
+  }
+
+  for (auto out_type : out_types) {
+    auto rank_shape = RankShape(out_type->shape);
+    shape_func_outs.push_back(TensorTypeNode::make(rank_shape, Int(64)));
+  }
+
+  auto input_type = TupleTypeNode::make(shape_func_ins);
+  auto output_type = TupleTypeNode::make(shape_func_outs);
+
+  reporter->Assign(types[1], input_type);
+  reporter->Assign(types[2], output_type);
+  reporter->Assign(types[3], TupleTypeNode::make({}));
+
+  return true;
+}
+
+RELAY_REGISTER_OP("memory.shape_func")
+    .describe(R"code(Get the shape of a tensor.)code" TVM_ADD_FILELINE)
+    .set_num_inputs(3)
+    .add_argument("tensor", "Tensor", "The tensor to retrieve the shape for.")
+    .add_type_rel("ShapeFuncRel", ShapeFuncRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+    .set_attr<FTVMCompute>("FTVMCompute",
+                           [](const Attrs& attrs, const Array<Tensor>& inputs,
+                              const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                             return {topi::identity(inputs[0])};
+                           });
+
+}  // namespace relay
+}  // namespace tvm
index 9dc249a..281e51e 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -30,6 +30,8 @@
 #include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
 #include <vector>
+#include <string>
+#include <unordered_map>
 #include "type_relations.h"
 #include "../pass/alter_op_layout.h"
 
@@ -105,6 +107,50 @@ namespace relay {
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout",         \
                                    BinaryBroadcastLayout)
 
+
+/*! \brief A helper class for matching and rewriting operators. */
+template<typename R>
+class OpMatch {
+ public:
+  using MatchFunc =
+      std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>;
+
+  /*! \brief Match an operator with the given name.
+   *  \param op_name The name of the operator to match.
+   *  \param func The function to execute when it matches.
+   *  \return A self-reference for builder style API.
+   */
+  inline OpMatch& Match(const std::string& op_name, MatchFunc func) {
+    auto op = Op::Get(op_name);
+    match_map_.insert({op, func});
+    return *this;
+  }
+
+  /*! \brief Rewrite a call operation based on the operator and the registered
+   *  match functions.
+   * \param call The call to rewrite.
+   * \return The result of rewriting.
+   */
+  inline R operator()(const Call& call) {
+    auto it = match_map_.find(Downcast<Op>(call->op));
+    if (it != match_map_.end()) {
+      return it->second(call->args, call->attrs, call->type_args);
+    } else {
+      if (default_ != nullptr) {
+        return default_(call->args, call->attrs, call->type_args);
+      } else {
+        LOG(FATAL) << "unexpected operation " << call->op;
+      }
+    }
+  }
+
+ private:
+  /*! \brief The match function map. */
+  std::unordered_map<Op, MatchFunc, NodeHash, NodeEqual> match_map_;
+  /*! \brief An optional default case. */
+  MatchFunc default_;
+};
+
 }  // namespace relay
 }  // namespace tvm
 
index 1979d06..e78a166 100644 (file)
@@ -286,8 +286,8 @@ bool ShapeOfRel(const Array<Type>& types,
   CHECK(tt != nullptr);
   const auto* param = attrs.as<ShapeOfAttrs>();
   CHECK(param != nullptr);
-  auto vector_out = tvm::Integer(tt->shape.size());
-  reporter->Assign(types[1], TensorTypeNode::make({ vector_out }, param->dtype));
+  auto rank_shape = RankShape(tt->shape);
+  reporter->Assign(types[1], TensorTypeNode::make(rank_shape, param->dtype));
   return true;
 }
 
index f71b85d..d280b83 100644 (file)
@@ -144,5 +144,13 @@ bool BroadcastCompRel(const Array<Type>& types,
   return false;
 }
 
+Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
+  if (shape.size() == 0) {
+    return {};
+  } else {
+    return { tvm::Integer(shape.size()) };
+  }
+}
+
 }  // namespace relay
 }  // namespace tvm
index 7e3a418..2244b9c 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -80,6 +80,8 @@ bool BroadcastCompRel(const Array<Type>& types,
                       const Attrs& attrs,
                       const TypeReporter& reporter);
 
+Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);
+
 }  // namespace relay
 }  // namespace tvm
 
index 21992ab..6ad04b0 100644 (file)
@@ -28,6 +28,7 @@
  *  3. Collect the device allocation of each expression.
  */
 
+#include <tvm/expr.h>
 #include <tvm/relay/attrs/device_copy.h>
 #include <tvm/relay/attrs/annotation.h>
 #include <tvm/relay/expr.h>
index 5825c1e..7724950 100644 (file)
@@ -27,6 +27,7 @@
 #include <tvm/relay/interpreter.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/transform.h>
+#include "./pattern_util.h"
 
 namespace tvm {
 namespace relay {
@@ -73,13 +74,12 @@ bool ConstantCheck(const Expr& e) {
 TVM_REGISTER_API("relay._analysis.check_constant")
 .set_body_typed(ConstantCheck);
 
-
 // TODO(tvm-team) consider combine dead-code with constant folder.
 // or make a more powerful partial evaluator.
 class ConstantFolder : public ExprMutator {
  public:
-  explicit ConstantFolder(FInterpreter executor)
-      : executor_(executor) {
+  explicit ConstantFolder(FInterpreter executor, Module module)
+      : executor_(executor), module_(module) {
   }
 
   Expr VisitExpr_(const LetNode* op) final {
@@ -123,6 +123,15 @@ class ConstantFolder : public ExprMutator {
     if (call->op.same_as(Op::Get("shape_of"))) {
       return EvaluateShapeOf(res, origin_args, call->attrs);
     }
+
+    // We should think about potentially constant evaluation over these ops too.
+    if (call->op.same_as(Op::Get("memory.invoke_tvm_op")) ||
+        call->op.same_as(Op::Get("memory.shape_func")) ||
+        call->op.same_as(Op::Get("memory.alloc_tensor")) ||
+        call->op.same_as(Op::Get("memory.alloc_storage"))) {
+      return GetRef<Call>(call);
+    }
+
     bool all_const_args = true;
     for (Expr arg : call->args) {
       if (!checker_.Check(arg)) {
@@ -151,10 +160,16 @@ class ConstantFolder : public ExprMutator {
   FInterpreter executor_;
   // Internal constant checker
   ConstantChecker checker_;
+  // Module
+  Module module_;
 
   // Convert value to expression.
   Expr ValueToExpr(Value value) {
     if (const auto* val = value.as<TensorValueNode>()) {
+      for (auto dim : val->data.Shape()) {
+        CHECK_GT(dim, 0)
+          << "invalid dimension after constant eval";
+      }
       return ConstantNode::make(val->data);
     } else if (const auto* val = value.as<TupleValueNode>()) {
       Array<Expr> fields;
@@ -171,18 +186,33 @@ class ConstantFolder : public ExprMutator {
   Expr ConstEvaluate(Expr expr) {
     std::vector<transform::Pass> passes = {transform::FuseOps(0),
                                            transform::InferType()};
-    auto mod = ModuleNode::FromExpr(expr);
+    Function func;
+    if (expr.as<FunctionNode>()) {
+      func = Downcast<Function>(expr);
+    } else {
+      // TODO(@jroesch): fix this
+      func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
+    }
+    auto mod = ModuleNode::make(
+      {},
+      module_->type_definitions,
+      module_->Imports());
+    auto global = GlobalVarNode::make("main");
+    mod->Add(global, func);
     auto seq = transform::Sequential(passes);
     mod = seq(mod);
     auto entry_func = mod->Lookup("main");
     expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
     return ValueToExpr(executor_(expr));
   }
-  // Evaluate shape_of op
+
+  // Evaluate a call to the shape_of operator for tensors with constant
+  // shapes.
   Expr EvaluateShapeOf(Expr expr, Array<Expr> args, Attrs attrs) {
     Expr input = args[0];
     const auto* param = attrs.as<ShapeOfAttrs>();
     CHECK(param != nullptr);
+
     tvm::Array<IndexExpr> ishape;
     if (const ConstantNode* op = input.as<ConstantNode>()) {
       ishape = op->tensor_type()->shape;
@@ -191,33 +221,48 @@ class ConstantFolder : public ExprMutator {
     } else {
       return expr;
     }
+
     // Get the constant shape
     DLContext ctx;
     ctx.device_type = kDLCPU;
     ctx.device_id = 0;
-    auto val = runtime::NDArray::Empty(
-        {(int64_t)ishape.size()}, Type2TVMType(Int(32)), ctx);
-    int32_t* dims = static_cast<int32_t*>(val->data);
-    using ::tvm::ir::IntImm;
-    for (size_t i = 0; i < ishape.size(); ++i) {
-      if (const IntImm* dim = ishape[i].as<IntImm>()) {
-        dims[i] = dim->value;
-      } else {
-        return expr;
+    runtime::NDArray value;
+    auto cdtype = Type2TVMType(Int(32));
+    if (ishape.size() == 0) {
+      value = runtime::NDArray::Empty({}, cdtype, ctx);
+    } else {
+      CHECK_NE(ishape.size(), 0);
+      std::vector<int64_t> cshape = { static_cast<int64_t>(ishape.size()) };
+      value = runtime::NDArray::Empty(cshape, cdtype, ctx);
+      int32_t* dims = static_cast<int32_t*>(value->data);
+      using ::tvm::ir::IntImm;
+      for (size_t i = 0; i < ishape.size(); ++i) {
+        if (const IntImm* dim = ishape[i].as<IntImm>()) {
+          dims[i] = dim->value;
+        } else {
+          return expr;
+        }
       }
     }
-    Expr shape = ValueToExpr(TensorValueNode::make(val));
+
+    Constant shape = Downcast<Constant>(ValueToExpr(TensorValueNode::make(value)));
+
+    if (shape->data.Shape().size() == 0 && GetScalarFromConstant<int32_t>(shape) == 0) {
+      auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx);
+      shape = ConstantNode::make(ndarray);
+    }
+
     // Cast the constant into correct dtype
     auto cast_attrs = make_node<CastAttrs>();
     cast_attrs->dtype = param->dtype;
     static const Op& cast_op = Op::Get("cast");
-    Expr ret = CallNode::make(cast_op, {shape}, Attrs(cast_attrs), {});
+    Expr ret = CallNode::make(cast_op, { shape }, Attrs(cast_attrs), {});
     return ConstEvaluate(ret);
   }
 };
 
 
-Expr FoldConstant(const Expr& expr) {
+Expr FoldConstant(const Expr& expr, const Module& mod) {
   DLContext ctx;
   ctx.device_type = kDLCPU;
   ctx.device_id = 0;
@@ -227,7 +272,7 @@ Expr FoldConstant(const Expr& expr) {
   With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
 
   return ConstantFolder(CreateInterpreter(
-      Module(nullptr), ctx, target)).Mutate(expr);
+      mod, ctx, target), mod).Mutate(expr);
 }
 
 namespace transform {
@@ -235,7 +280,7 @@ namespace transform {
 Pass FoldConstant() {
   runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
     [=](Function f, Module m, PassContext pc) {
-      return Downcast<Function>(FoldConstant(f));
+      return Downcast<Function>(FoldConstant(f, m));
   };
   return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
 }
index acee2c1..226ca6d 100644 (file)
@@ -862,6 +862,13 @@ class FuseMutator : private ExprMutator {
   Expr VisitExpr_(const CallNode* call) {
     static const Op& stop_fusion = Op::Get("annotation.stop_fusion");
     if (call->op.as<OpNode>()) {
+      static auto fnoncomputational =
+        Op::GetAttr<TNonComputational>("TNonComputational");
+
+      if (fnoncomputational.get(Downcast<Op>(call->op), false)) {
+        return ExprMutator::VisitExpr_(call);
+      }
+
       // If it is a primitive op call
       // then we must have a group assignment for it already.
       CHECK(gmap_.count(call));
index d268862..dbecc6a 100644 (file)
@@ -314,7 +314,7 @@ Module FunctionPassNode::operator()(const Module& mod,
              << pass_info->opt_level;
 
   // Execute the pass function and return a new module.
-  Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions);
+  Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports());
   std::vector<std::pair<GlobalVar, Function> > updates;
   for (const auto& it : updated_mod->functions) {
     auto updated_func = SkipFunction(it.second)
index cb0ea1b..c1d1a66 100644 (file)
@@ -311,8 +311,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
       Match match = GetRef<Match>(op);
       Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
       if (unmatched_cases.size() != 0) {
-        LOG(FATAL) << "Match clause " << match <<  " does not handle the following cases: "
-                   << unmatched_cases;
+        RelayErrorStream ss;
+        ss << "match expression does not handle the following cases: ";
+        int i = 0;
+        for (auto cs : unmatched_cases) {
+          ss << "case " << i << ": \n" << PrettyPrint(cs);
+        }
+        this->ReportFatalError(
+          match,
+          ss);
       }
     }
 
index 6035790..9b8cf88 100644 (file)
@@ -530,8 +530,10 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
 };
 
 // constructor
-TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module,
-                       ErrorReporter* err_reporter)
+TypeSolver::TypeSolver(
+  const GlobalVar& current_func,
+  const Module& module,
+  ErrorReporter* err_reporter)
     : reporter_(make_node<Reporter>(this)),
       current_func(current_func),
       err_reporter_(err_reporter),
index 32032b5..4c4554c 100644 (file)
@@ -287,9 +287,13 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
     }
     case Opcode::AllocTensor: {
       // Number of fields = 5 + instr.alloc_tensor.ndim
+      fields.push_back(instr.alloc_tensor.storage);
+
       // Save `DLDataType` and the dst register.
       const auto& dtype = instr.alloc_tensor.dtype;
-      fields.assign({dtype.code, dtype.bits, dtype.lanes});
+      fields.push_back(dtype.code);
+      fields.push_back(dtype.bits);
+      fields.push_back(dtype.lanes);
 
       // The number of dimensions is not needed for constructing an
       // `AllocTensor` instruction as it equals to the length of the `shape`
@@ -305,10 +309,22 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
       break;
     }
     case Opcode::AllocTensorReg: {
-      // Number of fields = 5
+      // Number of fields = 6
+      fields.push_back(instr.alloc_tensor_reg.storage);
       fields.push_back(instr.alloc_tensor_reg.shape_register);
       // Save `DLDataType` and the dst register.
-      const auto& dtype = instr.alloc_tensor.dtype;
+      const auto& dtype = instr.alloc_tensor_reg.dtype;
+      fields.push_back(dtype.code);
+      fields.push_back(dtype.bits);
+      fields.push_back(dtype.lanes);
+      fields.push_back(instr.dst);
+      break;
+    }
+    case Opcode::AllocStorage: {
+      fields.push_back(instr.alloc_storage.allocation_size);
+      fields.push_back(instr.alloc_storage.alignment);
+      // Save `DLDataType` and the dst register.
+      const auto& dtype = instr.alloc_storage.dtype_hint;
       fields.push_back(dtype.code);
       fields.push_back(dtype.bits);
       fields.push_back(dtype.lanes);
@@ -521,35 +537,39 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
       return Instruction::InvokePacked(packed_index, arity, output_size, args);
     }
     case Opcode::AllocTensor: {
-      // Number of fields = 5 + instr.alloc_tensor.ndim
-      DCHECK_GE(instr.fields.size(), 5U);
-      DCHECK_EQ(instr.fields.size(), 5U + static_cast<size_t>(instr.fields[3]));
+      // Number of fields = 6 + instr.alloc_tensor.ndim
+      DCHECK_GE(instr.fields.size(), 6U);
+      DCHECK_EQ(instr.fields.size(), 6U + static_cast<size_t>(instr.fields[4]));
+
+      RegName storage_reg = instr.fields[0];
 
       DLDataType dtype;
-      dtype.code = instr.fields[0];
-      dtype.bits = instr.fields[1];
-      dtype.lanes = instr.fields[2];
+      dtype.code = instr.fields[1];
+      dtype.bits = instr.fields[2];
+      dtype.lanes = instr.fields[3];
 
-      Index ndim = instr.fields[3];
-      RegName dst = instr.fields[4];
+      Index ndim = instr.fields[4];
+      RegName dst = instr.fields[5];
 
-      std::vector<Index> shape = ExtractFields(instr.fields, 5, ndim);
+      std::vector<Index> shape = ExtractFields(instr.fields, 6, ndim);
 
-      return Instruction::AllocTensor(shape, dtype, dst);
+      return Instruction::AllocTensor(storage_reg, shape, dtype, dst);
     }
     case Opcode::AllocTensorReg: {
       // Number of fields = 5
-      DCHECK_EQ(instr.fields.size(), 5U);
-      Index shape_register = instr.fields[0];
+      DCHECK_EQ(instr.fields.size(), 6U);
+
+      RegName storage_reg = instr.fields[0];
+      Index shape_register = instr.fields[1];
 
       DLDataType dtype;
-      dtype.code = instr.fields[1];
-      dtype.bits = instr.fields[2];
-      dtype.lanes = instr.fields[3];
+      dtype.code = instr.fields[2];
+      dtype.bits = instr.fields[3];
+      dtype.lanes = instr.fields[4];
 
-      RegName dst = instr.fields[4];
+      RegName dst = instr.fields[5];
 
-      return Instruction::AllocTensorReg(shape_register, dtype, dst);
+      return Instruction::AllocTensorReg(storage_reg, shape_register, dtype, dst);
     }
     case Opcode::AllocADT: {
       // Number of fields = 3 + instr.num_fields
@@ -575,6 +595,24 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
 
       return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst);
     }
+    case Opcode::AllocStorage: {
+      DCHECK_GE(instr.fields.size(), 6U);
+      Index allocation_size = instr.fields[0];
+      Index alignment = instr.fields[1];
+
+      DLDataType dtype;
+      dtype.code = instr.fields[2];
+      dtype.bits = instr.fields[3];
+      dtype.lanes = instr.fields[4];
+
+      RegName dst = instr.fields[5];
+
+      return Instruction::AllocStorage(
+        allocation_size,
+        alignment,
+        dtype,
+        dst);
+    }
     case Opcode::If: {
       // Number of fields = 4
       DCHECK_EQ(instr.fields.size(), 4U);
index f32d232..1c7e029 100644 (file)
@@ -32,6 +32,30 @@ namespace tvm {
 namespace runtime {
 namespace vm {
 
+static void BufferDeleter(NDArray::Container* ptr) {
+  CHECK(ptr->manager_ctx != nullptr);
+  Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
+  MemoryManager::Global()->GetAllocator(buffer->ctx)->
+      Free(*(buffer));
+  delete buffer;
+  delete ptr;
+}
+
+void StorageObj::Deleter(NDArray::Container* ptr) {
+  // When invoking AllocNDArray we don't own the underlying allocation
+  // and should not delete the buffer, but instead let it be reclaimed
+  // by the storage object's destructor.
+  //
+  // We did bump the reference count by 1 to keep alive the StorageObj
+  // allocation in case this NDArray is the sole owner.
+  //
+  // We decrement the object allowing for the buffer to release our
+  // reference count from allocation.
+  StorageObj* storage = reinterpret_cast<StorageObj*>(ptr->manager_ctx);
+  storage->DecRef();
+  delete ptr;
+}
+
 inline void VerifyDataType(DLDataType dtype) {
   CHECK_GE(dtype.lanes, 1);
   if (dtype.code == kDLFloat) {
@@ -50,6 +74,22 @@ inline size_t GetDataAlignment(const DLTensor& arr) {
   return align;
 }
 
+NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDataType dtype) {
+  // TODO(@jroesch): generalize later to non-overlapping allocations.
+  CHECK_EQ(offset, 0u);
+  VerifyDataType(dtype);
+  NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx);
+  container->deleter = StorageObj::Deleter;
+  size_t needed_size = GetDataSize(container->dl_tensor);
+  // TODO(@jroesch): generalize later to non-overlapping allocations.
+  CHECK(needed_size == this->buffer.size)
+    << "size mistmatch required " << needed_size << " found " << this->buffer.size;
+  this->IncRef();
+  container->manager_ctx = reinterpret_cast<void*>(this);
+  container->dl_tensor.data = this->buffer.data;
+  return NDArray(container);
+}
+
 MemoryManager* MemoryManager::Global() {
   static MemoryManager memory_manager;
   return &memory_manager;
@@ -66,15 +106,6 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
   return allocators_.at(ctx).get();
 }
 
-static void BufferDeleter(NDArray::Container* ptr) {
-  CHECK(ptr->manager_ctx != nullptr);
-  Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
-  MemoryManager::Global()->GetAllocator(buffer->ctx)->
-      Free(*(buffer));
-  delete buffer;
-  delete ptr;
-}
-
 NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
   VerifyDataType(dtype);
   NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx);
index 988df84..dce596c 100644 (file)
@@ -27,6 +27,7 @@
 
 #include <tvm/runtime/c_runtime_api.h>
 #include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/object.h>
 #include <functional>
 #include <memory>
 #include <mutex>
@@ -108,6 +109,38 @@ class MemoryManager {
   std::unordered_map<TVMContext, std::unique_ptr<Allocator> > allocators_;
 };
 
+/*! \brief An object representing a storage allocation. */
+class StorageObj : public Object {
+ public:
+  /*! \brief The index into the VM function table. */
+  Buffer buffer;
+
+  /*! \brief Allocate an NDArray from a given piece of storage. */
+  NDArray AllocNDArray(size_t offset,
+                       std::vector<int64_t> shape,
+                       DLDataType dtype);
+
+  /*! \brief The deleter for an NDArray when allocated from underlying storage. */
+  static void Deleter(NDArray::Container* ptr);
+
+  ~StorageObj() {
+    auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx);
+    alloc->Free(buffer);
+  }
+
+  static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+  static constexpr const char* _type_key = "vm.Storage";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object);
+};
+
+/*! \brief reference to storage. */
+class Storage : public ObjectRef {
+ public:
+  explicit Storage(Buffer buffer);
+
+  TVM_DEFINE_OBJECT_REF_METHODS_MUT(Storage, ObjectRef, StorageObj);
+};
+
 }  // namespace vm
 }  // namespace runtime
 }  // namespace tvm
index ab0e062..05935b7 100644 (file)
@@ -25,6 +25,8 @@
 #include <dmlc/memory_io.h>
 #include <tvm/logging.h>
 #include <tvm/runtime/vm.h>
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/object.h>
 
 #include <algorithm>
 #include <chrono>
@@ -42,6 +44,17 @@ namespace tvm {
 namespace runtime {
 namespace vm {
 
+
+inline Storage make_storage(size_t size, size_t alignment, TVMType dtype_hint, TVMContext ctx) {
+  // We could put cache in here, from ctx to storage allocator.
+  auto storage_obj = SimpleObjAllocator().make<StorageObj>();
+  auto alloc = MemoryManager::Global()->GetAllocator(ctx);
+  DCHECK(alloc != nullptr)
+    << "allocator must not null";
+  storage_obj->buffer = alloc->Alloc(size, alignment, dtype_hint);
+  return Storage(storage_obj);
+}
+
 Instruction::Instruction() {}
 
 template <typename T>
@@ -65,12 +78,14 @@ Instruction::Instruction(const Instruction& instr) {
       this->result = instr.result;
       return;
     case Opcode::AllocTensor:
+      this->alloc_tensor.storage = instr.alloc_tensor.storage;
       this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
       this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape,
                                                     instr.alloc_tensor.ndim);
       this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
       return;
     case Opcode::AllocTensorReg:
+      this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage;
       this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
       this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
       return;
@@ -119,6 +134,9 @@ Instruction::Instruction(const Instruction& instr) {
     case Opcode::Goto:
       this->pc_offset = instr.pc_offset;
       return;
+    case Opcode::AllocStorage:
+      this->alloc_storage = instr.alloc_storage;
+      return;
     default:
       std::ostringstream out;
       out << "Invalid instruction " << static_cast<int>(instr.op);
@@ -150,12 +168,14 @@ Instruction& Instruction::operator=(const Instruction& instr) {
       this->result = instr.result;
       return *this;
     case Opcode::AllocTensor:
+      this->alloc_tensor.storage = this->alloc_tensor.storage;
       this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
       this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape,
                                                     instr.alloc_tensor.ndim);
       this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
       return *this;
     case Opcode::AllocTensorReg:
+      this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage;
       this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
       this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
       return *this;
@@ -206,6 +226,9 @@ Instruction& Instruction::operator=(const Instruction& instr) {
     case Opcode::Goto:
       this->pc_offset = instr.pc_offset;
       return *this;
+    case Opcode::AllocStorage:
+      this->alloc_storage = instr.alloc_storage;
+      return *this;
     default:
       std::ostringstream out;
       out << "Invalid instruction " << static_cast<int>(instr.op);
@@ -224,6 +247,7 @@ Instruction::~Instruction() {
     case Opcode::GetTag:
     case Opcode::Goto:
     case Opcode::LoadConsti:
+    case Opcode::AllocStorage:
     case Opcode::Fatal:
       return;
     case Opcode::AllocTensor:
@@ -279,10 +303,14 @@ Instruction Instruction::InvokePacked(Index packed_index,
   return instr;
 }
 
-Instruction Instruction::AllocTensor(std::vector<int64_t> shape, DLDataType dtype, Index dst) {
+Instruction Instruction::AllocTensor(
+  RegName storage,
+  const std::vector<int64_t>& shape,
+  DLDataType dtype, Index dst) {
   Instruction instr;
   instr.op = Opcode::AllocTensor;
   instr.dst = dst;
+  instr.alloc_tensor.storage = storage;
   instr.alloc_tensor.ndim = shape.size();
   instr.alloc_tensor.shape = new int64_t[shape.size()];
   for (size_t i = 0; i < shape.size(); ++i) {
@@ -292,15 +320,32 @@ Instruction Instruction::AllocTensor(std::vector<int64_t> shape, DLDataType dtyp
   return instr;
 }
 
-Instruction Instruction::AllocTensorReg(RegName shape_register, DLDataType dtype, Index dst) {
+Instruction Instruction::AllocTensorReg(
+  RegName storage,
+  RegName shape_register,
+  DLDataType dtype, Index dst) {
   Instruction instr;
   instr.op = Opcode::AllocTensorReg;
   instr.dst = dst;
+  instr.alloc_tensor_reg.storage = storage;
   instr.alloc_tensor_reg.shape_register = shape_register;
   instr.alloc_tensor_reg.dtype = dtype;
   return instr;
 }
 
+Instruction Instruction::AllocStorage(RegName size,
+                                      Index alignment,
+                                      TVMType dtype_hint,
+                                      Index dst) {
+  Instruction instr;
+  instr.op = Opcode::AllocStorage;
+  instr.dst = dst;
+  instr.alloc_storage.allocation_size = size;
+  instr.alloc_storage.alignment = alignment;
+  instr.alloc_storage.dtype_hint = dtype_hint;
+  return instr;
+}
+
 Instruction Instruction::AllocADT(Index tag, Index num_fields,
                                        const std::vector<RegName>& datatype_fields, Index dst) {
   Instruction instr;
@@ -472,7 +517,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
       break;
     }
     case Opcode::AllocTensor: {
-      os << "alloc_tensor $" << instr.dst << " ["
+      os << "alloc_tensor $" << instr.dst << " $"
+         << instr.alloc_tensor.storage << " ["
          << StrJoin<int64_t>(instr.alloc_tensor.shape, 0,
                              instr.alloc_tensor.ndim)
          << "] ";
@@ -481,6 +527,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
     }
     case Opcode::AllocTensorReg: {
       os << "alloc_tensor_reg $" << instr.dst << " $"
+         << instr.alloc_tensor_reg.storage << " $"
          << instr.alloc_tensor_reg.shape_register << " ";
       DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
       break;
@@ -534,6 +581,14 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
       os << "goto " << instr.pc_offset;
       break;
     }
+    case Opcode::AllocStorage: {
+      os << "alloc_storage " <<
+        instr.dst << " " <<
+        instr.alloc_storage.allocation_size << " " <<
+        instr.alloc_storage.alignment << " " <<
+        TVMType2String(instr.alloc_storage.dtype_hint);
+      break;
+    }
     default:
       LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
       break;
@@ -827,17 +882,21 @@ void VirtualMachine::RunLoop() {
         goto main_loop;
       }
       case Opcode::InvokePacked: {
+        DLOG(INFO) << "InvokedPacked "
+          << "arity=" << instr.arity;
         const auto& func = packed_funcs[instr.packed_index];
         const auto& arity = instr.arity;
         std::vector<ObjectRef> args;
         for (Index i = 0; i < arity; ++i) {
-          args.push_back(ReadRegister(instr.packed_args[i]));
+          DLOG(INFO) <<
+            "arg" << i << " $" << instr.packed_args[i];
+          auto arg = ReadRegister(instr.packed_args[i]);
+          args.push_back(arg);
         }
+
+        // We no longer need to write the registers back, we write directly
+        // through the registers mutably.
         InvokePacked(instr.packed_index, func, arity, instr.output_size, args);
-        for (Index i = 0; i < instr.output_size; ++i) {
-          WriteRegister(instr.packed_args[instr.arity - instr.output_size + i],
-                        args[instr.arity - instr.output_size + i]);
-        }
         pc++;
         goto main_loop;
       }
@@ -901,12 +960,15 @@ void VirtualMachine::RunLoop() {
       }
       case Opcode::AllocTensor: {
         auto shape = std::vector<int64_t>(instr.alloc_tensor.ndim);
+
         for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) {
           shape[i] = instr.alloc_tensor.shape[i];
         }
-        // TODO(wweic) ctx could be obtained from the ctxs list.
-        auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
-        auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]);
+
+        auto storage_obj = ReadRegister(instr.alloc_tensor.storage);
+        auto storage = Downcast<Storage>(storage_obj);
+        auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype);
+
         auto obj = Tensor(data);
         WriteRegister(instr.dst, obj);
         pc++;
@@ -916,19 +978,22 @@ void VirtualMachine::RunLoop() {
         DLContext cpu_ctx;
         cpu_ctx.device_type = kDLCPU;
         cpu_ctx.device_id = 0;
-
         auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
         const auto* tensor = shape_tensor_obj.as<TensorObj>();
         CHECK(tensor != nullptr);
         NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
-
-        int64_t* dims = static_cast<int64_t*>(shape_tensor->data);
+        const DLTensor* dl_tensor = shape_tensor.operator->();
+        CHECK_EQ(dl_tensor->dtype.code, 0u);
+        CHECK_LE(dl_tensor->dtype.bits, 64);
+        int64_t* dims = reinterpret_cast<int64_t*>(dl_tensor->data);
         auto num_dims = shape_tensor->shape[0];
-        auto shape = std::vector<int64_t>(shape_tensor->shape[0]);
+        auto shape = std::vector<int64_t>(num_dims);
         shape.assign(dims, dims + num_dims);
-        // TODO(wweic) ctx could be obtained from the ctxs list.
-        auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
-        auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]);
+
+        auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage);
+        auto storage = Downcast<Storage>(storage_obj);
+        auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype);
+
         auto obj = Tensor(data);
         WriteRegister(instr.dst, obj);
         pc++;
@@ -953,6 +1018,20 @@ void VirtualMachine::RunLoop() {
         pc++;
         goto main_loop;
       }
+      case Opcode::AllocStorage: {
+        auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
+        auto alignment = LoadScalarInt(instr.alloc_storage.alignment);
+
+        DLOG(INFO) <<
+          "AllocStorage: allocation_size=" << size <<
+          "alignment=" << alignment <<
+          "dtype_hint=" << TVMType2String(instr.alloc_storage.dtype_hint);
+
+        auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs[0]);
+        WriteRegister(instr.dst, storage);
+        pc++;
+        goto main_loop;
+      }
       case Opcode::Ret: {
         // If we have hit the point from which we started
         // running, we should return to the caller breaking
diff --git a/tests/python/relay/test_memory_alloc.py b/tests/python/relay/test_memory_alloc.py
new file mode 100644 (file)
index 0000000..5c1bbc7
--- /dev/null
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay import memory_alloc
+
+def check_vm_alloc(func, check_fn):
+    mod = relay.Module()
+    mod['main'] = func
+    ex = relay.create_executor('vm', mod)
+    args = []
+    for param in func.params:
+        param = param.type_annotation
+        sh = [int(sh) for sh in param.shape]
+        data = np.random.rand(*sh).astype(param.dtype)
+        args.append(tvm.nd.array(data))
+    result = ex.evaluate(mod['main'])(*args)
+    py_res = check_fn(*[arg.asnumpy() for arg in args])
+    np.testing.assert_allclose(result.asnumpy(), py_res)
+
+def storage_type(mod):
+    return relay.TypeCall(mod.get_global_type_var("Storage"), [])
+
+def test_tyck_alloc_storage():
+    mod = relay.Module()
+    mod.import_from_std("core.rly")
+
+def test_tyck_alloc_tensor():
+    mod = relay.Module()
+    mod.import_from_std("core.rly")
+    sto = relay.Var("x", storage_type(mod))
+    sh = relay.const(np.array([1, 2]), dtype="int64")
+    at = relay.op.memory.alloc_tensor(sto, sh)
+    mod['main'] = relay.Function([sto], at)
+    relay.transform.InferType()(mod)
+
+
+def check_add(x):
+    return x + x
+
+def test_add():
+    x = relay.var('x', shape=(2,))
+    z = x + x
+    func = relay.Function([x,], z)
+    check_vm_alloc(func, check_add)
+
+
+def check_add_sub(x, y):
+    z = x + x
+    return z - y
+
+def test_add_sub():
+    x = relay.var('x', shape=(10,))
+    y = relay.var('y', shape=(10,))
+    z = x + x
+    z = z - y
+    func = relay.Function([x, y], z)
+    check_vm_alloc(func, check_add_sub)
+
+if __name__ == "__main__":
+    test_tyck_alloc_tensor()
+    test_add()
+    test_add_sub()
index 0146480..0327c14 100644 (file)
@@ -107,9 +107,9 @@ def test_serializer():
     assert any(item.startswith('fused_multiply') for item in prim_ops)
 
     code = exe.bytecode
-    assert "main 5 2 5" in code
-    assert "f1 2 1 3" in code
-    assert "f2 2 1 3" in code
+    assert "main 8 2 8" in code
+    assert "f1 5 1 6" in code
+    assert "f2 5 1 6" in code
 
     code, lib = exe.save()
     assert isinstance(code, bytearray)