[Relay][RFC] Implement type checking for Any (#3221)
authorJared Roesch <roeschinc@gmail.com>
Wed, 10 Jul 2019 17:16:45 +0000 (10:16 -0700)
committerThierry Moreau <moreau@uw.edu>
Wed, 10 Jul 2019 17:16:45 +0000 (10:16 -0700)
* Implement type checking for Any

Remove code generation related changes

Remove compile changes

Remove more

Remove unification hack

Add some code back that was needed, and clean up test

Refactor test cases

WIP

Implement TypeHint AST

Add test case which should fail

Remove unification changes, and fix bug with let rec

Restore unification for shapes

Improve error reporting while debugging

All examples type check

All examples type check

WIP

First version that works with hints, needs clean up

Remove dead code

Tweaks

Remove type hint

Remove unecessary type hint stuff

Remove more type hints

Clean up

Expose Any expression node

Address CR

Fix

Fix solver

Kill unecessary code

Fix

PyLint

Fix

Relocate loops

Fix license and test

Lint again

Lint again

Fix loops

Fix docstring

Fix template error

Fix compiler issue

Fix compile err

Remove more runtime changes

Restore buffer

Fix segfault

Fix

Fix arange

* Address feedback

* Fix typo

* Fix arange

* Fix op level3

* Fix issue with Python wrapper

35 files changed:
include/tvm/ir.h
include/tvm/relay/attrs/transform.h
include/tvm/relay/error.h
include/tvm/relay/expr.h
include/tvm/relay/op_attr_types.h
include/tvm/relay/type.h
include/tvm/runtime/ndarray.h
python/tvm/_ffi/base.py
python/tvm/api.py
python/tvm/relay/__init__.py
python/tvm/relay/expr.py
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/loops.py [new file with mode: 0644]
python/tvm/relay/op/transform.py
python/tvm/relay/scope_builder.py
python/tvm/relay/ty.py
src/codegen/llvm/codegen_llvm.cc
src/lang/buffer.cc
src/lang/ir.cc
src/lang/tensor.cc
src/relay/backend/build_module.cc
src/relay/backend/compile_engine.cc
src/relay/ir/error.cc
src/relay/ir/expr.cc
src/relay/ir/pretty_printer.cc
src/relay/ir/type.cc
src/relay/op/tensor/transform.cc
src/relay/op/type_relations.cc
src/relay/pass/type_infer.cc
src/relay/pass/type_solver.cc
src/relay/pass/type_solver.h
src/runtime/ndarray.cc
tests/python/relay/test_any.py [new file with mode: 0644]
tests/python/relay/test_op_level3.py

index e0c6297d5d032bc1063631ef3b51bb0415fc490a..7524109ec48b127dc3fa0d1c05d6b9c82549f6a3 100644 (file)
@@ -138,6 +138,15 @@ struct Reduce : public ExprNode<Reduce> {
   static constexpr const char* _type_key = "Reduce";
 };
 
+/*! \brief Any shape. */
+struct Any : public ExprNode<Any> {
+  TVM_DLL static Expr make();
+
+  void VisitAttrs(AttrVisitor* v) final {}
+  static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
+  static constexpr const char* _type_key = "Any";
+};
+
 /*!
  * \brief Auxiliary data structure used in IR Pass to indicate a tensor.
  */
index 1247884f0df8dd189f1d3e62b1083139bbdcfdaf..d09441d73effcb39b67e76201caddc0c2cb74f44 100644 (file)
@@ -123,19 +123,19 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
 
 /*! \brief Attributes used in arange operators */
 struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
-  tvm::Expr start;
-  tvm::Expr stop;
-  tvm::Expr step;
+  Expr start;
+  Expr stop;
+  Expr step;
   DataType dtype;
 
   TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
-    TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0))
+    TVM_ATTR_FIELD(start)
         .describe("Start of interval. The interval includes this value.");
     TVM_ATTR_FIELD(stop)
         .describe("Stop of interval. The interval does not include this value.");
-    TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1))
+    TVM_ATTR_FIELD(step)
         .describe("Spacing between values.");
-    TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
+    TVM_ATTR_FIELD(dtype)
         .describe("Target data type.");
   }
 };  // struct ArangeAttrs
index 5189fd982d379d1aa31d1089f5976d737da92e09..ef3387b1893b65e17cf5ca6e38cec39877d808c1 100644 (file)
@@ -64,9 +64,10 @@ struct RelayErrorStream {
 
 struct Error : public dmlc::Error {
   Span sp;
-  explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
-  Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
-  Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
+  explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {}
+  Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*)
+  Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {}
+  Error() : dmlc::Error(""), sp(nullptr) {}
 };
 
 /*! \brief An abstraction around how errors are stored and reported.
@@ -118,7 +119,8 @@ class ErrorReporter {
    * \param err The error message to report.
    */
   inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
