[Relay][Any] Add shape func for dynamic shape (#3606)
authorHaichen Shen <shenhaichen@gmail.com>
Sun, 1 Sep 2019 01:50:22 +0000 (18:50 -0700)
committerJared Roesch <roeschinc@gmail.com>
Sun, 1 Sep 2019 01:50:22 +0000 (18:50 -0700)
* init shape func in interpreter and vm compiler

* Update interpreter

* fix

* lint

* lint

* fix

* remove hack

* update

* fix

* fix

* update

* address comments & update for shape_of

* fix lint

* update

* fix hybrid

* lint

* fix bug & add take shape func

* lint

* lint

* update

* fix flaky test

* add todo

29 files changed:
include/tvm/ir.h
include/tvm/relay/op_attr_types.h
python/tvm/hybrid/parser.py
python/tvm/hybrid/util.py
python/tvm/relay/backend/vm.py
python/tvm/relay/op/_tensor.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/op.py
src/op/hybrid_op.cc
src/pass/storage_rewrite.cc
src/relay/backend/compile_engine.cc
src/relay/backend/compile_engine.h
src/relay/backend/interpreter.cc
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/compiler.h
src/relay/ir/op.cc
src/relay/op/tensor/transform.cc
src/relay/op/tensor/unary.cc
src/relay/op/type_relations.cc
src/relay/pass/fuse_ops.cc
src/relay/pass/pass_manager.cc
src/runtime/vm/vm.cc
src/schedule/schedule_dataflow_rewrite.cc
tests/python/relay/test_any.py
tests/python/relay/test_vm_serialization.py
tests/python/unittest/test_hybrid_script.py
topi/include/topi/detail/broadcast.h
topi/include/topi/transform.h
topi/tests/python/test_topi_conv2d_winograd.py

index 03559d3..a9c6c4b 100644 (file)
@@ -704,6 +704,10 @@ class Reduce : public ExprNode {
 class Any : public ExprNode {
  public:
   void VisitAttrs(AttrVisitor* v) final {}
+  /*! \brief Convert to var. */
+  Var ToVar() const {
+    return Variable::make(Int(32), "any_dim");
+  }
 
   TVM_DLL static Expr make();
 
index c1a0f83..741e8b4 100644 (file)
@@ -75,6 +75,11 @@ using TOpIsStateful = bool;
 using TNonComputational = bool;
 
 /*!
+ * \brief Mark the operator whether output shape is data dependant.
+ */
+using TShapeDataDependant = bool;
+
+/*!
  * \brief Computation description interface.
  *
  * \note This function have a special convention
@@ -186,7 +191,7 @@ using Shape = Array<IndexExpr>;
 using FShapeFunc = runtime::TypedPackedFunc<
   Array<Tensor>(const Attrs& attrs,
                 const Array<Tensor>& inputs,
-                const Array<Shape>& out_shapes)>;
+                const Array<IndexExpr>& out_ndims)>;
 
 }  // namespace relay
 }  // namespace tvm
index 40ea171..171a2f8 100644 (file)
@@ -25,7 +25,7 @@ import numbers
 
 from enum import Enum
 
-from .util import _internal_assert, _apply_indices
+from .util import _internal_assert
 from . import calls
 from . import util
 from .preprocessor import determine_variable_usage
@@ -35,7 +35,6 @@ from ..container import Array
 from ..tensor import Tensor, Operation
 from .. import _api_internal as _tvm_internal
 from .. import expr as _expr
-from .. import stmt as _stmt
 from .. import make as _make
 from .. import api  as _api
 from .. import ir_pass as _ir_pass
@@ -43,16 +42,15 @@ from .. import ir_pass as _ir_pass
 
 def concat_list_to_block(lst):
     """Concatenate a list of Python IR nodes to HalideIR Block"""
+    if not lst:
+        return util.make_nop()
     n = len(lst)
     if n == 1:
         return lst[0]
     body = lst[n - 1]
     for i in range(1, n):
         stmt = lst[n - 1 - i]
-        if isinstance(stmt, _stmt.AssertStmt):
-            body = _make.AssertStmt(stmt.condition, stmt.message, body)
-        else:
-            body = _make.Block(stmt, body)
+        body = _make.Block(stmt, body)
     return body
 
 
@@ -100,8 +98,8 @@ class HybridParser(ast.NodeVisitor):
         ast.LtE     : operator.le,
         ast.Eq      : operator.eq,
         ast.NotEq   : operator.ne,
-        ast.And   : _all,
-        ast.Or    : _any,
+        ast.And     : _all,
+        ast.Or      : _any,
     }
 
 
@@ -179,6 +177,9 @@ class HybridParser(ast.NodeVisitor):
         to_pop = []
         for key, val in self.usage.items():
             _, level, _ = val
+            if key not in self.symbols:
+                # don't realize the symbols that are never visited
+                continue
             if level != node:
                 continue
             _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)
@@ -363,44 +364,25 @@ class HybridParser(ast.NodeVisitor):
 
 
     def visit_Attribute(self, node):
-        _internal_assert(isinstance(node.value, ast.Name), \
-                         "For atrribute access, only both names are supported so far!")
         buf = self.visit(node.value)
         return getattr(buf, node.attr)
 
     def visit_Subscript(self, node):
         args = self.visit(node.slice)
-        if isinstance(node.value, ast.Name):
-            if node.value.id in self.closure_vars:
-                args = ast.literal_eval(str(args))
-                return _api.convert(_apply_indices(self.closure_vars[node.value.id], args))
-
-            buf = self.visit(node.value)
-            if isinstance(buf, Array):
-                for i in args:
-                    if isinstance(i, numbers.Integral):
-                        buf = buf[i]
-                    else:
-                        _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
-                                         "All indices are supposed to be constants")
-                        buf = buf[i.value]
-
-                return buf
-
-            if isinstance(node.ctx, ast.Load):
-                return _make.Call(buf.dtype, buf.name, args, \
-                                  _expr.Call.Halide, buf.op, buf.value_index)
-
-            return buf, args
-
-        shape = self.visit(node.value)
-        _internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!")
-        args = args[0]
-        #TODO: maybe support non-constant value later?
-        _internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \
-                         "So far only constant shape access supported!")
-        return shape[args.value]
-
+        arr = self.visit(node.value)
+        if isinstance(arr, Array):
+            for i in args:
+                if isinstance(i, numbers.Integral):
+                    arr = arr[i]
+                else:
+                    _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
+                                     "All indices are supposed to be constants")
+                    arr = arr[i.value]
+            return arr
+        if isinstance(node.ctx, ast.Load):
+            return _make.Call(arr.dtype, arr.name, args,
+                              _expr.Call.Halide, arr.op, arr.value_index)
+        return arr, args
 
     def visit_With(self, node):
         if sys.version_info[0] < 3:
@@ -417,7 +399,7 @@ class HybridParser(ast.NodeVisitor):
 
 
     def visit_If(self, node):
-        cond = self.visit(node.test)
+        cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
 
         # Return no IfThenElse if proven
         if isinstance(cond, _expr.UIntImm):
@@ -508,11 +490,11 @@ class HybridParser(ast.NodeVisitor):
         _name = node.target.id
 
         if isinstance(for_type, tuple):
-            low = _ir_pass.Simplify(low)
-            ext = _ir_pass.Simplify(ext)
+            low = _ir_pass.CanonicalSimplify(low)
+            ext = _ir_pass.CanonicalSimplify(ext)
             _internal_assert(isinstance(low, _expr.ConstExpr) and
                              isinstance(ext, _expr.ConstExpr), \
-                             "Const range should start from a const" + \
+                             "Const range should start from a const " + \
                              "and iterate const times")
 
             low, ext = low.value, ext.value
index 058c5aa..0dd1fa1 100644 (file)
@@ -101,9 +101,3 @@ def _is_tvm_arg_types(args):
         _internal_assert(isinstance(elem, np_arg_types), \
                          "Expect a numpy type but %s get!" % str(type(elem)))
     return False
-
-def _apply_indices(value, indices):
-    """Apply multidimensional index"""
-    if indices:
-        return _apply_indices(value[indices[0]], indices[1:])
-    return value
index d4cc68b..745ef63 100644 (file)
@@ -177,6 +177,10 @@ class VMCompiler(object):
             The VM runtime.
         """
         target = _update_target(target)
+        target_host = None if target_host == "" else target_host
+        if not target_host:
+            target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
+        target_host = tvm.target.create(target_host)
         self._compile(mod, target, target_host)
         return VirtualMachine(self._get_vm())
 
index 176def3..2e34233 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name, unused-argument
+#pylint: disable=invalid-name, unused-argument, len-as-condition
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 import topi
-from .op import register_compute, register_schedule, register_pattern
+from .op import register_compute, register_schedule, register_pattern, register_shape_func
 from .op import schedule_injective, OpPattern
+from ...hybrid import script
 
 schedule_broadcast = schedule_injective
 schedule_elemwise = schedule_injective
@@ -104,3 +105,49 @@ def clip_compute(attrs, inputs, output_type, target):
     return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
 
 register_schedule("clip", schedule_elemwise)
+
+# shape func
+@script
+def _broadcast_shape_func(x, y, ndim):
+    out = output_tensor((ndim,), "int64")
+    if len(x.shape) == 0:
+        for i in const_range(ndim):
+            out[i] = y[i]
+    elif len(y.shape) == 0:
+        for i in const_range(ndim):
+            out[i] = x[i]
+    else:
+        ndim1 = x.shape[0]
+        ndim2 = y.shape[0]
+        for i in const_range(1, min(ndim1, ndim2)+1):
+            if x[ndim1-i] == y[ndim2-i]:
+                out[ndim-i] = x[ndim1-i]
+            elif x[ndim1-i] == 1:
+                out[ndim-i] = y[ndim2-i]
+            else:
+                assert y[ndim2 - i] == 1, "Incompatible broadcast type %s and %s" % (
+                    x[ndim1-i], y[ndim2-i])
+                out[ndim-i] = x[ndim1-i]
+        for i in const_range(min(ndim1, ndim2)+1, ndim+1):
+            if ndim1 >= ndim2:
+                out[ndim-i] = x[ndim1-i]
+            else:
+                out[ndim-i] = y[ndim2-i]
+    return out
+
+def broadcast_shape_func(attrs, inputs, out_ndims):
+    return [_broadcast_shape_func(*inputs, out_ndims[0])]
+
+register_shape_func("add", False, broadcast_shape_func)
+register_shape_func("subtract", False, broadcast_shape_func)
+register_shape_func("multiply", False, broadcast_shape_func)
+register_shape_func("divide", False, broadcast_shape_func)
+register_shape_func("mod", False, broadcast_shape_func)
+register_shape_func("logical_and", False, broadcast_shape_func)
+register_shape_func("logical_or", False, broadcast_shape_func)
+register_shape_func("equal", False, broadcast_shape_func)
+register_shape_func("not_equal", False, broadcast_shape_func)
+register_shape_func("less", False, broadcast_shape_func)
+register_shape_func("less_equal", False, broadcast_shape_func)
+register_shape_func("greater", False, broadcast_shape_func)
+register_shape_func("greater_equal", False, broadcast_shape_func)
index a4c9375..7f29e85 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """Backend compiler related feature registration"""
-# pylint: disable=invalid-name,unused-argument
+# pylint: disable=invalid-name,unused-argument, len-as-condition
 from __future__ import absolute_import
+from topi.util import get_const_int, get_const_tuple
 from . import op as _reg
 from ._reduce import _schedule_reduce
 from .op import OpPattern
+from ...hybrid import script
+from ...api import convert
 
 schedule_injective = _reg.schedule_injective
 schedule_broadcast = _reg.schedule_injective
@@ -58,3 +61,145 @@ _reg.register_schedule("one_hot", schedule_injective)
 # layout_transform
 _reg.register_schedule("layout_transform", schedule_injective)
 _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
+
+# shape func
+@script
+def _arange_shape_func(start, stop, step):
+    out = output_tensor((1,), "int64")
+    out[0] = int64(ceil_div((float32(stop[0]) - float32(start[0])), float32(step[0])))
+    return out
+
+@_reg.register_shape_func("arange", True)
+def arange_shape_func(attrs, inputs, _):
+    return [_arange_shape_func(*inputs)]
+
+@script
+def _concatenate_shape_func(inputs, axis):
+    ndim = inputs[0].shape[0]
+    out = output_tensor((ndim,), "int64")
+    for i in const_range(ndim):
+        if i != axis:
+            out[i] = inputs[0][i]
+            for j in const_range(1, len(inputs)):
+                assert out[i] == inputs[j][i], \
+                    "Dims mismatch in the inputs of concatenate."
+        else:
+            out[i] = int64(0)
+            for j in const_range(len(inputs)):
+                out[i] += inputs[j][i]
+    return out
+
+@_reg.register_shape_func("concatenate", False)
+def concatenate_shape_func(attrs, inputs, _):
+    axis = get_const_int(attrs.axis)
+    return [_concatenate_shape_func(inputs, convert(axis))]
+
+@script
+def _reshape_shape_func(data_shape, newshape, ndim):
+    out = output_tensor((ndim,), "int64")
+    src_idx = 0
+    dst_idx = 0
+    infer_idx = -1
+    copy = False
+    skip = 0
+    for i in const_range(len(newshape)):
+        if skip > 0:
+            skip -= 1
+        elif newshape[i] > 0:
+            out[dst_idx] = int64(newshape[i])
+            src_idx += 1
+            dst_idx += 1
+        elif newshape[i] == 0:
+            out[dst_idx] = data_shape[src_idx]
+            src_idx += 1
+            dst_idx += 1
+        elif newshape[i] == -1:
+            assert infer_idx < 0, "One and only one dim can be inferred"
+            out[dst_idx] = int64(1)
+            infer_idx = i
+            dst_idx += 1
+        elif newshape[i] == -2:
+            copy = True
+        elif newshape[i] == -3:
+            assert data_shape.shape[0] - src_idx > 1, \
+                "Not enough dims in input shape for -3"
+            out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
+            src_idx += 2
+            dst_idx += 1
+        elif newshape[i] == -4:
+            assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
+            if newshape[i+1] == -1:
+                assert newshape[i+2] != -1, "Split dims cannot both be -1."
+                out[dst_idx] = data_shape[src_idx] / int64(newshape[i+2])
+                out[dst_idx+1] = int64(newshape[i+2])
+            else:
+                out[dst_idx] = int64(newshape[i+1])
+                if newshape[i+2] == -1:
+                    out[dst_idx+1] = data_shape[src_idx] / int64(newshape[i+1])
+                else:
+                    out[dst_idx+1] = int64(newshape[i+2])
+            assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
+                "Product of split dims doesn't match to input dim"
+            src_idx += 1
+            dst_idx += 2
+            skip = 2
+        else:
+            assert False, "Invalid special values in new shape"
+    if len(data_shape.shape) > 0:
+        # if data is not constant, we can then handle -1 and -2
+        if copy:
+            for i in range(src_idx, data_shape.shape[0]):
+                out[dst_idx] = data_shape[i]
+                dst_idx += 1
+        if infer_idx >= 0:
+            old_size = int64(1)
+            for i in const_range(data_shape.shape[0]):
+                old_size *= data_shape[i]
+            new_size = int64(1)
+            for i in const_range(out.shape[0]):
+                new_size *= out[i]
+            out[infer_idx] = old_size / new_size
+    return out
+
+@_reg.register_shape_func("reshape", False)
+def reshape_shape_func(attrs, inputs, out_ndims):
+    newshape = get_const_tuple(attrs.newshape)
+    return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])]
+
+@script
+def _take_no_axis_shape_func(indices_shape, out_ndim):
+    out = output_tensor((out_ndim,), "int64")
+    for i in const_range(out_ndim):
+        out[i] = indices_shape[i]
+    return out
+
+@script
+def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
+    out = output_tensor((out_ndim,), "int64")
+    for i in const_range(axis):
+        out[i] = data_shape[i]
+    if len(indices_shape.shape) == 0:
+        # indices is constant
+        for i in const_range(axis+1, len(data_shape)):
+            out[i-1] = data_shape[i]
+    else:
+        for i in const_range(len(indices_shape)):
+            out[axis+i] = indices_shape[i]
+        for i in const_range(axis+1, len(data_shape)):
+            out[len(indices_shape)+i-1] = data_shape[i]
+    return out
+
+@_reg.register_shape_func("take", False)
+def take_shape_func(attrs, inputs, out_ndims):
+    """
+    Shape function for take op.
+    """
+    if attrs.axis is None:
+        return [_take_no_axis_shape_func(inputs[1], out_ndims[0])]
+    else:
+        axis = get_const_int(attrs.axis)
+        data_ndim = int(inputs[0].shape[0])
+        if axis < 0:
+            axis += data_ndim
+        assert 0 <= axis < data_ndim
+        return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
index e07d153..fcbc3fd 100644 (file)
@@ -48,6 +48,22 @@ class Op(Expr):
         """
         return _OpGetAttr(self, attr_name)
 
+    def set_attr(self, attr_name, value, plevel=10):
+        """Set attribute about the operator.
+
+        Parameters
+        ----------
+        attr_name : str
+            The attribute name
+
+        value : object
+            The attribute value
+
+        plevel : int
+            The priority level
+        """
+        _OpSetAttr(self, attr_name, value, plevel)
+
 
 def get(op_name):
     """Get the Op for a given name
@@ -219,6 +235,26 @@ def register_gradient(op_name, fgradient=None, level=10):
     """
     return register(op_name, "FPrimalGradient", fgradient, level)
 
+def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
+    """Register operator shape function for an op.
+
+    Parameters
+    ----------
+    op_name : str
+        The name of the op.
+
+    data_dependant : bool
+        Whether the shape function depends on input data.
+
+    shape_func : function (attrs: Attrs, inputs: List[Tensor], out_ndims: List[IndexExpr])
+                 -> shape_tensors: List<Tensor>
+        The function for computing the dynamic output shapes
+
+    level : int
+        The priority level
+    """
+    get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
+    return register(op_name, "FShapeFunc", shape_func, level)
 
 _init_api("relay.op", __name__)
 
index f9814e8..934e894 100644 (file)
@@ -82,7 +82,25 @@ Operation HybridOpNode::make(std::string name,
 }
 
 Array<Tensor> HybridOpNode::InputTensors() const {
-  return inputs;
+  // Because input tensors could be potentially inlined into hybrid scripts,
+  // we need to check if all input tensors are used in the body.
+  std::unordered_set<Tensor> orig_inputs;
+  for (auto t : inputs) {
+    orig_inputs.insert(t);
+  }
+  std::unordered_set<Tensor> visited;
+  Array<Tensor> curr_inputs;
+  ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) {
+      const ir::Call *call = n.as<ir::Call>();
+      if (call != nullptr && call->func.defined()) {
+        Tensor t = Operation(call->func.node_).output(call->value_index);
+        if (orig_inputs.count(t) && !visited.count(t)) {
+          curr_inputs.push_back(t);
+          visited.insert(t);
+        }
+      }
+  });
+  return curr_inputs;
 }
 
 Operation HybridOpNode::ReplaceInputs(
@@ -111,7 +129,8 @@ void HybridOpNode::PropBoundToInputs(
     arith::Analyzer* analyzer,
     const std::unordered_map<const Variable*, IntSet> &dom_map,
     std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
-  for (Tensor t : this->inputs) {
+  auto curr_inputs = InputTensors();
+  for (Tensor t : curr_inputs) {
     auto it = out_dom_map->find(t);
     if (it == out_dom_map->end()) continue;
     TensorDom &dom = it->second;
@@ -180,11 +199,10 @@ Stmt HybridOpNode::BuildProvide(
       outputs[i]->dtype);
     f_push_bind(buffer, stage->op.output(i));
   }
-  for (int i = static_cast<int>(inputs.size()) - 1; i >= 0; --i) {
-    Buffer buffer = decl_buffer(
-      inputs[i]->shape,
-      inputs[i]->dtype);
-    f_push_bind(buffer, inputs[i]);
+  auto curr_inputs = InputTensors();
+  for (int i = static_cast<int>(curr_inputs.size()) - 1; i >= 0; --i) {
+    Buffer buffer = decl_buffer(curr_inputs[i]->shape, curr_inputs[i]->dtype);
+    f_push_bind(buffer, curr_inputs[i]);
   }
 
   std::unordered_map<Tensor, Tensor> rmap;
@@ -203,7 +221,7 @@ Stmt HybridOpNode::BuildProvide(
    *      tensors have the same names as the operation produces them.
    *   2. Once OpNode is wrapped up by an Operation node, it is finalized.
    *      Later access will be from a const OpNode*.
-   * This is a chiken-egg paradox. It is impossible to put the output
+   * This is a chicken-egg paradox. It is impossible to put the output
    * tensors into the function body without forming the op node. The
    * function body is immutable after the node is formed.
    *
index eba1cee..e118202 100644 (file)
@@ -41,7 +41,7 @@ namespace ir {
 using runtime::StorageRank;
 using runtime::StorageScope;
 
-// Find a linear pattern of storage acess
+// Find a linear pattern of storage access
 // Used for liveness analysis.
 // Composite scopes(loop/thread_launch/IfThen) is represented by two points:
 // before_scope -> scope_body -> after_scope
@@ -193,6 +193,10 @@ class LinearAccessPatternFinder final : public IRVisitor {
     VisitNewScope(op);
   }
 
+  void Visit_(const AssertStmt* op) final {
+    VisitNewScope(op);
+  }
+
   // linearized access sequence.
   std::vector<StmtEntry> linear_seq_;
   // The storage scope of each buffer
index ab90631..c88703e 100644 (file)
@@ -35,7 +35,9 @@
 #include <limits>
 #include <mutex>
 #include <functional>
+#include <vector>
 #include <unordered_map>
+#include "../ir/type_functor.h"
 #include "compile_engine.h"
 
 namespace tvm {
@@ -48,6 +50,43 @@ CCacheKey CCacheKeyNode::make(Function source_func, Target target) {
   return CCacheKey(n);
 }
 
+struct IsDynamicVisitor : public TypeVisitor {
+  bool is_dyn{false};
+  void VisitType_(const TensorTypeNode* tt) {
+    for (auto dim : tt->shape) {
+      if (dim.as<Any>()) {
+        is_dyn = true;
+        break;
+      }
+    }
+  }
+};
+
+bool IsDynamic(const Type& ty) {
+  IsDynamicVisitor v;
+  v.VisitType(ty);
+  return v.is_dyn;
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+  // for now, we always use int32 shape when possible
+  // even if the result of shape inference becomes int64.
+  Array<IndexExpr> res;
+  for (IndexExpr val : shape) {
+    const int64_t* pval = as_const_int(val);
+    if (pval != nullptr) {
+      CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+      CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+      res.push_back(ir::IntImm::make(Int(32), *pval));
+    } else if (val->is_type<ir::Any>()) {
+      res.push_back(val.as<ir::Any>()->ToVar());
+    } else {
+      res.push_back(val);
+    }
+  }
+  return res;
+}
+
 // The getter to get schedule from compile engine.
 // Get schedule from functor.
 class ScheduleGetter :
@@ -56,23 +95,6 @@ class ScheduleGetter :
   explicit ScheduleGetter(Target target)
       : target_(target) {}
 
-  Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
-    // for now, we always use int32 shape when possible
-    // even if the result of shape inference becomes int64.
-    Array<IndexExpr> res;
-    for (IndexExpr val : shape) {
-      const int64_t* pval = as_const_int(val);
-      if (pval != nullptr) {
-        CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
-        CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
-        res.push_back(ir::IntImm::make(Int(32), *pval));
-      } else {
-        res.push_back(val);
-      }
-    }
-    return res;
-  }
-
   std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
     static auto fschedule =
         Op::GetAttr<FTVMSchedule>("FTVMSchedule");
@@ -90,6 +112,7 @@ class ScheduleGetter :
         const auto* tuple_type = param->type_as<TupleTypeNode>();
         for (Type field : tuple_type->fields) {
           const auto* ttype = field.as<TensorTypeNode>();
+          // TODO(@icemelon): Allow recursive tuple
           CHECK(ttype != nullptr);
           tvm::Tensor tensor = tvm::placeholder(
               GetShape(ttype->shape), ttype->dtype);
@@ -283,6 +306,255 @@ class ScheduleGetter :
   Array<Operation> scalars_;
 };
 
+// Creates shape function from functor.
+class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
+ public:
+  MakeShapeFunc() {}
+
+  std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
+    for (auto param : prim_func->params) {
+      param_states_[param] = kNoNeed;
+      Array<tvm::Tensor> data_inputs;
+      Array<tvm::Tensor> shape_inputs;
+
+      auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) {
+        // Add data placeholder
+        Shape shape = GetShape(ttype->shape);
+        tvm::Tensor data_tensor = tvm::placeholder(shape, ttype->dtype);
+        data_inputs.push_back(data_tensor);
+        // Add shape placeholder
+        int64_t ndim = shape.size();
+        Shape sshape;
+        if (ndim > 0) {
+          sshape.push_back(tvm::Integer(ndim));
+        }
+        tvm::Tensor shape_tensor = tvm::placeholder(sshape, Int(64));
+        shape_inputs.push_back(shape_tensor);
+      };
+
+      if (const auto *ttype = param->checked_type().as<TensorTypeNode>()) {
+        add_placeholder(ttype);
+      } else {
+        // flatten tuple of tensor type.
+        const auto *tuple_type = param->type_as<TupleTypeNode>();
+        // TODO(@icemelon): Support recursive tuple
+        CHECK(tuple_type);
+        for (Type field : tuple_type->fields) {
+          const auto *ttype = field.as<TensorTypeNode>();
+          CHECK(ttype);
+          add_placeholder(ttype);
+        }
+      }
+      param_data_[param] = data_inputs;
+      param_shapes_[param] = shape_inputs;
+    }
+    readable_name_stream_ << "shape_func";
+    auto cache_node = make_node<CachedFuncNode>();
+    cache_node->outputs = VisitExpr(prim_func->body);
+    auto candidate_name = readable_name_stream_.str();
+    constexpr static size_t kMaxFuncNameLength = 80;
+    if (candidate_name.size() > kMaxFuncNameLength) {
+      std::stringstream truncated_name;
+      truncated_name <<  candidate_name.substr(0, kMaxFuncNameLength);
+      truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+      candidate_name = truncated_name.str();
+    }
+    cache_node->func_name = candidate_name;
+
+    // set inputs
+    for (auto param : prim_func->params) {
+      int state = param_states_[param];
+      cache_node->shape_func_param_states.push_back(IntImm::make(Int(32), state));
+      if (state & kNeedInputData) {
+        for (auto t : param_data_[param]) {
+          cache_node->inputs.push_back(t);
+        }
+      }
+      if (state & kNeedInputShape) {
+        for (auto t : param_shapes_[param]) {
+          cache_node->inputs.push_back(t);
+        }
+      }
+    }
+
+    CachedFunc cfunc(cache_node);
+    // generate schedule for shape func
+    Array<Operation> out_ops;
+    for (auto t : cache_node->outputs) {
+      out_ops.push_back(t->op);
+    }
+    auto schedule = create_schedule(out_ops);
+    tvm::schedule::AutoInlineInjective(schedule);
+    for (const auto& scalar : scalars_) {
+      auto scalar_op = scalar->op;
+      if (schedule->Contain(scalar_op)) {
+        schedule[scalar_op].compute_inline();
+      }
+    }
+    return std::make_pair(schedule, cfunc);
+  }
+
+  Array<Tensor> VisitExpr(const Expr& expr) {
+    auto it = memo_.find(expr);
+    if (it != memo_.end()) {
+      return it->second;
+    } else {
+      Array<Tensor> res = ExprFunctor::VisitExpr(expr);
+      if (expr.as<VarNode>() == nullptr) {
+        // Do not memoize vars because shape functions could use either the data
+        // or the shape of a var each time.
+        memo_[expr] = res;
+      }
+      return res;
+    }
+  }
+
+  Array<Tensor> VisitExpr_(const VarNode* var_node) final {
+    auto var = GetRef<Var>(var_node);
+    auto it = param_states_.find(var);
+    if (it == param_states_.end()) {
+      LOG(FATAL) << "Free variable " << var->name_hint();
+      return {};
+    } else {
+      CHECK(data_dependants_.size());
+      bool data_dependant = data_dependants_.back();
+      if (data_dependant) {
+        param_states_[var] |= kNeedInputData;
+        return param_data_[var];
+      } else {
+        param_states_[var] |= kNeedInputShape;
+        return param_shapes_[var];
+      }
+    }
+  }
+
+  Array<Tensor> VisitExpr_(const ConstantNode* op) final {
+    CHECK(data_dependants_.size());
+    CHECK(op->is_scalar());
+    bool data_dependant = data_dependants_.back();
+    if (data_dependant) {
+      void* data = op->data->data;
+      DataType dtype = TVMType2Type(op->data->dtype);
+      Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+          if (dtype == Int(32)) {
+            return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+          } else if (dtype == Int(64)) {
+            return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+          } else if (dtype == Float(32)) {
+            return make_const(dtype, static_cast<const float*>(data)[0]);
+          } else if (dtype == Float(64)) {
+            return make_const(dtype, static_cast<const double*>(data)[0]);
+          } else if (dtype == Bool()) {
+            return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+          } else {
+            LOG(FATAL) << "not handled";
+            return tvm::Expr();
+          }
+      }, "data_const", topi::kBroadcast);
+      scalars_.push_back(value);
+      return {value};
+    } else {
+      Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+          return make_const(Int(64), 0);
+      }, "shape_const", topi::kBroadcast);
+      scalars_.push_back(value);
+      return {value};
+    }
+  }
+
+  Array<Tensor> VisitExpr_(const CallNode* call_node) final {
+    static auto fshape_func = Op::GetAttr<FShapeFunc>("FShapeFunc");
+    static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>(
+        "TShapeDataDependant");
+    CHECK(call_node->op.as<OpNode>())
+      << "Primitive function only allows call into primitive ops";
+    Op op = Downcast<Op>(call_node->op);
+    CHECK(data_dependants_.empty() || !data_dependants_.back())
+      << "Error in op fusion: output of the shape func is fed to a "
+      << "data-dependant shape func";
+    CHECK_GT(fshape_func.count(op), 0)
+      << "Internal error, cannot find ShapeFunc for " << op->name;
+    CHECK_GT(tshape_data_dependant.count(op), 0)
+      << "Internal error, cannot find TShapeDataDependant for " << op->name;
+
+    data_dependants_.push_back(tshape_data_dependant[op]);
+    // Visit all inputs
+    Array<Tensor> inputs;
+    int count_tuple = 0;
+    for (Expr arg : call_node->args) {
+      if (arg->checked_type().as<TupleTypeNode>()) {
+        ++count_tuple;
+      }
+      for (Tensor tensor : VisitExpr(arg)) {
+        inputs.push_back(tensor);
+      }
+    }
+    if (count_tuple) {
+      CHECK_EQ(call_node->args.size(), 1U)
+        << "Only allow function with a single tuple input";
+    }
+    // Get output ndims
+    auto ret_type = call_node->checked_type();
+    Array<IndexExpr> out_ndims;
+    if (const auto* ttype = ret_type.as<TensorTypeNode>()) {
+      out_ndims.push_back(IntImm::make(Int(32), ttype->shape.size()));
+    } else {
+      auto rtype = ret_type.as<TupleTypeNode>();
+      // TODO(@icemelon): Allow recursive tuple
+      CHECK(rtype);
+      for (size_t i = 0; i < rtype->fields.size(); ++i) {
+        auto ttype = rtype->fields[i].as<TensorTypeNode>();
+        CHECK(ttype);
+        out_ndims.push_back(IntImm::make(Int(32), ttype->shape.size()));
+      }
+    }
+    // Call shape function
+    auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims);
+    data_dependants_.pop_back();
+    readable_name_stream_ << "_" << op->name;
+    return outputs;
+  }
+
+  Array<Tensor> VisitExpr_(const FunctionNode* op) final {
+    LOG(FATAL) << "Do not support sub function";
+    return Array<Tensor>();
+  }
+
+  Array<Tensor> VisitExpr_(const LetNode* op) final {
+    Array<Tensor> val = VisitExpr(op->value);
+    CHECK(!memo_.count(op->var));
+    memo_[op->var] = val;
+    return VisitExpr(op->body);
+  }
+
+  Array<Tensor> VisitExpr_(const TupleNode* op) final {
+    Array<Tensor> fields;
+    for (Expr field : op->fields) {
+      CHECK(field->checked_type().as<TensorTypeNode>())
+        << "Only allow Tuple of Tensor";
+      Array<Tensor> res = VisitExpr(field);
+      CHECK_EQ(res.size(), 1);
+      fields.push_back(res[0]);
+    }
+    return fields;
+  }
+
+ private:
+  /*! \brief String stream for function name */
+  std::ostringstream readable_name_stream_;
+  /*! \brief Map from parameter to its shape function usage state */
+  std::unordered_map<Expr, int, NodeHash, NodeEqual> param_states_;
+  /*! \brief Map from parameter to list of data placeholder */
+  std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> param_data_;
+  /*! \brief Map from parameter to list of shape placeholder */
+  std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> param_shapes_;
+  /*! \brief Memoized visit result */
+  std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
+  /*! \brief Stack of data dependencies for shape function */
+  std::vector<bool> data_dependants_;
+  /*! \brief Scalars used in the shape function */
+  Array<Tensor> scalars_;
+};
 
 class CompileEngineImpl : public CompileEngineNode {
  public:
@@ -304,6 +576,11 @@ class CompileEngineImpl : public CompileEngineNode {
     }
     return value->packed_func;
   }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
   void Clear() final {
     cache_.clear();
   }
@@ -379,6 +656,40 @@ class CompileEngineImpl : public CompileEngineNode {
     value->cached_func = CachedFunc(cache_node);
     return value;
   }
+  // implement lowered shape func
+  CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = shape_func_cache_.find(key);
+    if (it != shape_func_cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_node<CCacheValueNode>());
+      value->use_count = 0;
+      shape_func_cache_[key] = value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    CHECK(!value->cached_func.defined());
+    auto spair = MakeShapeFunc().Create(key->source_func);
+    auto cache_node = make_node<CachedFuncNode>(
+            *(spair.second.operator->()));
+    cache_node->func_name = GetUniqueName(cache_node->func_name);
+    cache_node->target = key->target;
+
+    Array<Tensor> all_args = cache_node->inputs;
+    for (Tensor arg : cache_node->outputs) {
+      all_args.push_back(arg);
+    }
+    tvm::BuildConfig bcfg = BuildConfig::Create();
+    std::unordered_map<Tensor, Buffer> binds;
+    cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
+    value->cached_func = CachedFunc(cache_node);
+    return value;
+  }
   /*!
    * \brief Get unique name from name.
    * \param name The orginal name.
@@ -408,6 +719,8 @@ class CompileEngineImpl : public CompileEngineNode {
   std::unordered_map<std::string, int> name_map_;
   /*! \brief internal compiler cache */
   std::unordered_map<CCacheKey, CCacheValue> cache_;
+  /*! \brief internal compiler cache for shape funcs */
+  std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
 };
 
 /*! \brief The global compile engine */