-    this->ReportAt(global, node, Error(err));
+    std::string err_msg = err.str();
+    this->ReportAt(global, node, Error(err_msg));
   }
 
   /*! \brief Report an error against a program, using the full program
index cb4f4ddece9980ef523781d42d1c1aef2cb8b1e9..c5cd6bb9e4abeb326b38a26fb1ddb952b85a5688 100644 (file)
@@ -561,6 +561,9 @@ inline const TTypeNode* ExprNode::type_as() const {
   return node;
 }
 
+/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
+std::string PrettyPrint(const NodeRef& node);
+
 /*!
  * \brief Render the node as a string in the Relay text format.
  * \param node The node to be rendered.
index ca7f6e5d39080ab88e3bbcb6b44a3efe0148097f..7709a790f5c58d645a3c4e1922ad579a08eeb7e2 100644 (file)
@@ -158,6 +158,22 @@ using FForwardRewrite = runtime::TypedPackedFunc<
 using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
                                                                   const Expr& output_grad)>;
 
+/*!
+ * \brief The codegeneration strategy for dynamic dimensions.
+ */
+enum AnyCodegenStrategy {
+  /*! \brief The default strategy of using completely variable dimensions. */
+  kVariableDimensions
+};
+
+/* \brief A runtime representation of shape. */
+using Shape = Array<IndexExpr>;
+
+using FShapeFunc = runtime::TypedPackedFunc<
+  Array<Tensor>(const Attrs& attrs,
+                const Array<Tensor>& inputs,
+                const Array<Shape>& out_shapes)>;
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_OP_ATTR_TYPES_H_
index e42ef1f65ba2ec7925004ba1f6fbf04b00aaec6b..d509fde2a8755264f778a62459ecd1332720d20b 100644 (file)
@@ -35,6 +35,8 @@
 namespace tvm {
 namespace relay {
 
+using Any = tvm::ir::Any;
+
 /*! \brief Base type of the Relay type hiearchy. */
 class TypeNode : public RelayNode {
  public:
@@ -384,6 +386,7 @@ class TypeReporterNode : public Node {
    *  But it is possible for the solver to resolve src by dst as well.
    */
   TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
+
   /*!
    * \brief assert shape expression comparison.
    * \note Use assert only if any of the condition input is symbolic.
index aea551ee7d6997742417f0b2e90d705b4bb626c2..993295179842397587af8f509627c96aa623958a 100644 (file)
@@ -190,6 +190,8 @@ class NDArray {
   TVM_DLL static void CopyFromTo(
       DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);
 
+  TVM_DLL std::vector<int64_t> Shape() const;
+
   // internal namespace
   struct Internal;
  protected:
index e8435081c9ed047d0cee3219fe65984d8ab7f8c0..c61c5c44544236a99e3f77e70fe6ea3dd7a144e5 100644 (file)
@@ -294,7 +294,7 @@ def get_last_ffi_error():
     """
     c_err_msg = py_str(_LIB.TVMGetLastError())
     py_err_msg, err_type = c2pyerror(c_err_msg)
-    if err_type.startswith("tvm.error."):
+    if err_type is not None and err_type.startswith("tvm.error."):
         err_type = err_type[10:]
     return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)
 
index e4777b6e39649c64844606e1f8b9d0c4fddd706e..7743ff7fa69093ee91477a9efe500ad32154c478 100644 (file)
@@ -479,7 +479,8 @@ def extern(shape,
             raise ValueError("nested tag is not allowed for now")
         tag = _tag.TagScope.get_current().tag
     shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
-    shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
+    if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)):
+        shape = [shape]
     if in_buffers is not None:
         in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
         if len(inputs) != len(in_buffers):
index dfac85bb1ed28b90c95414ad3dbd25f93a3a5845..509196f635b984d49d96273cd137a3b5b50fd8f5 100644 (file)
@@ -63,6 +63,7 @@ TupleType = ty.TupleType
 TensorType = ty.TensorType
 Kind = ty.Kind
 TypeVar = ty.TypeVar
+ShapeVar = ty.ShapeVar
 TypeConstraint = ty.TypeConstraint
 FuncType = ty.FuncType
 TypeRelation = ty.TypeRelation
@@ -71,6 +72,7 @@ scalar_type = ty.scalar_type
 RefType = ty.RefType
 GlobalTypeVar = ty.GlobalTypeVar
 TypeCall = ty.TypeCall
+Any = ty.Any
 
 # Expr
 Expr = expr.Expr
index 8e7f95c4dc2605dc3a3f293481fda9382881a3ca..88779dfd76e0f8fb10402cbbf8a013851aaa1ff7 100644 (file)
@@ -570,6 +570,7 @@ def const(value, dtype=None):
     """
     if isinstance(value, (_base.numeric_types, (bool, list))):
         value = _np.array(value, dtype=dtype)
+
     if not dtype:
         # when dtype is None: int maps to "int32", float maps to "float32"
         map_dtype = {
@@ -578,6 +579,7 @@ def const(value, dtype=None):
             }.get(value.dtype, None)
         if map_dtype:
             value = value.astype(map_dtype)
+
     if isinstance(value, (_np.ndarray, _np.generic)):
         value = _nd.array(value)
 
index e40f1dea61a9f5cdc8e2b13a4e2245f5f6b7f36e..3ddf47a119839a2ecc66f1aa6c03c7a4d9d991fb 100644 (file)
@@ -491,9 +491,9 @@ def _mx_arange(inputs, attrs):
         raise tvm.error.OpAttributeUnimplemented(
             'Attribute "repeat" is not supported in operator arange.')
     new_attrs = {}
-    new_attrs["start"] = attrs.get_float("start", 0)
-    new_attrs["stop"] = attrs.get_float("stop")
-    new_attrs["step"] = attrs.get_float("step", 1)
+    new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
+    new_attrs["stop"] = _expr.const(attrs.get_float("stop"))
+    new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
     new_attrs["dtype"] = attrs.get_str("dtype", "float32")
     return _op.arange(**new_attrs)
 
index 59e0983e95985182b91b90cdae7c4eb79b9de409..6fff40a922940ad6cdd0c4e8f8746689e94bb5e7 100644 (file)
@@ -1059,9 +1059,9 @@ def _range():
         return AttrCvt(
             op_name="arange",
             ignores=['Tidx'],
-            extras={'start': start,
-                    "stop": limit,
-                    'step': delta,
+            extras={'start': _expr.const(start),
+                    "stop": _expr.const(limit),
+                    'step': _expr.const(delta),
                     'dtype': dtype})([], attr)
     return _impl
 
@@ -1269,8 +1269,8 @@ def _batch_to_space_nd():
             crop = crops[axis - 1]
             if crop != [0, 0]:
                 indices = tvm.relay.arange(
-                    crop[0],
-                    reshaped_permuted_shape[axis] - crop[1],
+                    _expr.const(crop[0]),
+                    _expr.const(reshaped_permuted_shape[axis] - crop[1]),
                     dtype='int32'
                 )
                 cropped = tvm.relay.take(cropped, indices=indices, axis=axis)
diff --git a/python/tvm/relay/loops.py b/python/tvm/relay/loops.py
new file mode 100644 (file)
index 0000000..8e066ab
--- /dev/null
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""
+Utilities for building Relay loops.
+"""
+from .scope_builder import ScopeBuilder
+from . import expr as _expr
+
+def while_loop(cond, loop_vars, loop_bodies):
+    """
+    Construct a while loop.
+
+    Parameters
+    ----------
+
+    cond: Callable[Tuple[relay.Expr], relay.Expr]
+        The condition of the loop.
+
+    loop_vars:  Tuple[relay.Expr]
+        The variables being looped over.
+        The initial values of the loop, will be used to
+        construct the loop variables.
+
+    loop_bodies: Callable[Tuple[relay.Expr], Tuple[relay.Expr]]
+        The body of the loop, should be a function which
+        given loop variables produces the output result
+        also as a tuple
+
+    Returns
+    -------
+    loop: relay.Expr
+        The loop expression.
+    """
+    sb = ScopeBuilder()
+    loop = _expr.Var("while_loop")
+    fresh_vars = []
+
+    for i, loop_var in enumerate(loop_vars):
+        name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else "arg{}".format(i)
+        new_var = _expr.var(name, type_annotation=sb.type_of(loop_var))
+        fresh_vars.append(new_var)
+
+    with sb.if_scope(cond(*fresh_vars)):
+        sb.ret(loop(*loop_bodies(*fresh_vars)))
+    with sb.else_scope():
+        sb.ret(_expr.Tuple(fresh_vars))
+
+    func = _expr.Function(fresh_vars, sb.get())
+    let = _expr.Let(loop, func, loop)
+    return let
index bac60a058fca0ec7a00a512c81bdf34b274f34f0..5137a9c469a41f1aa7d6d83885417ff0abdf388d 100644 (file)
@@ -17,7 +17,7 @@
 """Transform operators."""
 
 from . import _make
-from ..expr import TupleWrapper
+from ..expr import TupleWrapper, const
 
 
 def cast(data, dtype):
@@ -272,7 +272,7 @@ def full_like(data, fill_value):
     return _make.full_like(data, fill_value)
 
 
-def arange(start, stop=None, step=1, dtype="float32"):
+def arange(start, stop=None, step=None, dtype="float32"):
     """Return evenly spaced values within a given interval.
 
     .. note::
@@ -310,9 +310,13 @@ def arange(start, stop=None, step=1, dtype="float32"):
         relay.arange(1, 5) = [1, 2, 3, 4]
         relay.arange(1, 5, 1.5) = [1, 2.5, 4]
     """
+    if step is None:
+        step = const(1, dtype)
+
     if stop is None:
         stop = start
-        start = 0
+        start = const(0, dtype=dtype)
+
     return _make.arange(start, stop, step, dtype)
 
 
index 337044098cd5a90f89fdc937373949958eec27fa..dfe3db187e07d72cc5818599b2418181d6fa3dcf 100644 (file)
@@ -42,7 +42,6 @@ class WithScope(object):
         else:
             self._exit_cb()
 
-
 def _make_lets(bindings, ret_value):
     """Make a nested let expressions.
 
@@ -176,6 +175,24 @@ class ScopeBuilder(object):
                 false_branch)
         return WithScope(None, _on_exit)
 
+
+    def type_of(self, expr):
+        """
+        Compute the type of an expression.
+
+        Parameters
+        ----------
+        expr: relay.Expr
+            The expression to compute the type of.
+        """
+        if isinstance(expr, _expr.Var):
+            return expr.type_annotation
+
+        ity = _ty.IncompleteType()
+        var = _expr.var("unify", ity)
+        self.let(var, expr)
+        return ity
+
     def ret(self, value):
         """Set the return value of this scope.
 
index b1477b75d278b87f34f346fe31405fbf6bfee3bb..2f3b7e91aaf79229f6bd9897bd9a27688c8a971f 100644 (file)
@@ -20,6 +20,7 @@ from enum import IntEnum
 from .base import RelayNode, register_relay_node
 from . import _make
 
+Any = _make.Any
 
 class Type(RelayNode):
     """The base type for all Relay types."""
@@ -137,6 +138,19 @@ class TypeVar(Type):
         """
         self.__init_handle_by_constructor__(_make.TypeVar, var, kind)
 
+def ShapeVar(name):
+    """A helper which constructs a type var of which the shape kind.
+
+    Parameters
+    ----------
+    name : str
+
+    Returns
+    -------
+    type_var : tvm.relay.TypeVar
+        The shape variable.
+    """
+    return TypeVar(name, kind=Kind.ShapeVar)
 
 @register_relay_node
 class GlobalTypeVar(Type):
index fde0486483b2b82ba1b7121725a5e1e39e76787a..f90829fa5e536963bea74a5f5534b406a26f8c86 100644 (file)
@@ -970,7 +970,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
              op->call_type == Call::PureExtern) {
     return CreateCallExtern(op);
   } else {
-    LOG(FATAL) << "Unknown call type ";
+    LOG(FATAL) << "Unknown call type " <<
+      "name= " << op->name <<
+      " call_type= " << op->call_type;
     return nullptr;
   }
 }
index cb5c86710fabbbb7bf048a48ecd62e7764bf7bbc..8324938bd7f96f09420c0997013a75d8165ac8da 100644 (file)
@@ -246,13 +246,20 @@ inline Expr MergeMulMod(const Expr &base) {
 inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
   Expr base = n->elem_offset;
   if (n->strides.size() == 0) {
-    CHECK_EQ(n->shape.size(), index.size());
-    if (index.size() > 0) {
-      Expr offset = index[0];
-      for (size_t i = 1; i < index.size(); ++i) {
-        offset = MergeMulMod(offset * n->shape[i] + index[i]);
+    // Scalar case
+    if (n->shape.size() == 0 && index.size() == 1) {
+      auto is_int = index[0].as<IntImm>();
+      CHECK(is_int && is_int->value == 0);
+      base = base + index[0];
+    } else {
+      CHECK_EQ(n->shape.size(), index.size());
+      if (index.size() > 0) {
+        Expr offset = index[0];
+        for (size_t i = 1; i < index.size(); ++i) {
+          offset = MergeMulMod(offset * n->shape[i] + index[i]);
+        }
+        base = base + offset;
       }
-      base = base + offset;
     }
   } else {
     CHECK_EQ(n->strides.size(), index.size());
index 612a5e908b545298f7bba0a05ff4472462c23bab..4eeddd91d80c46c2c16d00e7bd915127a533c74a 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -35,13 +35,24 @@ namespace Internal {
 
 using tvm::ir::CommReducerNode;
 using tvm::ir::Reduce;
+using tvm::ir::Any;
 using tvm::ir::AttrStmt;
 
 template<>
 void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
-  LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor";
+  LOG(FATAL) << "Reduce does not work with old Visitor, use IRFunctor style visitor";
+}
+
+template<>
+void ExprNode<Any>::accept(IRVisitor *v, const Expr&) const {
+  LOG(FATAL) << "Any does not work with old Visitor, use IRFunctor style visitor";
 }
 
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
+  p->stream << "?";
+});
+
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
   p->stream << "reduce(combiner="
@@ -116,8 +127,14 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
   return Expr(n);
 }
 
+Expr Any::make() {
+  auto n = make_node<Any>();
+  return Expr(n);
+}
+
 TVM_REGISTER_NODE_TYPE(CommReducerNode);
 TVM_REGISTER_NODE_TYPE(Reduce);
+TVM_REGISTER_NODE_TYPE(Any);
 TVM_REGISTER_NODE_TYPE(AttrStmt);
 
 TVM_REGISTER_NODE_TYPE(FloatImm);
index d885d7103606434a570a1783725eb9a975e19592..c2f80d10f790f93e893e553ff03e2c84c792c2ac 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -38,9 +38,12 @@ Expr Tensor::operator()(Array<Var> indices) const {
 
 Expr Tensor::operator()(Array<Expr> indices) const {
   using HalideIR::Internal::Call;
-  CHECK_EQ(ndim(), indices.size())
-      << "Tensor dimension mismatch in read"
-      << "ndim = " << ndim() << ", indices.size=" << indices.size();
+  if (ndim() != 0) {
+    CHECK_EQ(ndim(), indices.size())
+        << "Tensor dimension mismatch in read"
+        << "ndim = " << ndim() << ", indices.size=" << indices.size();
+  }
+
   auto n = Call::make(
       (*this)->dtype, (*this)->op->name, indices, Call::Halide,
       (*this)->op, (*this)->value_index);
index 7de77c8bcfd4b244664b8a122c8c7708d7325c97..f7e106756818ef5def5c303cf08d6fb4d2d96c02 100644 (file)
@@ -417,10 +417,10 @@ class RelayBuildModule : public runtime::ModuleNode {
   }
 
   /*!
-   * \brief Build relay function to runtime module
+   * \brief Compile a Relay function to runtime module.
    *
-   * \param func Relay Function
-   * \param params parameters
+   * \param func The Relay function.
+   * \param params The parameters.
    */
   void BuildRelay(
       Function func,
@@ -444,8 +444,13 @@ class RelayBuildModule : public runtime::ModuleNode {
     ret_.graph_json = graph_codegen_->GetJSON();
     ret_.params = graph_codegen_->GetParams();
 
-    ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_,
-                          BuildConfig::Current());
+    auto lowered_funcs = graph_codegen_->GetLoweredFunc();
+    if (lowered_funcs.size() != 0) {
+      ret_.mod = tvm::build(
+        lowered_funcs,
+        target_host_,
+        BuildConfig::Current());
+    }
   }
 
  protected:
index 83e4a36ff4f93878de68d94a3bfc9a8b4ba96401..ab906310aaa3ff0cdd89b5774e453d7341a4fa17 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- *
+ * 
  *   http://www.apache.org/licenses/LICENSE-2.0
- *
+ * 
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
index 5e621316a136358563022db85066d45d77325cc6..5ed51f5fd281f150cd50f80b07fdd5cd23936b01 100644 (file)
@@ -67,6 +67,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
     std::stringstream err_msg;
 
     err_msg << rang::fg::red;
+    err_msg << " ";
     for (auto index : error_indicies) {
       err_msg << this->errors_[index].what() << "; ";
     }
@@ -88,7 +89,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
   // First we output a header for the errors.
   annotated_prog <<
   rang::style::bold << std::endl <<
-  "Error(s) have occurred. We have annotated the program with them:"
+  "Error(s) have occurred. The program has been annotated with them:"
   << std::endl << std::endl << rang::style::reset;
 
   // For each global function which contains errors, we will
index a638006f78e547aac9c7789c849368b773dbc010..9589a0a9f5b88be37d13fcda423520bfcddeefbd 100644 (file)
@@ -287,6 +287,8 @@ RefCreate RefCreateNode::make(Expr value) {
   return RefCreate(n);
 }
 
+TVM_REGISTER_NODE_TYPE(RefCreateNode);
+
 TVM_REGISTER_API("relay._make.RefCreate")
 .set_body_typed(RefCreateNode::make);
 
@@ -301,6 +303,8 @@ RefRead RefReadNode::make(Expr ref) {
   return RefRead(n);
 }
 
+TVM_REGISTER_NODE_TYPE(RefReadNode);
+
 TVM_REGISTER_API("relay._make.RefRead")
 .set_body_typed(RefReadNode::make);
 
@@ -316,6 +320,8 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) {
   return RefWrite(n);
 }
 
+TVM_REGISTER_NODE_TYPE(RefWriteNode);
+
 TVM_REGISTER_API("relay._make.RefWrite")
 .set_body_typed(RefWriteNode::make);
 
index cdc56db845b20172b448603f6cf9167f0f857821..09196b49a61719a0c4f18890fec738bd88dbac2d 100644 (file)
@@ -686,7 +686,9 @@ class PrettyPrinter :
   Doc PrintAttr(const NodeRef& value, bool meta = false) {
     if (value.defined()) {
       Doc printed_attr;
-      if (meta) {
+      if (value.as<tvm::ir::Any>()) {
+        printed_attr << "?";
+      } else if (meta) {
         printed_attr = meta_.GetMetaNode(value);
       } else {
         printed_attr = VisitAttr(value);
@@ -846,6 +848,12 @@ std::string PrettyPrint_(const NodeRef& node,
   return doc.str();
 }
 
+std::string PrettyPrint(const NodeRef& node) {
+  Doc doc;
+  doc << PrettyPrinter(false, runtime::TypedPackedFunc<std::string(Expr)>()).PrintFinal(node);
+  return doc.str();
+}
+
 std::string AsText(const NodeRef& node,
                        bool show_meta_data,
                        runtime::TypedPackedFunc<std::string(Expr)> annotate) {
index 35a12052949e229c5429cca67715a4074c8f5369..2604896c86055fce694f1b1f22f94584c0aa1c06 100644 (file)
@@ -228,5 +228,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
   p->stream << "RefTypeNode(" << node->value << ")";
 });
 
+TVM_REGISTER_API("relay._make.Any")
+.set_body_typed<IndexExpr()>([]() { return Any::make(); });
+
+
 }  // namespace relay
 }  // namespace tvm
index da93860251905aa127bd29202c34df964d3ea8b5..59424884ccfe8245dd87019a278f9c34129a1bac 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -23,6 +23,7 @@
  * \brief Transform operators.
  */
 #include <tvm/relay/op.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/expr_operator.h>
 #include <tvm/ir.h>
@@ -184,40 +185,77 @@ bool ConcatenateRel(const Array<Type>& types,
                     const TypeReporter& reporter) {
   // types: [data, result]
   CHECK_EQ(types.size(), 2);
+  /* If we receive a tuple we can continue, if we receive
+   * anything but an incomplete type we should signal an
+   * error.
+  */
   const auto* tensor_tuple = types[0].as<TupleTypeNode>();
   if (tensor_tuple == nullptr) {
-    CHECK(types[0].as<IncompleteTypeNode>())
-        << "cast: expect input type to be TupleType but get "
-        << types[0];
+    throw relay::Error(
+        RELAY_ERROR(
+          "concatenate requires a tuple of tensors as the first argument, found "
+        << PrettyPrint(types[0])));
+  } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
     return false;
   }
+
   const auto* param = attrs.as<ConcatenateAttrs>();
+  if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) {
+    return false;
+  }
   const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
   // Sanity check: ndim and dtype.
   const int ndim = static_cast<int>(first->shape.size());
   const DataType dtype = first->dtype;
+
   for (const Type& ele : tensor_tuple->fields) {
+    if (ele.as<IncompleteTypeNode>()) {
+      return false;
+    }
+
     const auto& e = Downcast<TensorType>(ele);
+
     int e_ndim = static_cast<int>(e->shape.size());
     const DataType& e_dtype = e->dtype;
-    CHECK_EQ(e_ndim, ndim) << "relay.concatenate requires all tensors have the same ndim";
-    CHECK_EQ(e_dtype, dtype) << "relay.concatenate requires all tensors have the same dtype";
+    if (e_ndim != ndim) {
+      throw relay::Error("relay.concatenate requires all tensors have the same ndim");
+    }
+    if (e_dtype != dtype) {
+      throw relay::Error("relay.concatenate requires all tensors have the same dtype");
+    }
   }
   // Sanity check: axis
   int axis = param->axis;
-  CHECK(-ndim <= axis && axis < ndim)
-    << "concatenate only accepts `axis` in [-ndim, ndim)"
-    << ", but got axis = " << axis
-    << ", and ndim = " << ndim;
+  if (!(-ndim <= axis && axis < ndim)) {
+    throw relay::Error(RELAY_ERROR(
+      "concatenate only accepts `axis` in [-ndim, ndim)" <<
+      ", but got axis = " << axis <<
+      ", and ndim = " << ndim));
+  }
   axis = axis < 0 ? ndim + axis : axis;
   // Calculate shape
   std::vector<IndexExpr>&& oshape = AsVector(first->shape);
   IndexExpr &concat_dim = oshape[axis];
-  for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
-    const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
-    concat_dim += e->shape[axis];
+  bool has_any = false;
+  if (concat_dim.as<Any>()) {
+    has_any = true;
+  } else {
+    for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
+      const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
+      if (e->shape[axis].as<Any>()) {
+        has_any = true;
+        break;
+      }
+      concat_dim += e->shape[axis];
+    }
   }
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, dtype));
+
+  if (has_any) {
+    concat_dim = Any::make();
+  }
+
+  auto rtype = TensorTypeNode::make(oshape, dtype);
+  reporter->Assign(types[1], rtype);
   return true;
 }
 