index 9765cf9..84197db 100644 (file)
 namespace tvm {
 namespace relay {
 
+/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */
+enum ShapeFuncParamState {
+  kNoNeed = 0,
+  kNeedInputData = 1,
+  kNeedInputShape = 2,
+  kNeedBoth = 3,
+};
+
 /*! \brief Node container to represent a cached function. */
 struct CachedFuncNode : public Node {
   /* \brief compiled target */
@@ -48,6 +56,8 @@ struct CachedFuncNode : public Node {
   tvm::Array<Tensor> outputs;
   /*! \brief The lowered functions to support the function. */
   tvm::Array<tvm::LoweredFunc> funcs;
+  /*! \brief Parameter usage states in the shape function. */
+  tvm::Array<Integer> shape_func_param_states;
 
   void VisitAttrs(tvm::AttrVisitor* v) final {
     v->Visit("target", &target);
@@ -55,6 +65,7 @@ struct CachedFuncNode : public Node {
     v->Visit("inputs", &inputs);
     v->Visit("outputs", &outputs);
     v->Visit("funcs", &funcs);
+    v->Visit("shape_func_param_states", &shape_func_param_states);
   }
 
   static constexpr const char* _type_key = "relay.CachedFunc";
@@ -170,6 +181,12 @@ class CompileEngineNode : public Node {
    * \return The result.
    */
   virtual PackedFunc JIT(const CCacheKey& key) = 0;
+  /*!
+   * \brief Lower the shape function.
+   * \param key The key to the cached function.
+   * \return The result.
+   */
+  virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
   /*! \brief clear the cache. */
   virtual void Clear() = 0;
 
@@ -180,7 +197,7 @@ class CompileEngineNode : public Node {
   TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
 };
 
-/*! \brier cache entry used in compile engine */
+/*! \brief cache entry used in compile engine */
 class CompileEngine : public NodeRef {
  public:
   CompileEngine() {}
@@ -193,6 +210,13 @@ class CompileEngine : public NodeRef {
   TVM_DLL static const CompileEngine& Global();
 };
 
+/*!
+ * \brief Check if the type is dynamic.
+ * \param ty The type to be checked.
+ * \return The result.
+ */
+bool IsDynamic(const Type& ty);
+
 // implementations
 inline size_t CCacheKeyNode::Hash() const {
   if (hash_ != 0) return hash_;
index 913d7ad..dedff7a 100644 (file)
@@ -280,7 +280,7 @@ class Interpreter :
     return TupleValueNode::make(values);
   }
 
-  // TODO(@jroesch): this doesn't support mututal letrec
+  // TODO(@jroesch): this doesn't support mutual letrec
   inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
     tvm::Map<Var, Value> captured_mod;
     Array<Var> free_vars = FreeVars(func);
@@ -310,7 +310,125 @@ class Interpreter :
     return MakeClosure(func);
   }
 
-  Value InvokePrimitiveOp(Function func,
+  Array<Shape> ComputeDynamicShape(const Function& func,
+                                   const Array<Value>& args) {
+    auto key = CCacheKeyNode::make(func, Target::Create("llvm"));
+    auto cfunc = engine_->LowerShapeFunc(key);
+    size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
+
+    std::vector<TVMValue> values(arity);
+    std::vector<int> codes(arity);
+    TVMArgsSetter setter(values.data(), codes.data());
+    std::vector<NDArray> inputs(cfunc->inputs.size());
+    std::vector<NDArray> outputs(cfunc->outputs.size());
+
+    DLContext cpu_ctx;
+    cpu_ctx.device_type = kDLCPU;
+    cpu_ctx.device_id = 0;
+
+    auto fset_input = [&](size_t i, Value val, bool need_shape) {
+        const TensorValueNode* tv = val.as<TensorValueNode>();
+        CHECK(tv != nullptr) << "expect Tensor argument";
+        if (need_shape) {
+          int64_t ndim = tv->data.Shape().size();
+          NDArray shape_arr;
+          if (ndim == 0) {
+            shape_arr = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
+          } else {
+            shape_arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
+            int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
+            for (auto j = 0; j < ndim; ++j) {
+              data[j] = tv->data.Shape()[j];
+            }
+          }
+          inputs[i] = shape_arr;
+          setter(i, shape_arr);
+        } else {
+          auto arr = tv->data.CopyTo(cpu_ctx);
+          inputs[i] = arr;
+          setter(i, arr);
+        }
+    };
+
+    size_t arg_counter = 0;
+    for (size_t i = 0; i < args.size(); ++i) {
+      auto arg = args[i];
+      auto param = func->params[i];
+      int state = cfunc->shape_func_param_states[i]->value;
+      if (arg.as<TensorValueNode>()) {
+        if (state & kNeedInputData) {
+          fset_input(arg_counter++, arg, false);
+        }
+        if (state & kNeedInputShape) {
+          fset_input(arg_counter++, arg, true);
+        }
+      } else {
+        const TupleValueNode* tuple = arg.as<TupleValueNode>();
+        CHECK(tuple != nullptr);
+        if (state & kNeedInputData) {
+          for (size_t i = 0; i < tuple->fields.size(); ++i) {
+            fset_input(arg_counter++, tuple->fields[i], false);
+          }
+        }
+        if (state & kNeedInputShape) {
+          for (size_t i = 0; i < tuple->fields.size(); ++i) {
+            fset_input(arg_counter++, tuple->fields[i], true);
+          }
+        }
+      }
+    }
+    CHECK_EQ(arg_counter, cfunc->inputs.size())
+      << "Shape function input sizes mismatch";
+
+    auto fset_shape_output = [&](size_t i, Type val_type) {
+        // TODO(@icemelon): allow recursive tuple
+        const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
+        CHECK(rtype != nullptr);
+        int64_t ndim = rtype->shape.size();
+        auto arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
+        outputs[i] = arr;
+        setter(arg_counter + i, arr);
+    };
+
+    auto ret_type = func->body->checked_type();
+    size_t out_cnt = 0;
+    if (auto rtype = ret_type.as<TupleTypeNode>()) {
+      out_cnt = rtype->fields.size();
+      for (size_t i = 0; i < out_cnt; ++i) {
+        fset_shape_output(i, rtype->fields[i]);
+      }
+    } else {
+      out_cnt = 1;
+      auto tt = Downcast<TensorType>(ret_type);
+      fset_shape_output(0, tt);
+    }
+    CHECK_EQ(cfunc->outputs.size(), out_cnt)
+      << "Shape function output sizes mismatch";
+
+    PackedFunc shape_func;
+    TVMRetValue rv;
+    if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
+      tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target);
+      shape_func = m.GetFunction(cfunc->func_name);
+    } else {
+      LOG(FATAL) << "relay.backend.build is not registered";
+    }
+    shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
+
+    // Get output shapes
+    Array<Shape> out_shapes;
+    for (auto out_tensor : outputs) {
+      int64_t* shape_data = reinterpret_cast<int64_t*>(out_tensor->data);
+      Shape out_shape;
+      for (int i = 0; i < out_tensor->shape[0]; ++i) {
+        out_shape.push_back(tvm::Integer(shape_data[i]));
+      }
+      out_shapes.push_back(out_shape);
+    }
+    return out_shapes;
+  }
+
+  Value InvokePrimitiveOp(const Function& func,
                           const Array<Value>& args) {
     auto call_node = func->body.as<CallNode>();
 
@@ -394,17 +512,46 @@ class Interpreter :
       return out_tensor;
     };
 
+    Array<Shape> out_shapes;
+    auto ret_type = func->body->checked_type();
+    bool is_dyn = IsDynamic(func->checked_type());
+    if (call_node->op == Op::Get("shape_of")) {
+      // The output shape of shape_of must be static since Relay doesn't support
+      // dynamic rank tensors.
+      is_dyn = false;
+    }
+
+    if (is_dyn) {
+      CHECK(func->IsPrimitive());
+      out_shapes = ComputeDynamicShape(func, args);
+    }
+
     PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_));
     TVMRetValue rv;
     if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
+      CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
       Array<Value> fields;
       for (size_t i = 0; i < rtype->fields.size(); ++i) {
-        fields.push_back(fset_output(i, rtype->fields[i]));
+        if (is_dyn) {
+          auto sh = out_shapes[i];
+          auto tt = Downcast<TensorType>(rtype->fields[i]);
+          fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype)));
+        } else {
+          fields.push_back(fset_output(i, rtype->fields[i]));
+        }
       }
       packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
       return TupleValueNode::make(fields);
     } else {
-      Value out_tensor = fset_output(0, func->body->checked_type());
+      Value out_tensor;
+      if (is_dyn) {
+        CHECK_EQ(out_shapes.size(), 1);
+        auto sh = out_shapes[0];
+        auto tt = Downcast<TensorType>(ret_type);
+        out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype));
+      } else {
+        out_tensor = fset_output(0, ret_type);
+      }
       packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
       return out_tensor;
     }