@@ -499,6 +537,8 @@ bool ReshapeRel(const Array<Type>& types,
     newshape = param->newshape;
   }
   Array<IndexExpr> oshape;
+  std::unordered_set<size_t> used_input_dims;
+  std::unordered_set<size_t> used_output_dims;
   size_t src_idx = 0;
   int infer_idx = -1;
 
@@ -511,6 +551,8 @@ bool ReshapeRel(const Array<Type>& types,
     } else if (svalue == 0) {
       // keep same
       CHECK_LT(src_idx, data_shape.size());
+      used_input_dims.insert(src_idx);
+      used_output_dims.insert(oshape.size());
       oshape.push_back(data_shape[src_idx++]);
     } else if (svalue == -1) {
       // inference based on rest
@@ -522,31 +564,49 @@ bool ReshapeRel(const Array<Type>& types,
     } else if (svalue == -2) {
       // copy all remaining dims from source
       while (src_idx < data_shape.size()) {
+        used_input_dims.insert(src_idx);
+        used_output_dims.insert(oshape.size());
         oshape.push_back(data_shape[src_idx++]);
       }
     } else if (svalue == -3) {
       // merge two dims from source
       CHECK_LT(src_idx + 1, data_shape.size());
+      used_input_dims.insert(src_idx);
       IndexExpr d1 = data_shape[src_idx++];
+      used_input_dims.insert(src_idx);
       IndexExpr d2 = data_shape[src_idx++];
+      used_output_dims.insert(oshape.size());
       oshape.push_back(d1 * d2);
     } else if (svalue == -4) {
       // split the source dim s into two dims
       // read the left dim and then the right dim (either can be -1)
       CHECK_LT(i + 2, newshape.size());
       CHECK_LT(src_idx, data_shape.size());
+      used_input_dims.insert(src_idx);
       IndexExpr d0 = data_shape[src_idx++];
       Integer d1 = newshape[++i];
       Integer d2 = newshape[++i];
       if (d1->value == -1) {
         CHECK(d2->value != -1)
             << "Split dims cannot both be -1.";
-        oshape.push_back(d0 / d2);
+        used_output_dims.insert(oshape.size());
+        if (d0.as<Any>()) {
+          oshape.push_back(Any::make());
+        } else {
+          oshape.push_back(d0 / d2);
+        }
+        used_output_dims.insert(oshape.size());
         oshape.push_back(d2);
       } else {
+        used_output_dims.insert(oshape.size());
         oshape.push_back(d1);
+        used_output_dims.insert(oshape.size());
         if (d2->value == -1) {
-          oshape.push_back(d0 / d1);
+          if (d0.as<Any>()) {
+            oshape.push_back(Any::make());
+          } else {
+            oshape.push_back(d0 / d1);
+          }
         } else {
           oshape.push_back(d2);
         }
@@ -555,9 +615,30 @@ bool ReshapeRel(const Array<Type>& types,
   }
 
   if (infer_idx >= 0) {
-    IndexExpr new_size = arith::ComputeReduce<tvm::ir::Mul>(oshape, 1);
-    IndexExpr old_size = arith::ComputeReduce<tvm::ir::Mul>(data_shape, 1);
-    oshape.Set(infer_idx, old_size / new_size);
+    IndexExpr infer_dim = 1;
+    for (size_t i = 0; i < data_shape.size(); ++i) {
+      if (used_input_dims.count(i) != 0) {
+        continue;
+      }
+      if (data_shape[i].as<Any>()) {
+        infer_dim = Any::make();
+        break;
+      }
+      infer_dim *= data_shape[i];
+    }
+    if (!infer_dim.as<Any>()) {
+      for (size_t i = 0; i < oshape.size(); ++i) {
+        if (used_output_dims.count(i) != 0) {
+          continue;
+        }
+        if (oshape[i].as<Any>()) {
+          infer_dim = Any::make();
+          break;
+        }
+        infer_dim /= oshape[i];
+      }
+    }
+    oshape.Set(infer_idx, infer_dim);
   }
 
   if (param->reverse) {
@@ -978,21 +1059,51 @@ and type as the input array.
 // arange operator
 TVM_REGISTER_NODE_TYPE(ArangeAttrs);
 
+double ToScalar(const runtime::NDArray& array) {
+  if (array->dtype.code == kDLInt || array->dtype.code == kDLUInt) {
+    return reinterpret_cast<int32_t*>(array->data)[0];
+  } else {
+    return reinterpret_cast<float*>(array->data)[0];
+  }
+}
+
 bool ArangeRel(const Array<Type>& types,
                int num_inputs,
-               const Attrs& attrs,
+               const Attrs& raw_attrs,
                const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 1);
-  const ArangeAttrs* param = attrs.as<ArangeAttrs>();
-  IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil(
-      tvm::cast(tvm::Float(32), param->stop - param->start) / param->step));
-  if (const tvm::ir::IntImm* val = num_elem.as<tvm::ir::IntImm>()) {
-    CHECK_GT(val->value, 0)
-        << "Invalid arange attributes (start, stop, step): " << param->start
-        << ", " << param->stop << ", " << param->step;
-  }
-  reporter->Assign(types[0], TensorTypeNode::make({num_elem}, param->dtype));
-  return true;
+  CHECK_EQ(types.size(), 4);
+  const ArangeAttrs* attrs = raw_attrs.as<ArangeAttrs>();
+  const ConstantNode *cstart, *cstop, *cstep;
+
+  reporter->Assign(types[0], types[1]);
+  reporter->Assign(types[1], types[2]);
+  reporter->Assign(types[2], TensorTypeNode::make({}, attrs->dtype));
+
+  if ((cstart = attrs->start.as<ConstantNode>()) &&
+      (cstop = attrs->stop.as<ConstantNode>()) &&
+      (cstep = attrs->step.as<ConstantNode>())) {
+    double start = ToScalar(cstart->data);
+    double stop = ToScalar(cstop->data);
+    double step = ToScalar(cstep->data);
+    int32_t num_elem = static_cast<int32_t>(std::ceil((stop - start) / step));
+    CHECK_GT(num_elem, 0)
+        << "Invalid arange attributes (start, stop, step): " << attrs->start
+        << ", " << attrs->stop << ", " << attrs->step;
+    reporter->Assign(types[3], TensorTypeNode::make({num_elem}, attrs->dtype));
+    return true;
+  } else {
+    reporter->Assign(types[3], TensorTypeNode::make({Any::make()}, attrs->dtype));
+    return true;
+  }
+}
+
+inline Tensor DynamicArange(const tvm::Tensor& start, const tvm::Tensor& stop,
+                            const tvm::Tensor& step, tvm::Type dtype, std::string name = "tensor",
+                            std::string tag = topi::kInjective) {
+  tvm::Expr num_elem = tvm::Var("num_elem");
+  return tvm::compute({num_elem}, [&](const Array<tvm::Var>& indices) {
+    return tvm::cast(dtype, start[0] + step[0] * indices[0]);
+  }, name, tag);
 }
 
 Array<Tensor> ArangeCompute(const Attrs& attrs,
@@ -1000,35 +1111,53 @@ Array<Tensor> ArangeCompute(const Attrs& attrs,
                             const Type& out_type,
                             const Target& target) {
   const ArangeAttrs* param = attrs.as<ArangeAttrs>();
-  return { topi::arange(param->start, param->stop, param->step, param->dtype) };
+  Tensor start = inputs[0];
+  Tensor stop =  inputs[1];
+  Tensor step = inputs[2];
+  Array<tvm::Expr> empty = {0};
+  return { DynamicArange(start, stop, step, param->dtype) };
 }
 
-Expr MakeArange(tvm::Expr start,
-                tvm::Expr stop,
-                tvm::Expr step,
+Expr MakeArange(Expr start,
+                Expr stop,
+                Expr step,
                 DataType dtype) {
   auto attrs = make_node<ArangeAttrs>();
-  attrs->start = std::move(start);
-  attrs->stop = std::move(stop);
-  attrs->step = std::move(step);
-  attrs->dtype = std::move(dtype);
+  attrs->start = start;
+  attrs->stop = stop;
+  attrs->step = step;
+  attrs->dtype = dtype;
   static const Op& op = Op::Get("arange");
-  return CallNode::make(op, {}, Attrs(attrs), {});
+  return CallNode::make(op, {start, stop, step}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_API("relay.op._make.arange")
 .set_body_typed(MakeArange);
 
+// An issue with the existing design is that we require dependency
+// to type the operator precisely.
+//
+// Supporting this in general is challenging so we duplicate the
+// secondary arguments as args and attributes.
+//
+// In this way reify the arguments at both the value and type level.
+//
+// In the case our arguments are constant we can immediately recover
+// the type of arange.
+//
+// In general I think we should avoid this pattern, and introduce
+// a secondary shape analysis to recover more precise information.
 RELAY_REGISTER_OP("arange")
 .describe(R"code(Returns evenly spaced values within a given interval.
 
 )code" TVM_ADD_FILELINE)
 .set_attrs_type_key("relay.attrs.ArangeAttrs")
-.set_num_inputs(0)
+.set_num_inputs(3)
 .set_support_level(3)
 .add_type_rel("Arange", ArangeRel)
 .set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+.set_attr<TOpPattern>("TOpPattern", kInjective)
+.set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
 
 // repeat operator
 TVM_REGISTER_NODE_TYPE(RepeatAttrs);
index 5b147a489b445e879a5bda6e8b2a5a43144ab347..d4efe80c533f9cc403cf06745356bfb87486ce23 100644 (file)
@@ -87,6 +87,11 @@ Type ConcreteBroadcast(const TensorType& t1,
       oshape.push_back(s2);
     } else if (EqualConstInt(s2, 1)) {
       oshape.push_back(s1);
+    } else if (s1.as<Any>() && EqualConstInt(s2, 1)) {
+      // TODO(@jroesch): we need to come back to this
+      oshape.push_back(s2);
+    } else if (s2.as<Any>() && EqualConstInt(s1, 1)) {
+      oshape.push_back(s1);
     } else {
       RELAY_ERROR(
           "Incompatible broadcast type "
index 64f125a9050657581ea6c114991cb67d23743ea3..02f6cc3d485799eaa26e333ba896406874f34c9f 100644 (file)
@@ -313,17 +313,24 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
   Type VisitExpr_(const LetNode* let) final {
     // if the definition is a function literal, permit recursion
     bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
+    Type let_type = IncompleteTypeNode::make(Kind::kType);
+
     if (is_functional_literal) {
-      type_map_[let->var].checked_type = IncompleteTypeNode::make(Kind::kType);
+      let_type = GetType(let->var);
+      type_map_[let->var].checked_type = let_type;
     }
 
-    Type vtype = GetType(let->value);
+
     if (let->var->type_annotation.defined()) {
-      vtype = Unify(vtype, let->var->type_annotation, GetRef<Let>(let));
+      let_type = Unify(let_type, let->var->type_annotation, GetRef<Let>(let));
     }
+
+    Type vtype = GetType(let->value);
+    let_type = Unify(let_type, vtype, GetRef<Let>(let));
+
     CHECK(is_functional_literal || !type_map_.count(let->var));
     // NOTE: no scoping is necessary because var are unique in program
-    type_map_[let->var].checked_type = vtype;
+    type_map_[let->var].checked_type = let_type;
     return GetType(let->body);
   }
 
@@ -473,7 +480,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     }
 
     for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
-      this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]);
+      this->Unify(fn_ty->arg_types[i], arg_types[i], GetRef<Call>(call));
     }
 
     for (auto cs : fn_ty->type_constraints) {
@@ -556,6 +563,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types),
                               td->type_vars, {});
   }
+
+  void Solve() {
+    solver_.Solve();
+
+    if (err_reporter.AnyErrors()) {
+      err_reporter.RenderErrors(mod_);
+    }
+  }
 };
 
 class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
@@ -673,7 +688,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
         update_missing_type_annotation_ &&
         !new_var->type_annotation.defined());
 