index 17de083..5e5bc1a 100644 (file)
  * \brief A compiler from relay::Module to the VM byte code.
  */
 
+#include <tvm/operation.h>
 #include <tvm/relay/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/interpreter.h>
 #include <tvm/logging.h>
 #include <tvm/relay/transform.h>
 #include <tvm/runtime/vm.h>
+#include <topi/tags.h>
+#include <algorithm>
 #include <iostream>
 #include <memory>
 #include <set>
@@ -61,42 +64,6 @@ using namespace relay::transform;
 // (@jroesch): VM passes, eventually declare as passes.
 bool IsClosure(const Function& func);
 
-// Compute the constant pool, i.e a mapping from Constant node to constant index.
-struct ConstantPool : ExprVisitor {
-  std::set<GlobalVar> visited;
-  Module module;
-  ConstMap const_map;
-  ConstTensorShapeMap const_tensor_shape_map;
-
-  size_t index;
-
-  explicit ConstantPool(const Module& mod) : module(mod), const_map(), index(0) {}
-
-  void VisitExpr_(const GlobalVarNode* var_node) {
-    auto gvar = GetRef<GlobalVar>(var_node);
-    if (visited.find(gvar) == visited.end()) {
-      visited.insert(gvar);
-      this->VisitExpr(this->module->Lookup(gvar));
-    }
-  }
-
-  void VisitExpr_(const ConstantNode* const_node) {
-    auto konst = GetRef<Constant>(const_node);
-    auto it = this->const_map.find(konst);
-    if (it == this->const_map.end()) {
-      this->const_map.insert({konst, index++});
-    }
-  }
-};
-
-std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& module) {
-  auto cp = ConstantPool(module);
-  for (auto& func : module->functions) {
-    cp.VisitExpr(func.first);
-  }
-  return std::make_tuple(cp.const_map, cp.const_tensor_shape_map);
-}
-
 void InstructionPrint(std::ostream& os, const Instruction& instr);
 
 // Represent a runtime object that's going to be matched by pattern match expressions
@@ -220,12 +187,13 @@ TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause>
 
 class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
  public:
-  VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets)
+  VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
       : last_register_(0),
         registers_num_(0),
         engine_(CompileEngine::Global()),
         context_(context),
-        targets_(targets) {}
+        targets_(targets),
+        target_host_(target_host) {}
 
   VMFunction Compile(const GlobalVar& var, const Function& func) {
     size_t i = 0;
@@ -288,10 +256,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
   }
 
   void VisitExpr_(const ConstantNode* const_node) {
-    auto rconst = GetRef<Constant>(const_node);
-    auto it = this->context_->const_map.find(rconst);
-    CHECK(it != this->context_->const_map.end());
-    Emit(Instruction::LoadConst(it->second, NewRegister()));
+    size_t konst_idx = context_->constants.size();
+    context_->constants.push_back(const_node->data);
+    Emit(Instruction::LoadConst(konst_idx, NewRegister()));
   }
 
   void VisitExpr_(const VarNode* var_node) {
@@ -326,7 +293,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
   }
 
   void VisitExpr_(const LetNode* let_node) {
-    DLOG(INFO) << let_node->value;
+    DLOG(INFO) << AsText(let_node->value);
     this->VisitExpr(let_node->value);
     var_register_map_.insert({let_node->var, this->last_register_});
     this->VisitExpr(let_node->body);
@@ -393,29 +360,206 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
     this->last_register_ = true_register;
   }
 