-    bool need_update_fn = (
+    bool need_update_fn =(
         std::is_base_of<FunctionNode, T>::value &&
         update_missing_type_annotation_ &&
         !new_fn->ret_type.defined());
@@ -738,16 +753,13 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
 
 
 Expr TypeInferencer::Infer(Expr expr) {
-  // Step 0: Populate the constraints.
+  // Step 1: Populate the constraints.
   GetType(expr);
-  // Step 1: Solve the constraints.
-  solver_.Solve();
 
-  if (err_reporter.AnyErrors()) {
-    err_reporter.RenderErrors(mod_);
-  }
+  // Step 2: Solve the constraints.
+  Solve();
 
-  // Step 2: Attach resolved types to checked_type field.
+  // Step 3: Attach resolved types to checked_type field.
   auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
   CHECK(WellFormed(resolved_expr));
   return resolved_expr;
index 8289130f53d85306fec8f9522121d88cdd1384be..38870762d8408b83817e928925a7d94e608bacef 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -24,6 +24,8 @@
  */
 #include <string>
 #include <memory>
+#include <tuple>
+#include <utility>
 #include "type_solver.h"
 #include "../ir/type_functor.h"
 
@@ -90,7 +92,7 @@ class TypeSolver::OccursChecker : public TypeVisitor {
 
 class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
  public:
-  explicit Unifier(TypeSolver* solver) : solver_(solver) {}
+  explicit Unifier(TypeSolver* solver, const NodeRef& loc) : solver_(solver), loc(loc) {}
 
   Type Unify(const Type& src, const Type& dst) {
     // Known limitation
@@ -102,27 +104,34 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     if (lhs->FindRoot() == rhs->FindRoot()) {
       return lhs->resolved_type;
     }
+
     if (lhs->resolved_type.as<IncompleteTypeNode>()) {
-      CHECK(!CheckOccurs(lhs, rhs->resolved_type))
+      CHECK(!OccursCheck(lhs, rhs->resolved_type))
         << "Incomplete type " << lhs->resolved_type << " occurs in "
         << rhs->resolved_type << ", cannot unify";
+
       solver_->MergeFromTo(lhs, rhs);
       return rhs->resolved_type;
     } else if (rhs->resolved_type.as<IncompleteTypeNode>()) {
-      CHECK(!CheckOccurs(rhs, lhs->resolved_type))
+      CHECK(!OccursCheck(rhs, lhs->resolved_type))
         << "Incomplete type " << rhs->resolved_type << " occurs in "
         << lhs->resolved_type << ", cannot unify";
       solver_->MergeFromTo(rhs, lhs);
       return lhs->resolved_type;
     } else {
       Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
-      CHECK(resolved.defined())
-        << "Unable to unify parent types: "
-        << lhs->resolved_type << " and " << rhs->resolved_type;
-      TypeNode* top = solver_->GetTypeNode(resolved);
-      solver_->MergeFromTo(lhs, top);
-      solver_->MergeFromTo(rhs, top);
-      return resolved;
+      if (!resolved.defined()) {
+        solver_->ReportError(RELAY_ERROR("unable to unify: "
+                                         << "`" << PrettyPrint(lhs->resolved_type) << "` and `"
+                                         << PrettyPrint(rhs->resolved_type) << "`"),
+                             this->loc);
+        return lhs->resolved_type;
+      } else {
+        TypeNode* top = solver_->GetTypeNode(resolved);
+        solver_->MergeFromTo(lhs, top);
+        solver_->MergeFromTo(rhs, top);
+        return resolved;
+      }
     }
   }
 
@@ -130,7 +139,9 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
   // there is a recursive equality constraint, which should be rejected.
   // N.b.: A tautology like ?a = ?a is okay and should be checked for
   // *before* calling this method
-  bool CheckOccurs(TypeNode* lhs, const Type& t) {
+  //
+  // See: https://en.wikipedia.org/wiki/Occurs_check
+  bool OccursCheck(TypeNode* lhs, const Type& t) {
     OccursChecker rc(solver_, lhs);
     return rc.Check(t);
   }
@@ -145,6 +156,118 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     return t1;
   }
 
+  IndexExpr GetShape(const IndexExpr& e) {
+    IndexExpr ex = e;
+    while (true) {
+      auto it = solver_->shape_uf_.find(ex);
+      if (it == solver_->shape_uf_.end()) {
+        return ex;
+      } else {
+        ex = (*it).second;
+      }
+    }
+  }
+
+  IndexExpr UnifyDim(const IndexExpr& lhs, const IndexExpr& rhs) {
+    auto ulhs = GetShape(lhs);
+    auto urhs = GetShape(rhs);
+
+    if (ulhs.same_as(urhs)) {
+      return ulhs;
+    }
+    if (ulhs.as<Any>() || urhs.as<Any>()) {
+      return Any::make();
+    }
+
+    auto left_index0 = ulhs.as<tvm::Variable>();
+    auto right_index0 = urhs.as<tvm::IntImm>();
+    if (left_index0 && right_index0) {
+      solver_->shape_uf_.Set(ulhs, urhs);
+      return urhs;
+    }
+
+    auto left_index1 = ulhs.as<tvm::IntImm>();
+    auto right_index1 = urhs.as<tvm::Variable>();
+    if (left_index1 && right_index1) {
+      solver_->shape_uf_.Set(urhs, ulhs);
+      return ulhs;
+    }
+
+    auto left_index2 = ulhs.as<tvm::IntImm>();
+    auto right_index2 = urhs.as<tvm::IntImm>();
+    if (left_index2 && right_index2 && left_index2->value == right_index2->value) {
+      return ulhs;
+    }
+
+    return tvm::Expr();
+  }
+
+  Type VisitType_(const TensorTypeNode* op, const Type& tn) final {
+    const auto* tt_node = tn.as<TensorTypeNode>();
+    if (!tt_node) {
+      return Type(nullptr);
+    }
+
+    auto tt1 = GetRef<TensorType>(op);
+    auto tt2 = GetRef<TensorType>(tt_node);
+
+    if (AlphaEqual(tt1, tt2)) {
+      return std::move(tt1);
+    }
+
+    if (tt1->dtype != tt2->dtype) {
+      return Type(nullptr);
+    }
+
+    tvm::Array<IndexExpr> shape;
+    if (tt1->shape.size() != tt2->shape.size()) {
+      this->solver_->ReportError(
+        RELAY_ERROR(
+          "tensor type `" << PrettyPrint(tt1) <<
+          "` has " <<  tt1->shape.size() <<
+          " dimensions, while `" <<
+          PrettyPrint(tt2) <<
+          "` has " << tt2->shape.size() <<
+          " dimensions"), this->loc);
+      return Type(nullptr);
+    }
+
+    std::vector<std::tuple<size_t, IndexExpr, IndexExpr>> mismatches;
+
+    CHECK_EQ(tt1->shape.size(), tt2->shape.size());
+    for (size_t i = 0; i < tt1->shape.size(); i++) {
+      auto dim = UnifyDim(tt1->shape[i], tt2->shape[i]);
+      if (!dim.defined()) {
+        // NB: We push an arbitrary dimension here so we can continue error propogation.
+        shape.push_back(tt1->shape[i]);
+        tvm::Expr shape1 = tt1->shape[i];
+        tvm::Expr shape2 = tt2->shape[i];
+        std::tuple<int, IndexExpr, IndexExpr> tuple = std::make_tuple(i, shape1, shape2);
+        mismatches.push_back(tuple);
+      } else {
+        shape.push_back(dim);
+      }
+    }
+
+    if (mismatches.size() != 0) {
+      RelayErrorStream err;
+      err << "in particular ";
+      for (auto mismatch : mismatches) {
+        err << "dimension "
+            << std::get<0>(mismatch)
+            << " conflicts "
+            << std::get<1>(mismatch)
+            << " does not match "
+            << std::get<2>(mismatch);
+      }
+      Error error(err);
+      this->solver_->ReportError(error, this->loc);
+      return Type(nullptr);
+    }
+
+    return TensorTypeNode::make(shape, tt1->dtype);
+  }
+
   Type VisitType_(const TupleTypeNode* op, const Type& tn) final {
     const auto* ttn = tn.as<TupleTypeNode>();
     if (!ttn || op->fields.size() != ttn->fields.size()) {
@@ -225,6 +348,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
 
  private:
   TypeSolver* solver_;
+  NodeRef loc;
 };
 
 class TypeSolver::Resolver : public TypeMutator {
@@ -412,14 +536,14 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) {
 }
 
 // Add equality constraint
-Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) {
-  // NB(@jroesch): we should probably pass location into the unifier to do better
-  // error reporting as well.
-  Unifier unifier(this);
+Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef& loc) {
+  Unifier unifier(this, loc);
   return unifier.Unify(dst, src);
 }
 
 void TypeSolver::ReportError(const Error& err, const NodeRef& location)  {
+  CHECK(location.defined());
+  CHECK(current_func.defined());
   err_reporter_->ReportAt(current_func, location, err);
 }
 
@@ -460,7 +584,6 @@ Type TypeSolver::Resolve(const Type& type) {
 }
 
 bool TypeSolver::Solve() {
-  // Update until queue is empty.
   while (!update_queue_.empty()) {
     RelationNode* rnode = update_queue_.front();
     const auto& rel = rnode->rel;
@@ -474,7 +597,7 @@ bool TypeSolver::Solve() {
     }
 
     CHECK(rnode->location.defined())
-      << "undefined location, should be set when constructing relation node";
+        << "undefined location, should be set when constructing relation node";
 
     // We need to set this in order to understand where unification
     // errors generated by the error reporting are coming from.
@@ -494,11 +617,10 @@ bool TypeSolver::Solve() {
       rnode->resolved = false;
     } catch (const dmlc::Error& err) {
       rnode->resolved = false;
-      this->ReportError(
-          RELAY_ERROR(
-            "an internal invariant was violated while " \
-            "typechecking your program " <<
-            err.what()), rnode->location);
+      this->ReportError(RELAY_ERROR("an internal invariant was violated while "
+                                    "typechecking your program "
+                                    << err.what()),
+                        rnode->location);
     }
 
     // Mark inqueue as false after the function call
@@ -516,17 +638,21 @@ TVM_REGISTER_API("relay._analysis._test_type_solver")
 .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
     using runtime::PackedFunc;
     using runtime::TypedPackedFunc;
-    ErrorReporter err_reporter;
-    auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), &err_reporter);
+    ErrorReporter *err_reporter = new ErrorReporter();
+    auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), err_reporter);
 
-    auto mod = [solver](std::string name) -> PackedFunc {
+    auto mod = [solver, err_reporter](std::string name) -> PackedFunc {
       if (name == "Solve") {
         return TypedPackedFunc<bool()>([solver]() {
             return solver->Solve();
           });
       } else if (name == "Unify") {
-        return TypedPackedFunc<Type(Type, Type)>([solver](Type lhs, Type rhs) {
-            return solver->Unify(lhs, rhs, lhs);
+        return TypedPackedFunc<Type(Type, Type)>([solver, err_reporter](Type lhs, Type rhs) {
+            auto res = solver->Unify(lhs, rhs, lhs);
+            if (err_reporter->AnyErrors()) {
+              err_reporter->RenderErrors(ModuleNode::make({}, {}), true);
+            }
+            return res;
           });
       } else if (name == "Resolve") {
         return TypedPackedFunc<Type(Type)>([solver](Type t) {
index 002ccac356f02e24413a66022234840e28c1e34f..28579633c1c671a3fa3d38ac29740a5443ce6ff8 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -89,7 +89,6 @@ class TypeSolver {
    * \param location The location at which the unification problem arose.
    */
   Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location);
-
   /*!
    * \brief Report an error at the provided location.
    * \param err The error to report.
@@ -124,6 +123,7 @@ class TypeSolver {
     TypeNode* parent{nullptr};
     /*! \brief set of relations that is related to this type node */
     std::unordered_set<RelationNode*> rel_set;
+
     /*!
      * \brief Find the root type node, perform path compression
      * \return The root type node.
@@ -159,13 +159,15 @@ class TypeSolver {
     NodeRef location;
   };
 
+  /*! \brief A simple union find between shapes. */
+  tvm::Map<IndexExpr, IndexExpr> shape_uf_;
   /*! \brief List of all allocated type nodes */
   std::vector<TypeNode*> type_nodes_;
   /*! \brief List of all allocated relation nodes */
   std::vector<RelationNode*> rel_nodes_;
   /*! \brief Number of resolved relations */
   size_t num_resolved_rels_{0};
-  /*! \brief map from type node to types. */
+  /*! \brief map from types to type nodes. */
   std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_;
   /*! \brief Internal queue to update the relation */
   std::queue<RelationNode*> update_queue_;
@@ -205,6 +207,7 @@ class TypeSolver {
     rel->inqueue = true;
     update_queue_.push(rel);
   }
+
   /*!
    * \brief Merge rhs type node to lhs
    * \param src The source operand
index 39c17b8b3a81ce293fc8f3e945ef32f1b5f3a740..0877ead3b27dc5d710bc703bb1b9ad94052ba3fe 100644 (file)
@@ -184,6 +184,10 @@ void NDArray::CopyFromTo(DLTensor* from,
     from_size, from->ctx, to->ctx, from->dtype, stream);
 }
 
+std::vector<int64_t> NDArray::Shape() const {
+  return data_->shape_;
+}
+
 }  // namespace runtime
 }  // namespace tvm
 
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
new file mode 100644 (file)
index 0000000..8359fc5
--- /dev/null
@@ -0,0 +1,143 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import relay
+from tvm.relay import Kind, transform
+from tvm.relay.loops import while_loop
+import numpy as np
+
+def infer_type(expr):
+    mod = relay.Module.from_expr(expr)
+    mod = transform.InferType()(mod)
+    entry = mod[mod.entry_func]
+    return entry if isinstance(expr, relay.Function) else entry.body
+
+def int32(val):
+    return relay.const(val, 'int32')
+
+def test_arange_with_dynamic_shape():
+    m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
+    x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
+    y0 = relay.shape_of(x)
+    y1 = relay.take(y0, relay.const(0, 'int32'))
+    y2 = relay.op.arange(y1)
+    ex = relay.create_executor()
+    f = relay.Function([x], y2, type_params=[m, n, k])
+    # TODO(@jroesch): Restore after code generation.
+    # data = np.random.rand(10, 5, 3).astype('float32')
+    # result = ex.evaluate(f)(data)
+    # np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
+
+def test_dynamic_concat():
+    """
+    fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) {
+        if (%i < 10) {
+            let %i = reshape(cast(i, "float32"), newshape=(1, ))
+            let %new_st = concatenate((st, i), axis=0)
+            concat_loop(%i + 1, )
+        } else {
+            st
+        }
+    }
+    """
+    # Initial Values.
+    i = relay.var('i', shape=(), dtype='int32')
+    st = relay.var('st', shape=(relay.Any(), 1), dtype='int32')
+
+    def _cond(i, st):
+        return relay.op.min(relay.op.less(i, int32(10)))
+
+    def _body(i, st):
+        i_vec = relay.op.reshape(i, (1,1))
+        ret = relay.op.concatenate([st, i_vec], axis=0)
+        return i + int32(1), ret
+
+    loop = while_loop(_cond, [i, st], _body)
+    start = relay.var('start', shape=(), dtype='int32')
+    body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
+    func = relay.Function([start], relay.TupleGetItem(body, 1))
+    func = infer_type(func)
+    # TODO(@jroesch, @haichen): We should restore this code when codegeneration
+    # is merged
+    # ret_shape = func.checked_type.ret_type.shape
+    # assert len(ret_shape) == 2, "expected 2-dim output"
+    # assert relay.ir_pass.alpha_eq(ret_shape[0], relay.Any())
+    # import pdb; pdb.set_trace()
+    # mod = relay.module.Module()
+    # print(relay.ir_pass.infer_type(func, mod=mod))
+    # ret = relay.Call(loop, [relay.const(0, 'int32'), init])
+    # mod[mod.entry_func] = relay.Function([], ret)
+    # print(relay.ir_pass.infer_type(mod[mod.entry_func], mod=mod))
+
+    # initial = np.array(0.0, dtype='float32').reshape((1,))
+    # iter_stop = np.array(10, dtype='int32')
+    # ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
+    # result = ex.evaluate(mod.entry_func)()
+    # np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
+
+def test_dynamic_concat_with_wrong_annotation():
+    """
+    v0.0.1
+    fn (%start: int32) {
+        %7 = {
+            let %while_loop = fn (%i: int32, %st: Tensor[(1, 1), int32]) {
+            %0 = less(%i, 10)
+            %1 = min(%0)
+            if (%1) {
+                %2 = add(%i, 1)
+                %3 = reshape(%i, newshape=[1, 1])
+                %4 = (%st, %3)
+                /* The result of concat should be 1,1 but it is 2, 1. */
+                %5 = concatenate(%4)
+                %while_loop(%2, %5)
+            } else {
+                (%i, %st)
+            }
+        }
+        %6 = reshape(0, newshape=[1, 1])
+        %while_loop(%start, %6)
+    }
+    %7.1
+    }
+    """
+    # Initial Values.
+    i = relay.var('i', shape=(), dtype='int32')
+    st = relay.var('st', shape=(1, 1), dtype='int32')
+
+    def _cond(i, st):
+        return relay.op.min(relay.op.less(i, int32(10)))
+
+    def _body(i, st):
+        i_vec = relay.op.reshape(i, (1,1))
+        ret = relay.op.concatenate([st, i_vec], axis=0)
+        return i + int32(1), ret
+
+    loop = while_loop(_cond, [i, st], _body)
+    start = relay.var('start', shape=(), dtype='int32')
+    body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
+    func = relay.Function([start], relay.TupleGetItem(body, 1))
+    try:
+        func = infer_type(func)
+        assert False
+    except Exception as e:
+        assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
+
+if __name__ == "__main__":
+    test_arange_with_dynamic_shape()
+    test_dynamic_concat()
+    test_dynamic_concat_with_wrong_annotation()
index e1a760421349ac270fb7b40ef5ea869728513f18..da3de2b22f741d16e07170765d0ffe51285bec33 100644 (file)
@@ -493,17 +493,20 @@ def test_arange():
     def verify_arange(start, stop, step):
         dtype = "float32"
         if start is None and step is None:
-            x = relay.arange(stop)
-            ref_res = np.arange(stop)
+            x = relay.arange(relay.const(stop, dtype=dtype))
+            ref_res = np.arange(stop).astype(dtype)
         elif start is None:
-            x = relay.arange(stop, step=step)
-            ref_res = np.arange(stop, step=step)
+            x = relay.arange(relay.const(stop, dtype=dtype), step=relay.const(step, dtype=dtype))
+            ref_res = np.arange(stop, step=step).astype(dtype)
         elif step is None:
-            x = relay.arange(start, stop)
-            ref_res = np.arange(start, stop)
+            x = relay.arange(relay.const(start, dtype=dtype), relay.const(stop, dtype=dtype))
+            ref_res = np.arange(start, stop).astype(dtype)
         else:
-            x = relay.arange(start, stop, step)
-            ref_res = np.arange(start, stop, step)
+            x = relay.arange(
+                relay.const(start, dtype=dtype),
+                relay.const(stop, dtype=dtype),
+                relay.const(step, dtype=dtype))
+            ref_res = np.arange(start, stop, step).astype(dtype)
 
         func = relay.Function([], x)
         for target, ctx in ctx_list():
@@ -515,11 +518,13 @@ def test_arange():
     verify_arange(None, 20, 2)
     verify_arange(1, 20, None)
     verify_arange(1, 20, 2)
-    verify_arange(1, 20, 1.5)
+    # arange doesnt' support floating point right now, see type relation
+    # verify_arange(1, 20, 1.5)
     verify_arange(1, 20.5, None)
     verify_arange(1, 20, 3)
     verify_arange(20, 1, -1)
-    verify_arange(20, 1, -1.5)
+    # arange doesnt' support floating point right now, see type relation
+    # verify_arange(20, 1, -1.5)
 
 def test_tile():
     def verify_tile(dshape, reps):
@@ -616,6 +621,7 @@ def test_gather_nd():
 
 
 if __name__ == "__main__":
+    test_arange()
     test_cast()
     test_zeros_ones()
     test_unary_identity()