-  Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
-    TVMType dltype = Type2TVMType(ttype->dtype);
-    auto tensor_type = GetRef<TensorType>(ttype);
+  Index EmitGetShape(const TensorTypeNode* ttype, Index reg) {
+    bool const_shape = true;
     std::vector<int64_t> shape;
-    for (auto dim : tensor_type->shape) {
-      shape.push_back(Downcast<tvm::Integer>(dim)->value);
+    for (auto dim : ttype->shape) {
+      if (auto kdim = dim.as<IntImm>()) {
+        shape.push_back(kdim->value);
+      } else {
+        const_shape = false;
+      }
+    }
+    if (const_shape) {
+      int64_t ndim = shape.size();
+      DLContext cpu_ctx;
+      cpu_ctx.device_type = kDLCPU;
+      cpu_ctx.device_id = 0;
+      NDArray shape_tensor;
+      if (ndim == 0) {
+        shape_tensor = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
+      } else {
+        shape_tensor = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
+        int64_t* dims = reinterpret_cast<int64_t*>(shape_tensor->data);
+        for (size_t i = 0; i < shape.size(); ++i) {
+          dims[i] = shape[i];
+        }
+      }
+      size_t konst_idx = context_->constants.size();
+      context_->constants.push_back(shape_tensor);
+      Emit(Instruction::LoadConst(konst_idx, NewRegister()));
+      return last_register_;
+    }
+    // For dynamic shape, we need insert shape_of op to get its shape at runtime
+    auto attrs = make_node<ShapeOfAttrs>();
+    attrs->dtype = Int(64);
+    static const Op& op = Op::Get("shape_of");
+    auto input = VarNode::make("input", GetRef<Type>(ttype));
+    auto expr = CallNode::make(op, {input}, Attrs(attrs), {});
+    auto func = FunctionNode::make({input}, expr, IncompleteTypeNode::make(Kind::kType), {});
+    auto mod = ModuleNode::make({}, {});
+    auto main_gv = GlobalVarNode::make("main");
+    mod->Add(main_gv, func);
+    func = mod->Lookup(main_gv);
+
+    // shape_of op has to be run on the host target
+    // TODO(@icemelon9): handle heterogeneous target, such as cuda
+    auto key = CCacheKeyNode::make(func, target_host_);
+    auto cfunc = engine_->Lower(key);
+    auto op_index = -1;
+    if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
+      op_index = context_->cached_funcs.size();
+      context_->cached_funcs.push_back(cfunc);
+      context_->seen_funcs[cfunc->funcs[0]] = op_index;
+    } else {
+      op_index = context_->seen_funcs[cfunc->funcs[0]];
+    }
+    std::vector<Index> arg_regs{reg};
+    int64_t ndim = ttype->shape.size();
+    if (ndim == 0) {
+      Emit(Instruction::AllocTensor({}, Int(64), NewRegister()));
+    } else {
+      Emit(Instruction::AllocTensor({ndim}, Int(64), NewRegister()));
+    }
+    Index shape_reg = last_register_;
+    arg_regs.push_back(shape_reg);
+    Emit(Instruction::InvokePacked(op_index, 2, 1, arg_regs));
+    return shape_reg;
+  }
+
+  std::vector<Index> EmitShapeFunc(const Type& ret_type, const Function& func,
+                                   const std::vector<Index>& unpacked_arg_regs) {
+    // Find the mapping from params to registers
+    int idx = 0;
+    std::vector<std::vector<Index>> param_regs;
+    std::vector<std::vector<const TensorTypeNode*>> param_types;
+    for (auto param : func->params) {
+      auto ty = param->checked_type();
+      std::vector<Index> regs;
+      std::vector<const TensorTypeNode*> types;
+      if (auto ttype = ty.as<TensorTypeNode>()) {
+        regs.push_back(unpacked_arg_regs[idx++]);
+        types.push_back(ttype);
+      } else if (const auto tuple_ty = ret_type.as<TupleTypeNode>()) {
+        for (size_t j = 0; j < tuple_ty->fields.size(); ++j, ++idx) {
+          regs.push_back(unpacked_arg_regs[idx]);
+          auto ttype = tuple_ty->fields[j].as<TensorTypeNode>();
+          CHECK(ttype);
+          types.push_back(ttype);
+        }
+      } else {
+        LOG(FATAL) << "unsupported parameter type " << ty;
+      }
+      param_regs.push_back(regs);
+      param_types.push_back(types);
+    }
+
+    // Lower shape function
+    auto key = CCacheKeyNode::make(func, target_host_);
+    auto cfunc = engine_->LowerShapeFunc(key);
+    int op_index = -1;
+    if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) {
+      op_index = context_->cached_funcs.size();
+      context_->cached_funcs.push_back(cfunc);
+      context_->seen_funcs[cfunc->funcs[0]] = op_index;
+    } else {
+      op_index = context_->seen_funcs[cfunc->funcs[0]];
+    }
+
+    // Prepare input and output registers
+    std::vector<Index> shape_func_args;
+    std::vector<Index> shape_regs;
+    for (size_t i = 0; i < func->params.size(); ++i) {
+      int state = cfunc->shape_func_param_states[i]->value;
+      if (state & kNeedInputData) {
+        for (auto reg : param_regs[i]) {
+          // TODO(@icemelon9): Need to copy data here for heterogeneous exec
+          shape_func_args.push_back(reg);
+        }
+      }
+      if (state & kNeedInputShape) {
+        for (size_t j = 0; j < param_regs[i].size(); ++j) {
+          shape_func_args.push_back(EmitGetShape(param_types[i][j], param_regs[i][j]));
+        }
+      }
+    }
+    for (auto t : cfunc->outputs) {
+      int64_t ndim = t->shape[0].as<IntImm>()->value;
+      Emit(Instruction::AllocTensor({ndim}, t->dtype, NewRegister()));
+      shape_func_args.push_back(last_register_);
+      shape_regs.push_back(last_register_);
+    }
+
+    int arity = shape_func_args.size();
+    int ret_count = shape_regs.size();
+    Emit(Instruction::InvokePacked(op_index, arity, ret_count, shape_func_args));
+
+    // Alloc return tensors given the shape regs
+    std::vector<DataType> ret_dtypes;
+    if (const auto* tuple_type = ret_type.as<TupleTypeNode>()) {
+      for (auto field : tuple_type->fields) {
+        const TensorTypeNode* tty = field.as<TensorTypeNode>();
+        CHECK(tty);
+        ret_dtypes.push_back(tty->dtype);
+      }
+    } else {
+      auto tty = ret_type.as<TensorTypeNode>();
+      CHECK(tty);
+      ret_dtypes.push_back(tty->dtype);
+    }
+    std::vector<Index> ret_regs;
+    for (size_t i = 0; i < shape_regs.size(); ++i) {
+      Emit(Instruction::AllocTensorReg(shape_regs[i], ret_dtypes[i], NewRegister()));
+      ret_regs.push_back(last_register_);
+    }
+    return ret_regs;
+  }
+
+  std::vector<Index> AllocReturnType(const Type& ret_type, const Function& func,
+                                     const std::vector<Index>& unpacked_arg_regs) {
+    auto op = func->body.as<CallNode>()->op;
+    // 1. If either func param types or ret type is dynamic, we need to insert
+    // shape func to perform type checking at runtime.
+    // 2. We skip the shape_of function since currently Relay doesn't support
+    // dynamic rank tensor.
+    if (op != Op::Get("shape_of") && IsDynamic(func->checked_type())) {
+      return EmitShapeFunc(ret_type, func, unpacked_arg_regs);
     }
-    return Instruction::AllocTensor(shape, dltype, NewRegister());
+    std::vector<Index> ret_regs;
+    auto alloc_tensor = [&](const TensorTypeNode* ttype) {
+      const TensorType& tensor_type = GetRef<TensorType>(ttype);
+      std::vector<int64_t> shape;
+      for (auto dim : tensor_type->shape) {
+        shape.push_back(Downcast<tvm::Integer>(dim)->value);
+      }
+      Emit(Instruction::AllocTensor(shape, Type2TVMType(tensor_type->dtype), NewRegister()));
+      ret_regs.push_back(last_register_);
+    };
+    if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
+      alloc_tensor(ttype);
+    } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
+      for (auto field : ttype->fields) {
+        alloc_tensor(field.as<TensorTypeNode>());
+      }
+    } else {
+      LOG(FATAL) << "Unsupported return value type";
+    }
+    return ret_regs;
   }
 
   void EmitInvokePrimitive(const Function& func,
-                           const std::vector<Index>& args_registers,
+                           const std::vector<Index>& arg_registers,
                            const Type& ret_type) {
     std::vector<Index> unpacked_arg_regs;
     std::vector<Instruction> allocs;
 
     // Arity calculation must flatten tuples.
     size_t arity = 0;
-    CHECK_EQ(func->params.size(), args_registers.size());
+    CHECK_EQ(func->params.size(), arg_registers.size());
     for (size_t i = 0; i < func->params.size(); i++) {
       auto ty = func->params[i]->checked_type();
       if (ty.as<TensorTypeNode>()) {
-        unpacked_arg_regs.push_back(args_registers[i]);
+        unpacked_arg_regs.push_back(arg_registers[i]);
         arity += 1;
       } else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
         for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
@@ -424,7 +568,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
             << "only supports non-nested tuples currently "
             << "found " << field;
           auto dst =  NewRegister();
-          Emit(Instruction::GetField(args_registers[i], f, dst));
+          Emit(Instruction::GetField(arg_registers[i], f, dst));
           unpacked_arg_regs.push_back(dst);
         }
         arity += tuple_ty->fields.size();
@@ -433,30 +577,11 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
       }
     }
 
-    size_t return_val_count = 0;
-    if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
-      // Allocate space for the return tensor.
-      auto alloc = AllocTensorFromType(ttype);
-      allocs.push_back(alloc);
-      return_val_count = 1;
-    } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
-      std::vector<Index> fields_registers;
-
-      for (size_t i = 0; i < ttype->fields.size(); ++i) {
-        auto f = ttype->fields[i];
-        auto f_type = f.as<TensorTypeNode>();
-        allocs.push_back(AllocTensorFromType(f_type));
-        fields_registers.push_back(allocs.back().dst);
-      }
-      return_val_count = ttype->fields.size();
-    } else {
-      LOG(FATAL) << "Unsupported return value type";
-    }
-
-    arity += return_val_count;
-    for (auto& alloc : allocs) {
-      Emit(alloc);
-      unpacked_arg_regs.push_back(alloc.dst);
+    auto ret_regs = AllocReturnType(ret_type, func, unpacked_arg_regs);
+    size_t return_count = ret_regs.size();
+    arity += return_count;
+    for (auto reg : ret_regs) {
+      unpacked_arg_regs.push_back(reg);
     }
 
     // Next generate the invoke instruction.
@@ -477,22 +602,22 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
     CHECK_EQ(cfunc->funcs.size(), 1);
     auto op_index = -1;
     if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
-      op_index = context_->lowered_funcs.size();
-      context_->lowered_funcs.push_back(cfunc->funcs[0]);
+      op_index = context_->cached_funcs.size();
+      context_->cached_funcs.push_back(cfunc);
       context_->seen_funcs[cfunc->funcs[0]] = op_index;
     } else {
       op_index = context_->seen_funcs[cfunc->funcs[0]];
     }
 
-    Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
+    Emit(Instruction::InvokePacked(op_index, arity, return_count, unpacked_arg_regs));
 
-    if (return_val_count > 1) {
+    if (return_count > 1) {
       // return value is a tuple, we need to create a tuple
       std::vector<Index> fields_registers;
-      for (size_t i = arity - return_val_count; i < arity; ++i) {
+      for (size_t i = arity - return_count; i < arity; ++i) {
         fields_registers.push_back(unpacked_arg_regs[i]);
       }
-      Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister()));
+      Emit(Instruction::AllocDatatype(0, return_count, fields_registers, NewRegister()));
     }
   }
 
@@ -636,6 +761,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
   VMCompilerContext* context_;
   /*! \brief Target devices. */
   TargetsMap targets_;
+  /*! \brief Host target. */
+  Target target_host_;
 };
 
 
@@ -676,28 +803,18 @@ void VMCompiler::Compile(const Module& mod_ref,
   // in the VMFunction table.
   PopulateGlobalMap();
 
-  // Next we populate constant map.
-  auto constant_analysis_result = LayoutConstantPool(context_.module);
-  context_.const_map = std::get<0>(constant_analysis_result);
-  context_.const_tensor_shape_map = std::get<1>(constant_analysis_result);
-
   // Next we get ready by allocating space for
   // the global state.
   vm_->functions.resize(context_.module->functions.size());
-  vm_->constants.resize(context_.const_map.size() + context_.const_tensor_shape_map.size());
 
-  for (auto pair : context_.const_map) {
-    vm_->constants[pair.second] = Object::Tensor(pair.first->data);
-  }
-
-  for (auto pair : context_.const_tensor_shape_map) {
-    vm_->constants[pair.second.first] = Object::Tensor(pair.second.second);
-  }
+  // Next we get ready by allocating space for
+  // the global state.
+  vm_->functions.resize(context_.module->functions.size());
 
   for (auto named_func : context_.module->functions) {
     auto gvar = named_func.first;
     auto func = named_func.second;
-    VMFunctionCompiler func_compiler(&context_, targets_);
+    VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
     auto vm_func = func_compiler.Compile(gvar, func);
 
     size_t func_index = context_.global_map.at(gvar);
@@ -711,6 +828,11 @@ void VMCompiler::Compile(const Module& mod_ref,
   }
 #endif  // USE_RELAY_DEBUG
 
+  // populate constants
+  for (auto data : context_.constants) {
+    vm_->constants.push_back(Object::Tensor(data));
+  }
+
   LibraryCodegen();
 
   for (auto gv : context_.global_map) {
@@ -721,11 +843,13 @@ void VMCompiler::Compile(const Module& mod_ref,
 Module VMCompiler::OptimizeModule(const Module& mod) {
   // TODO(@icemelon9): check number of targets and build config, add more optimization pass
   transform::Sequential seq({transform::SimplifyInference(),
-                             transform::ToANormalForm(),
                              transform::InlinePrimitives(),
+                             // TODO(@wweic): FuseOps pass currently don't handle Let
+                             // For now, we put FuseOps before ToANormalForm to enable it
+                             transform::FuseOps(),
+                             transform::ToANormalForm(),
                              transform::LambdaLift(),
-                             transform::InlinePrimitives(),
-                             transform::FuseOps()});
+                             transform::InlinePrimitives()});
   auto pass_ctx = transform::PassContext::Create();
   tvm::With<relay::transform::PassContext> ctx(pass_ctx);
   return seq(mod);
@@ -741,27 +865,36 @@ void VMCompiler::PopulateGlobalMap() {
 }
 
 void VMCompiler::LibraryCodegen() {
-  auto const& lowered_funcs = context_.lowered_funcs;
-  if (lowered_funcs.size() == 0) {
+  auto const &cached_funcs = context_.cached_funcs;
+  if (cached_funcs.size() == 0) {
     return;
   }
-  // TODO(@icemelon9): support heterogeneous targets
-  Target target;
-  for (auto kv : targets_) {
-    target = kv.second;
+  std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs;
+  for (auto &cfunc : cached_funcs) {
+    std::string target_str = cfunc->target->str();
+    if (tgt_funcs.count(target_str) == 0) {
+      tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
+    } else {
+      tgt_funcs[target_str].push_back(cfunc->funcs[0]);
+    }
   }
-  if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
-    runtime::Module mod =
-        (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target,
-             target_host_);
+  Map<Target, Array<LoweredFunc>> funcs;
+  for (auto &it : tgt_funcs) {
+    funcs.Set(Target::Create(it.first), it.second);
+  }
+
+  if (const auto *f = runtime::Registry::Get("relay.backend.build")) {
+    // The target is just a dummy arg because funcs already contains corresponding target
+    // therefore target won't be used in the build function
+    runtime::Module mod = (*f)(funcs, Target(), target_host_);
     CHECK(mod.operator->());
     vm_->lib = mod;
   } else {
     LOG(FATAL) << "relay.backend.build is not registered";
   }
   size_t primitive_index = 0;
-  for (auto lfunc : lowered_funcs) {
-    vm_->primitive_map.insert({lfunc->name, primitive_index++});
+  for (auto cfunc : cached_funcs) {
+    vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
   }
 }
 
index 4a2de0a..bfe19ac 100644 (file)
@@ -72,12 +72,10 @@ struct VMCompilerContext {
   TagMap tag_map;
   // Map from global var to a unique integer
   GlobalMap global_map;
-  // Map from Const object to its index in const pool
-  ConstMap const_map;
-  // Map from Const tensor shape to its index in const pool
-  ConstTensorShapeMap const_tensor_shape_map;
-  // List of lowered functions
-  std::vector<LoweredFunc> lowered_funcs;
+  // List of constants
+  std::vector<NDArray> constants;
+  // List of cached functions
+  std::vector<CachedFunc> cached_funcs;
   // The functions that have been lowered.
   std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
 };
index b4303e7..76b56ae 100644 (file)
@@ -121,14 +121,25 @@ TVM_REGISTER_API("relay.op._ListOpNames")
 TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);
 
 TVM_REGISTER_API("relay.op._OpGetAttr")
-    .set_body([](TVMArgs args, TVMRetValue* rv) {
-      Op op = args[0];
-      std::string attr_name = args[1];
-      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
-      if (op_map.count(op)) {
-        *rv = op_map[op];
-      }
-    });
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    Op op = args[0];
+    std::string attr_name = args[1];
+    auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+    if (op_map.count(op)) {
+      *rv = op_map[op];
+    }
+  });
+
+TVM_REGISTER_API("relay.op._OpSetAttr")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    Op op = args[0];
+    std::string attr_name = args[1];
+    runtime::TVMArgValue value = args[2];
+    int plevel = args[3];
+    auto& reg =
+        OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
+    reg.set_attr(attr_name, value, plevel);
+  });
 
 TVM_REGISTER_API("relay.op._Register")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
index c3975c3..6305d22 100644 (file)
@@ -528,7 +528,11 @@ bool ReshapeRel(const Array<Type>& types,
       used_input_dims.insert(src_idx);
       IndexExpr d2 = data_shape[src_idx++];
       used_output_dims.insert(oshape.size());
-      oshape.push_back(d1 * d2);
+      if (d1.as<Any>() || d2.as<Any>()) {
+        oshape.push_back(Any::make());
+      } else {
+        oshape.push_back(d1 * d2);
+      }
     } else if (svalue == -4) {
       // split the source dim s into two dims
       // read the left dim and then the right dim (either can be -1)
@@ -563,6 +567,8 @@ bool ReshapeRel(const Array<Type>& types,
           oshape.push_back(d2);
         }
       }
+    } else {
+      CHECK(false) << "Unsupported special value: " << svalue;
     }
   }
 
@@ -608,7 +614,15 @@ Array<Tensor> ReshapeCompute(const Attrs& attrs,
                              const Target& target) {
   const auto* out_ttype = out_type.as<TensorTypeNode>();
   CHECK(out_ttype != nullptr);
-  return { topi::reshape(inputs[0], out_ttype->shape) };
+  Array<IndexExpr> newshape;
+  for (auto val : out_ttype->shape) {
+    if (val->is_type<ir::Any>()) {
+      newshape.push_back(val.as<ir::Any>()->ToVar());
+    } else {
+      newshape.push_back(val);
+    }
+  }
+  return { topi::reshape(inputs[0], newshape) };
 }
 
 Expr MakeReshape(Expr data,
@@ -1108,7 +1122,8 @@ RELAY_REGISTER_OP("arange")
 .set_support_level(3)
 .add_type_rel("Arange", ArangeRel)
 .set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective)
+// TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape
+.set_attr<TOpPattern>("TOpPattern", kOpaque)
 .set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
 
 // repeat operator
index d3afb91..826fe69 100644 (file)
@@ -295,7 +295,9 @@ RELAY_REGISTER_OP("shape_of")
 .add_argument("data", "Tensor", "The input tensor.")
 .add_type_rel("ShapeOf", ShapeOfRel)
 .set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<TOpPattern>("TOpPattern", kInjective)
+// Use kOpaque for shape_of op for now since it won't be performance critic,
+// and it makes things easier for dynamic shape func
+.set_attr<TOpPattern>("TOpPattern", kOpaque)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
                                ElemwiseArbitraryLayout)
 .set_support_level(10)
index d4efe80..f71b85d 100644 (file)
@@ -81,16 +81,17 @@ Type ConcreteBroadcast(const TensorType& t1,
   for (; i <= std::min(ndim1, ndim2); ++i) {
     IndexExpr s1 = t1->shape[ndim1 - i];
     IndexExpr s2 = t2->shape[ndim2 - i];
-    if (EqualCheck(s1, s2)) {
-      oshape.push_back(s1);
-    } else if (EqualConstInt(s1, 1)) {
+    if (EqualConstInt(s1, 1)) {
       oshape.push_back(s2);
     } else if (EqualConstInt(s2, 1)) {
       oshape.push_back(s1);
-    } else if (s1.as<Any>() && EqualConstInt(s2, 1)) {
-      // TODO(@jroesch): we need to come back to this
+    } else if (s1.as<Any>()) {
+      // s1 == 1 || s1 == s2
       oshape.push_back(s2);
-    } else if (s2.as<Any>() && EqualConstInt(s1, 1)) {
+    } else if (s2.as<Any>()) {
+      // s2 == 1 || s2 == s1
+      oshape.push_back(s1);
+    } else if (EqualCheck(s1, s2)) {
       oshape.push_back(s1);
     } else {
       RELAY_ERROR(
index cdd2837..9dc180f 100644 (file)
@@ -915,7 +915,7 @@ class FuseMutator : private ExprMutator {
         if (it == gmap_.end()) return "";
         std::ostringstream os;
         auto *group = it->second->FindRoot();
-        os << "group=" << group;
+        os << " /* group=" << group << " */";
         return os.str();
       });
     LOG(INFO) << "Dump of group info:\n" << text;
index cef8e72..153b90c 100644 (file)
@@ -120,7 +120,7 @@ class ModulePassNode : public PassNode {
   /*!
    * \brief Get the pass information/meta data.
    */
-  PassInfo Info() const { return pass_info; }
+  PassInfo Info() const override { return pass_info; }
 
   TVM_DLL static ModulePass make(
       runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
@@ -174,7 +174,7 @@ class FunctionPassNode : public PassNode {
   /*!
    * \brief Get the pass information/meta data.
    */
-  PassInfo Info() const { return pass_info; }
+  PassInfo Info() const override { return pass_info; }
 
   TVM_DLL static FunctionPass make(
       runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
@@ -220,7 +220,7 @@ class SequentialNode : public PassNode {
   /*!
    * \brief Get the pass information/meta data.
    */
-  PassInfo Info() const { return pass_info; }
+  PassInfo Info() const override { return pass_info; }
 
   /*!
    * \brief Check if a pass is enabled.
index 33990ae..02ea3a4 100644 (file)
@@ -451,11 +451,11 @@ std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") {
 void InstructionPrint(std::ostream& os, const Instruction& instr) {
   switch (instr.op) {
     case Opcode::Move: {
-      os << "move $" << instr.dst << " $" << instr.from << std::endl;
+      os << "move $" << instr.dst << " $" << instr.from;
       break;
     }
     case Opcode::Ret: {
-      os << "ret $" << instr.result  << std::endl;
+      os << "ret $" << instr.result;
       break;
     }
     case Opcode::Fatal: {
@@ -469,7 +469,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
          << ", out: $"
          << StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size,
                              instr.output_size, ", $")
-         << ")" << std::endl;
+         << ")";
       break;
     }
     case Opcode::AllocTensor: {
@@ -478,71 +478,61 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
                              instr.alloc_tensor.ndim)
          << "] ";
       DLDatatypePrint(os, instr.alloc_tensor.dtype);
-      os << std::endl;
       break;
     }
     case Opcode::AllocTensorReg: {
       os << "alloc_tensor_reg $" << instr.dst << " $"
          << instr.alloc_tensor_reg.shape_register << " ";
       DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
-      os << std::endl;
       break;
     }
     case Opcode::AllocDatatype: {
       os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$"
-         << StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]"
-         << std::endl;
+         << StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]";
       break;
     }
     case Opcode::AllocClosure: {
       os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index
          << "]($" << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$")
-         << ")"
-         << std::endl;
+         << ")";
       break;
     }
     case Opcode::If: {
       os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " "
-         << instr.if_op.true_offset << " " << instr.if_op.false_offset
-         << std::endl;
+         << instr.if_op.true_offset << " " << instr.if_op.false_offset;
       break;
     }
     case Opcode::Invoke: {
       os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($"
          << StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$")
-         << ")"
-         << std::endl;
+         << ")";
       break;
     }
     case Opcode::InvokeClosure: {
       os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($"
          << StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$")
-         << ")"
-         << std::endl;
+         << ")";
       break;
     }
     case Opcode::LoadConst: {
-      os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]"
-         << std::endl;
+      os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]";
       break;
     }
     case Opcode::LoadConsti: {
-      os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]"
-         << std::endl;
+      os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]";
       break;
     }
     case Opcode::GetField: {
       os << "get_field $" << instr.dst << " $" << instr.object << "["
-         << instr.field_index << "]"
-         << std::endl;
+         << instr.field_index << "]";
       break;
     }
     case Opcode::GetTag: {
-      os << "get_tag $" << instr.dst << " $" << instr.get_tag.object << std::endl;
+      os << "get_tag $" << instr.dst << " $" << instr.get_tag.object;
       break;
     }
     case Opcode::Goto: {
-      os << "goto " << instr.pc_offset << std::endl;
+      os << "goto " << instr.pc_offset;
       break;
     }
     default:
@@ -559,9 +549,7 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instr) {
 void VMFunctionPrint(std::ostream& os, const VMFunction& vm_func) {
   os << vm_func.name << ": " << std::endl;
   for (size_t i = 0; i < vm_func.instructions.size(); ++i) {
-    os << i << ": ";
-    InstructionPrint(os, vm_func.instructions[i]);
-    os << ";" << std::endl;
+    os << i << ": " << vm_func.instructions[i] << ";" << std::endl;
   }
 }
 
@@ -801,7 +789,7 @@ void VirtualMachine::RunLoop() {
   while (true) {
   main_loop:
     auto const& instr = this->code[this->pc];
-    DLOG(INFO) << "Executing(" << pc << "): ";
+    DLOG(INFO) << "Executing(" << pc << "): " << instr;
 #if USE_RELAY_DEBUG
     InstructionPrint(std::cout, instr);
 #endif  // USE_RELAY_DEBUG
index 760ed0f..31f6169 100644 (file)
@@ -546,6 +546,8 @@ void InjectInline(ScheduleNode* sch) {
 
   std::vector<Array<Expr> > new_body(sch->stages.size());
   std::vector<bool> changed(sch->stages.size(), false);
+  std::vector<Stmt> new_hybrid_body(sch->stages.size());
+  std::vector<bool> hybrid_changed(sch->stages.size(), false);
   // inline all the ops
   for (size_t i = sch->stages.size(); i != 0; --i) {
     Stage stage = sch->stages[i - 1];
@@ -568,6 +570,7 @@ void InjectInline(ScheduleNode* sch) {
       for (size_t j = i; j < sch->stages.size(); ++j) {
         Stage s = sch->stages[j];
         const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
+        const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
         if (compute) {
           if (!new_body[j].size()) {
             new_body[j] = compute->body;
@@ -606,6 +609,15 @@ void InjectInline(ScheduleNode* sch) {
               }
             }
           }
+        } else if (hybrid) {
+          if (!new_hybrid_body[j].defined()) {
+            new_hybrid_body[j] = hybrid->body;
+          }
+          Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body);
+          if (!new_stmt.same_as(new_hybrid_body[j])) {
+            new_hybrid_body[j] = new_stmt;
+            hybrid_changed[j] = true;
+          }
         }
       }
     }
@@ -632,6 +644,17 @@ void InjectInline(ScheduleNode* sch) {
         }
         s->op = op;
       }
+    } else if (hybrid_changed[i]) {
+      const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
+      CHECK(hybrid);
+      Operation op = HybridOpNode::make(
+              hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
+              hybrid->outputs, new_hybrid_body[i]);
+      op = op->ReplaceInputs(op, repl);
+      for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
+        repl[s->op.output(idx)] = op.output(idx);
+      }
+      s->op = op;
     } else {
       Operation op = s->op->ReplaceInputs(s->op, repl);
       if (!op.same_as(s->op)) {
index f0a3fe7..214b88f 100644 (file)
@@ -18,27 +18,156 @@ import numpy as np
 
 import tvm
 from tvm import relay
-from tvm.relay import Kind, transform
 from tvm.relay.loops import while_loop
 from tvm.relay.testing import run_infer_type as infer_type
 
 def int32(val):
     return relay.const(val, 'int32')
 
+def any_dims(ndim):
+    shape = []
+    for _ in range(ndim):
+        shape.append(relay.Any())
+    return tuple(shape)
+
+# TODO(@wweic): because vm doesn't support heterogeneous exec, we can only test
+# shape function on CPU.
+
+def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
+    dtype = 'float32'
+    x = relay.var('x', shape=x_shape, dtype=dtype)
+    y = relay.var('y', shape=y_shape, dtype=dtype)
+    mod = relay.module.Module()
+    mod["main"] = relay.Function([x, y], op(x, y))
+    x_np = np.random.uniform(size=x_np_shape).astype(dtype)
+    y_np = np.random.uniform(size=y_np_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(x_np, y_np)
+        tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np))
+
+def test_any_broadcast():
+    verify_any_broadcast((relay.Any(),), (3, 2), (1,), (3, 2), relay.add, np.add)
+    verify_any_broadcast((relay.Any(), 2), (1, 2), (1, 2), (1, 2), relay.add, np.add)
+    verify_any_broadcast((relay.Any(), 2), (1, 2), (3, 2), (1, 2), relay.add, np.add)
+    verify_any_broadcast((relay.Any(), 2), (3, 2), (1, 2), (3, 2), relay.add, np.add)
+    verify_any_broadcast((relay.Any(), 2), (3, relay.Any()), (1, 2), (3, 1), relay.add, np.add)
+
+    # The following currently fail because topi compute treats Any as 1
+    # will requires auto_broadcast buffer to solve the problem
+    # TODO(@zhiics): Fix this
+    # verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
+    # verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
+
+def test_any_concat():
+    x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
+    y = relay.var('y', shape=(1, 2), dtype="float32")
+    z = relay.op.concatenate([x, y], axis=0)
+    mod = relay.module.Module()
+    mod["main"] = relay.Function([x, y], z)
+    x_np = np.random.uniform(size=(3, 2)).astype('float32')
+    y_np = np.random.uniform(size=(1, 2)).astype('float32')
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(x_np, y_np)
+        ref = np.concatenate([x_np, y_np], axis=0)
+        tvm.testing.assert_allclose(result.asnumpy(), ref)
+
+def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape):
+    x = relay.var('x', shape=x_shape, dtype="float32")
+    y = relay.reshape(x, newshape=newshape)
+    mod = relay.module.Module()
+    mod["main"] = relay.Function([x], y)
+    data = np.random.uniform(size=x_np_shape).astype('float32')
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data).asnumpy()
+        assert result.shape == out_shape
+        tvm.testing.assert_allclose(result.flatten(), data.flatten())
+
+def test_any_reshape():
+    verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24))
+    verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12))
+    verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4))
+    verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
+    verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
+
+def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
+    mod = relay.Module()
+    data = relay.var('data', shape=data_shape, dtype='float32')
+    indices = relay.var('indices', shape=indices_shape, dtype='int32')
+    y = relay.take(data, indices, axis=axis)
+    mod["main"] = relay.Function([data, indices], y)
+    data_np = np.random.uniform(size=data_np_shape).astype('float32')
+    if axis is None:
+        max_index = data_np.size
+    else:
+        max_index = data_np.shape[axis]
+    indices_np = np.random.randint(max_index, size=indices_np_shape).astype('int32')
+    ref = np.take(data_np, indices_np, axis=axis)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np, indices_np)
+        tvm.testing.assert_allclose(result.asnumpy(), ref)
+
+def test_any_take():
+    verify_any_take(any_dims(2), (1,), 0, (4, 5), (1,))
+    verify_any_take(any_dims(2), (), 0, (4, 5), ())
+    verify_any_take(any_dims(2), (), None, (4, 5), ())
+    verify_any_take(any_dims(3), any_dims(2), 1, (3, 4, 5), (2, 3))
+    verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4))
+    verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5))
+
+def test_any_shape_of():
+    x = relay.var('x', shape=any_dims(2), dtype='float32')
+    y = relay.shape_of(x)
+    mod = relay.module.Module()
+    mod["main"] = relay.Function([x], y)
+    data = np.random.uniform(size=(3, 4)).astype('float32')
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data)
+        tvm.testing.assert_allclose(result.asnumpy(), np.array([3,4]).astype("int64"))
+
+    x = relay.var('x', shape=any_dims(3), dtype='float32')
+    y0 = relay.shape_of(x)
+    y1 = relay.take(y0, relay.const(1, 'int32'))
+    mod = relay.module.Module()
+    mod["main"] = relay.Function([x], y1)
+    data = np.random.uniform(size=(2, 3, 4)).astype('float32')
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data)
+        tvm.testing.assert_allclose(result.asnumpy(), np.array(3).astype("int64"))
+
+def test_fused_ops():
+    x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32')
+    y0 = x + relay.const(1.0, 'float32')
+    y1 = y0 * relay.const(2.0, 'float32')
+    mod = relay.module.Module()
+    mod["main"] = relay.Function([x], y1)
+    data = np.random.uniform(size=(5, 4)).astype('float32')
+    for kind in ["vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data)
+        tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2)
+
 def test_arange_with_dynamic_shape():
     m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
     x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
     y0 = relay.shape_of(x)
     y1 = relay.take(y0, relay.const(0, 'int32'))
-    y2 = relay.op.arange(y1)
-    ex = relay.create_executor()
-    f = relay.Function([x], y2, type_params=[m, n, k])
-    # TODO(@jroesch): Restore after code generation.
-    # data = np.random.rand(10, 5, 3).astype('float32')
-    # result = ex.evaluate(f)(data)
-    # np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
-
-def test_dynamic_concat():
+    y2 = relay.op.arange(y1, dtype="int32")
+    y3 = y2 + relay.const(1, dtype="int32")
+    data = np.random.rand(10, 5, 3).astype('float32')
+    mod = relay.module.Module()
+    mod["main"] = relay.Function([x], y3, type_params=[m, n, k])
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data)
+        tvm.testing.assert_allclose(result.asnumpy(), np.array(range(10)).astype("int32")+1)
+
+def test_recursive_concat():
     """
     fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) {
         if (%i < 10) {
@@ -66,26 +195,18 @@ def test_dynamic_concat():
     start = relay.var('start', shape=(), dtype='int32')
     body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
     func = relay.Function([start], relay.TupleGetItem(body, 1))
-    func = infer_type(func)
-    # TODO(@jroesch, @haichen): We should restore this code when codegeneration
-    # is merged
-    # ret_shape = func.checked_type.ret_type.shape
-    # assert len(ret_shape) == 2, "expected 2-dim output"
-    # assert relay.ir_pass.alpha_eq(ret_shape[0], relay.Any())
-    # import pdb; pdb.set_trace()
-    # mod = relay.module.Module()
-    # print(relay.ir_pass.infer_type(func, mod=mod))
-    # ret = relay.Call(loop, [relay.const(0, 'int32'), init])
-    # mod[mod.entry_func] = relay.Function([], ret)
-    # print(relay.ir_pass.infer_type(mod[mod.entry_func], mod=mod))
-
-    # initial = np.array(0.0, dtype='float32').reshape((1,))
-    # iter_stop = np.array(10, dtype='int32')
-    # ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
-    # result = ex.evaluate(mod.entry_func)()
-    # np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
-
-def test_dynamic_concat_with_wrong_annotation():
+    mod = relay.module.Module()
+    mod["main"] = func
+    data = np.array(0.0, dtype='int32')
+    # TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail
+    # so currently we cannot run this test case on VM
+    for kind in ["debug"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data)
+        ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
+        np.testing.assert_allclose(result.asnumpy(), ref)
+
+def test_recursive_concat_with_wrong_annotation():
     """
     v0.0.1
     fn (%start: int32) {
@@ -133,6 +254,12 @@ def test_dynamic_concat_with_wrong_annotation():
         assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
 
 if __name__ == "__main__":
+    test_any_broadcast()
+    test_any_concat()
+    test_any_reshape()
+    test_any_take()
+    test_any_shape_of()
+    test_fused_ops()
     test_arange_with_dynamic_shape()
-    test_dynamic_concat()
-    test_dynamic_concat_with_wrong_annotation()
+    test_recursive_concat()
+    test_recursive_concat_with_wrong_annotation()
index 9a8ab2d..a32ec27 100644 (file)
@@ -104,9 +104,6 @@ def test_serializer():
     vm = create_vm(mod)
     ser = serializer.Serializer(vm)
 
-    stats = ser.stats
-    assert "scalar" in stats
-
     glbs = ser.globals
     assert len(glbs) == 3
     assert "f1" in glbs
@@ -120,8 +117,8 @@ def test_serializer():
 
     code = ser.bytecode
     assert "main 5 2 5" in code
-    assert "f1 3 1 4" in code
-    assert "f2 3 1 4" in code
+    assert "f1 2 1 3" in code
+    assert "f2 2 1 3" in code
 
     code, lib = ser.serialize()
     assert isinstance(code, bytearray)
index 805cff8..fd40c3f 100644 (file)
@@ -122,11 +122,13 @@ def test_outer_product():
     assert ibody.min.value == 0
     assert ibody.extent.name == 'm'
     #Check loop body
-    jbody = ibody.body
+    jblock = ibody.body
+    assert isinstance(jblock, tvm.stmt.Block)
+    jbody = jblock.first
     assert isinstance(jbody, tvm.stmt.AssertStmt)
     assert isinstance(jbody.message, tvm.expr.StringImm)
     assert jbody.message.value == "index out of range!"
-    jbody = jbody.body
+    jbody = jblock.rest
     assert isinstance(jbody, tvm.stmt.Provide)
     assert jbody.func.name == 'c'
     assert len(jbody.args) == 2
index b5120ee..2439fab 100644 (file)
@@ -52,6 +52,9 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
   tvm::Expr one(1);
   int i;
   for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
+    // TODO(@icemelon9): Need to revisit this part
+    const Variable* var1 = shape1[s1_size - i].as<Variable>();
+    const Variable* var2 = shape2[s2_size - i].as<Variable>();
     bh.all_vars.push_front(tvm::Var());
     if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
       bh.common_shape.push_front(shape1[s1_size - i]);
@@ -64,6 +67,16 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
     } else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
       bh.common_shape.push_front(shape1[s1_size - i]);
       bh.vars1.push_front(bh.all_vars[0]);
+    } else if (var1 && var2) {
+      bh.common_shape.push_front(max(shape1[s1_size - i], shape2[s2_size - i]));
+      bh.vars1.push_front(bh.all_vars[0]);
+      bh.vars2.push_front(bh.all_vars[0]);
+    } else if (var1) {
+      bh.common_shape.push_front(shape2[s2_size - i]);
+      bh.vars2.push_front(bh.all_vars[0]);
+    } else if (var2) {
+      bh.common_shape.push_front(shape1[s1_size - i]);
+      bh.vars1.push_front(bh.all_vars[0]);
     } else {
       CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i]
                    << " and " << shape2[s2_size - i] << " in: "
index 1622b20..af2ed16 100644 (file)
@@ -1148,9 +1148,9 @@ inline Tensor tensordot(const Tensor& A,
   return compute(output_shape, func, name, tag);
 }
 
-inline Tensor arange(const Expr start,
-                     const Expr stop,
-                     const Expr step,
+inline Tensor arange(const Expr& start,
+                     const Expr& stop,
+                     const Expr& step,
                      Type dtype,
                      std::string name = "T_arange",
                      std::string tag = kInjective) {
index a42d61d..908317e 100644 (file)
@@ -82,9 +82,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
             func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
             func(a, w, c)
 
-        rtol = 1e-5
-        if (kernel > 3):
-          rtol = 2e-5
+        rtol = 1e-3
 
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)