[Relay][Op]Support symbolic TopK, Ones, Zeros and Full (#5459)
authorYao Wang <kevinthesunwy@gmail.com>
Tue, 26 May 2020 01:09:44 +0000 (18:09 -0700)
committerGitHub <noreply@github.com>
Tue, 26 May 2020 01:09:44 +0000 (18:09 -0700)
* Support symbolic TopK, Ones, Zeros and Full

* Fix pylint

* Add docstring for topk shape func

* Fix grad

* Fix lazy_gradient_init

* Fix parser

* Fix print ir text

* Fix lint

* Improve pattern_util

* Fix topk

* Fix build

* Use Optional for attribute

* Fix clang-format

* Minot fix

* Fix pylint

* Fix build warning

* Fix parser

* Move ToScalar

* Fix lint

* Fix lint

* Make topk shape func as data independent when k is constant.

* Fix lint

* Minor fix

22 files changed:
include/tvm/relay/attrs/algorithm.h
include/tvm/relay/attrs/transform.h
include/tvm/runtime/ndarray.h
python/tvm/relay/_parser.py
python/tvm/relay/op/_algorithm.py
python/tvm/relay/op/_tensor.py
python/tvm/relay/op/_tensor_grad.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/algorithm.py
python/tvm/relay/op/strategy/generic.py
python/tvm/relay/op/tensor.py
python/tvm/relay/op/transform.py
src/relay/analysis/util.cc
src/relay/op/algorithm/topk.cc
src/relay/op/image/resize.cc
src/relay/op/tensor/transform.cc
src/relay/op/tensor/transform.h
src/relay/qnn/util.cc
src/relay/transforms/lazy_gradient_init.cc
src/relay/transforms/pattern_util.h
tests/python/relay/test_any.py
topi/python/topi/sort.py

index a7d4708..83b4dda 100644 (file)
@@ -26,6 +26,7 @@
 
 #include <tvm/ir/attrs.h>
 #include <tvm/relay/base.h>
+#include <tvm/relay/expr.h>
 
 #include <string>
 
@@ -52,14 +53,14 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
 };
 
 struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
-  int k;
+  Optional<Integer> k;
   int axis;
   bool is_ascend;
   std::string ret_type;
   DataType dtype;
 
   TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
-    TVM_ATTR_FIELD(k).set_default(1).describe("Number of top elements to select");
+    TVM_ATTR_FIELD(k).describe("Number of top elements to select");
     TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor.");
     TVM_ATTR_FIELD(ret_type).set_default("both").describe(
         "The return type [both, values, indices]."
index 7fb7f3a..ccf8e54 100644 (file)
@@ -111,7 +111,7 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
 
 /*! \brief Attributes that specify a tensor */
 struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
-  Array<IndexExpr> shape;
+  Optional<Array<Integer>> shape;
   DataType dtype;
 
   TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") {
index 0171d8a..e69d802 100644 (file)
@@ -462,7 +462,11 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
   int64_t data_byte_size;
   CHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format";
   CHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format";
-  CHECK(strm->Read(ret->data, data_byte_size)) << "Invalid DLTensor file format";
+  auto read_ret = strm->Read(ret->data, data_byte_size);
+  // Only check non-empty data
+  if (ndim > 0 && shape[0] != 0) {
+    CHECK(read_ret) << "Invalid DLTensor file format";
+  }
   if (!DMLC_IO_NO_ENDIAN_SWAP) {
     dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
   }
index 1d97b55..49f2d4d 100644 (file)
@@ -116,6 +116,8 @@ class FuncOp(OpWrapper):
             attrs = {}
         if self.operator is op.reshape:
             x = self.operator(*args)
+        elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
+            x = self.operator(*args, dtype=attrs["dtype"])
         else:
             x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
         if isinstance(x, expr.TupleWrapper):
index e1e6fd3..5a20480 100644 (file)
 # pylint: disable=invalid-name,unused-argument
 from __future__ import absolute_import
 
+from tvm.te.hybrid import script
+from tvm.runtime import convert
+
 from . import strategy
+from . import op as _reg
 from .op import OpPattern, register_pattern
 from .op import register_strategy
 
@@ -29,3 +33,67 @@ register_pattern("argsort", OpPattern.OPAQUE)
 # topk
 register_strategy("topk", strategy.topk_strategy)
 register_pattern("topk", OpPattern.OPAQUE)
+
+@script
+def _topk_shape_func_input_data(data, k, axis):
+    ndim = len(data.shape)
+    val_out = output_tensor((ndim,), "int64")
+    indices_out = output_tensor((ndim,), "int64")
+
+    for i in const_range(ndim):
+        if i != axis:
+            val_out[i] = int64(data.shape[i])
+            indices_out[i] = int64(data.shape[i])
+        else:
+            if k[0] < 1:
+                val_out[i] = int64(data.shape[i])
+                indices_out[i] = int64(data.shape[i])
+            else:
+                val_out[i] = int64(k[0])
+                indices_out[i] = int64(k[0])
+    return val_out, indices_out
+
+@script
+def _topk_shape_func_input_shape(data_shape, k, axis):
+    ndim = data_shape.shape[0]
+    val_out = output_tensor((ndim,), "int64")
+    indices_out = output_tensor((ndim,), "int64")
+
+    for i in const_range(ndim):
+        if i != axis:
+            val_out[i] = int64(data_shape[i])
+            indices_out[i] = int64(data_shape[i])
+        else:
+            if k < 1:
+                val_out[i] = int64(data_shape[i])
+                indices_out[i] = int64(data_shape[i])
+            else:
+                val_out[i] = int64(k)
+                indices_out[i] = int64(k)
+    return val_out, indices_out
+
+@_reg.register_shape_func("topk", True)
+def topk_shape_func(attrs, inputs, _):
+    """
+    Shape func for topk.
+    """
+    axis = attrs.axis
+    if attrs.k is not None:
+        if axis < 0:
+            axis += inputs[0].shape[0]
+        val_out, indices_out = \
+            _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
+    else:
+        if axis < 0:
+            axis += len(inputs[0].shape)
+        val_out, indices_out = \
+            _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
+    ret_type = attrs.ret_type
+    if ret_type == "both":
+        ret = [val_out, indices_out]
+    elif ret_type == "values":
+        ret = [val_out]
+    else:
+        ret = [indices_out]
+
+    return ret
index e029e0c..cd9e4ed 100644 (file)
 #pylint: disable=invalid-name, unused-argument, len-as-condition
 """Backend compiler related feature registration"""
 
-from tvm.runtime import convert
 from tvm.te.hybrid import script
 import topi
-from topi.util import get_const_tuple
+
 from .op import register_compute, register_shape_func
 from .op import register_broadcast_schedule, register_injective_schedule
 from .op import register_pattern, OpPattern
@@ -93,7 +92,7 @@ register_broadcast_schedule("fast_erf")
 # zeros
 @register_compute("zeros")
 def zeros_compute(attrs, inputs, output_type):
-    assert not inputs
+    assert len(inputs) == 1
     return [topi.full(output_type.shape, output_type.dtype, 0.0)]
 
 register_broadcast_schedule("zeros")
@@ -110,7 +109,7 @@ register_broadcast_schedule("zeros_like")
 # ones
 @register_compute("ones")
 def ones_compute(attrs, inputs, output_type):
-    assert not inputs
+    assert len(inputs) == 1
     return [topi.full(output_type.shape, output_type.dtype, 1.0)]
 
 register_broadcast_schedule("ones")
@@ -132,20 +131,10 @@ def clip_compute(attrs, inputs, output_type):
 
 register_injective_schedule("clip")
 
-@script
-def _cast_shape_function(x):
-    out_ndim = len(x)
-    out = output_tensor((out_ndim,), "int64")
-    for i in const_range(out_ndim):
-        out[i] = x[i]
-    return out
-
-def cast_shape_func(attrs, inputs, out_ndims):
-    return [_cast_shape_function(*inputs)]
-
+# full
 @script
 def _full_shape_func(shape):
-    out_ndim = len(shape)
+    out_ndim = shape.shape[0]
     out = output_tensor((out_ndim,), "int64")
     for i in const_range(out_ndim):
         out[i] = int64(shape[i])
@@ -153,10 +142,15 @@ def _full_shape_func(shape):
 
 def full_shape_func(attrs, inputs, out_ndims):
     """
-    Shape func for zeros, zeros_like, ones, ones_like.
+    Shape func for full.
+    """
+    return [_full_shape_func(inputs[1])]
+
+def no_data_full_shape_func(attrs, inputs, out_ndims):
+    """
+    Shape func for zeros and ones.
     """
-    shape = get_const_tuple(attrs.shape)
-    return [_full_shape_func(convert(shape))]
+    return [_full_shape_func(inputs[0])]
 
 @script
 def _broadcast_shape_func(x, y, ndim):
@@ -198,13 +192,14 @@ def elemwise_shape_func(attrs, inputs, _):
     """
     return [topi.math.identity(inputs[0])]
 
-register_shape_func("cast", False, cast_shape_func)
-register_shape_func("zeros", False, full_shape_func)
+register_shape_func("cast", False, elemwise_shape_func)
+register_shape_func("zeros", True, no_data_full_shape_func)
 register_shape_func("zeros_like", False, elemwise_shape_func)
-register_shape_func("ones", False, full_shape_func)
+register_shape_func("ones", True, no_data_full_shape_func)
 register_shape_func("ones_like", False, elemwise_shape_func)
-register_shape_func("full", False, full_shape_func)
+register_shape_func("full", True, full_shape_func)
 register_shape_func("full_like", False, elemwise_shape_func)
+register_shape_func("broadcast_to", True, full_shape_func)
 
 register_shape_func("add", False, broadcast_shape_func)
 register_shape_func("subtract", False, broadcast_shape_func)
index 8be3358..8ba1020 100644 (file)
@@ -232,14 +232,14 @@ def divide_grad(orig, grad):
 
 @register_gradient("zeros")
 def zeros_grad(orig, grad):
-    """Returns []"""
-    return []
+    """Returns [shape]"""
+    return [orig.args[0]]
 
 
 @register_gradient("ones")
 def ones_grad(orig, grad):
-    """Returns []"""
-    return []
+    """Returns [shape]"""
+    return [orig.args[0]]
 
 
 @register_gradient("zeros_like")
index 43d8d62..e1c2bd7 100644 (file)
@@ -120,6 +120,8 @@ def _concatenate_shape_func(inputs, axis):
 @_reg.register_shape_func("concatenate", False)
 def concatenate_shape_func(attrs, inputs, _):
     axis = get_const_int(attrs.axis)
+    if axis < 0:
+        axis += inputs[0].shape[0]
     return [_concatenate_shape_func(inputs, convert(axis))]
 
 @script
index 17fab80..d31e89a 100644 (file)
@@ -17,7 +17,7 @@
 """Classic algorithm operation"""
 from __future__ import absolute_import as _abs
 from . import _make
-from ..expr import TupleWrapper
+from ..expr import TupleWrapper, const
 
 def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
     """Performs sorting along the given axis and returns an array of indicies
@@ -48,7 +48,8 @@ def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
     return _make.argsort(data, axis, is_ascend, dtype)
 
 
-def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
+def topk(data, k=1, axis=-1, ret_type="both",
+         is_ascend=False, dtype="int32"):
     """Get the top k elements in an input tensor along the given axis.
 
     ret_type specifies the return type, can be one of ("both", "values", "indices").
@@ -58,7 +59,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
     data : relay.Expr
         The input data tensor.
 
-    k : int, optional
+    k : int or relay.Expr, optional
         Number of top elements to select. Return all elements if k < 1.
 
     axis : int, optional
@@ -81,6 +82,8 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
     out : relay.Expr or List[relay.Expr]
         The computed result.
     """
+    if isinstance(k, int):
+        k = const(k, "int64")
     out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
     if ret_type == "both":
         return TupleWrapper(out, 2)
index 6db5b14..99439af 100644 (file)
@@ -598,7 +598,9 @@ def argsort_strategy(attrs, inputs, out_type, target):
 def wrap_compute_topk(topi_compute):
     """Wrap topk compute"""
     def _compute_topk(attrs, inputs, out_type):
-        k = get_const_int(attrs.k)
+        k = inputs[1]
+        if attrs.k is not None:
+            k = attrs.k
         axis = get_const_int(attrs.axis)
         ret_type = attrs.ret_type
         is_ascend = bool(get_const_int(attrs.is_ascend))
index d5ae5cd..c60dbee 100644 (file)
@@ -20,7 +20,7 @@ from tvm.runtime import ndarray as _nd
 from tvm.runtime import TVMContext as _TVMContext
 
 from . import _make
-from ..expr import Tuple
+from ..expr import Tuple, const
 
 
 # We create a wrapper function for each operator in the
@@ -928,7 +928,7 @@ def zeros(shape, dtype):
 
     Parameters
     ----------
-    shape : tuple of int
+    shape : tuple of int or relay.Expr
         The shape of the target.
 
     dtype : data type
@@ -939,6 +939,8 @@ def zeros(shape, dtype):
     result : relay.Expr
         The resulting tensor.
     """
+    if isinstance(shape, (list, tuple)):
+        shape = const(list(shape), "int32")
     return _make.zeros(shape, dtype)
 
 
@@ -963,7 +965,7 @@ def ones(shape, dtype):
 
     Parameters
     ----------
-    shape : tuple of int
+    shape : tuple of int or relay.Expr
         The shape of the target.
 
     dtype : data type
@@ -974,6 +976,8 @@ def ones(shape, dtype):
     result : relay.Expr
         The resulting tensor.
     """
+    if isinstance(shape, (list, tuple)):
+        shape = const(list(shape), "int32")
     return _make.ones(shape, dtype)
 
 
index 2d9e4ba..1da58ae 100644 (file)
@@ -299,7 +299,7 @@ def full(fill_value, shape=(), dtype=""):
     fill_value : relay.Expr
         The value to fill. Must be a scalar.
 
-    shape : tuple of int
+    shape : tuple of int or relay.Expr
         The shape of the target.
 
     dtype : data type, optional (defaults to data type of the fill value)
@@ -310,6 +310,8 @@ def full(fill_value, shape=(), dtype=""):
     result : relay.Expr
         The resulting tensor.
     """
+    if isinstance(shape, (list, tuple)):
+        shape = const(list(shape), "int32")
     return _make.full(fill_value, shape, dtype)
 
 
@@ -527,7 +529,7 @@ def broadcast_to(data, shape):
     data : relay.Expr
         The input tensor.
 
-    shape : shape
+    shape : tuple of int or relay.Expr
         Provide the shape to broadcast to.
 
     Returns
@@ -535,6 +537,8 @@ def broadcast_to(data, shape):
     result : relay.Expr
         The resulting tensor.
     """
+    if isinstance(shape, (list, tuple)):
+        shape = const(list(shape), "int32")
     return _make.broadcast_to(data, shape)
 
 def broadcast_to_like(data, broadcast_type):
index a05bb8f..2853165 100644 (file)
@@ -25,6 +25,7 @@
  */
 #include <tvm/ir/type_functor.h>
 #include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/algorithm.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
@@ -450,6 +451,13 @@ bool IsDataDependant(const CallNode* call) {
         return false;
       }
     }
+  } else if (op->name == "topk") {
+    if (const auto* attrs = call->attrs.as<TopKAttrs>()) {
+      if (attrs->k) {
+        // If k attribute exists, it isn't data dependant.
+        return false;
+      }
+    }
   }
 
   return tshape_data_dependant[op];
index 5ff5904..3db8eee 100644 (file)
  */
 #include <tvm/relay/attrs/algorithm.h>
 #include <tvm/relay/op.h>
+#include <tvm/tir/op.h>
 
 namespace tvm {
 namespace relay {
+using tir::make_const;
 
 TVM_REGISTER_NODE_TYPE(TopKAttrs);
 
@@ -33,7 +35,7 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
              const TypeReporter& reporter) {
   // `types` contains: [data, result]
   const TopKAttrs* param = attrs.as<TopKAttrs>();
-  CHECK_EQ(types.size(), 2);
+  CHECK_EQ(types.size(), 3);
   const auto* data = types[0].as<TensorTypeNode>();
   CHECK(data);
   int ndim = data->shape.size();
@@ -44,35 +46,44 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   CHECK(axis >= 0 && axis < ndim);
   Array<IndexExpr> out_shape;
   for (int i = 0; i < ndim; ++i) {
-    if (i != axis || param->k < 1) {
+    if (i != axis) {
       out_shape.push_back(data->shape[i]);
+    } else if (param->k) {
+      const Integer& ck = param->k.value();
+      if (ck->value < 1) {
+        out_shape.push_back(data->shape[i]);
+      } else {
+        out_shape.push_back(ck);
+      }
     } else {
-      out_shape.push_back(param->k);
+      out_shape.push_back(Any::make());
     }
   }
   auto values_ty = TensorType(out_shape, data->dtype);
   auto indices_ty = TensorType(out_shape, param->dtype);
   if (param->ret_type == "both") {
-    reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
+    reporter->Assign(types[2], TupleType({values_ty, indices_ty}));
   } else if (param->ret_type == "values") {
-    reporter->Assign(types[1], values_ty);
+    reporter->Assign(types[2], values_ty);
   } else if (param->ret_type == "indices") {
-    reporter->Assign(types[1], indices_ty);
+    reporter->Assign(types[2], indices_ty);
   } else {
     LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
   }
   return true;
 }
 
-Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype) {
+Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) {
   auto attrs = make_object<TopKAttrs>();
-  attrs->k = k;
+  if (const auto& ck = k.as<ConstantNode>()) {
+    attrs->k = tvm::Integer(reinterpret_cast<int*>(ck->data->data)[0]);
+  }
   attrs->axis = axis;
   attrs->ret_type = ret_type;
   attrs->is_ascend = is_ascend;
   attrs->dtype = dtype;
   static const Op& op = Op::Get("topk");
-  return Call(op, {data}, Attrs(attrs), {});
+  return Call(op, {data, k}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
@@ -80,9 +91,10 @@ TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
 RELAY_REGISTER_OP("topk")
     .describe(R"doc(Get the top k elements in an input tensor along the given axis.
 )doc" TVM_ADD_FILELINE)
-    .set_num_inputs(1)
+    .set_num_inputs(2)
     .set_attrs_type<TopKAttrs>()
     .add_argument("data", "Tensor", "Input data.")
+    .add_argument("k", "Tensor", "Number of top elements.")
     .set_support_level(6)
     .add_type_rel("TopK", TopKRel);
 
index 7bddb29..b6d2c71 100644 (file)
@@ -194,12 +194,12 @@ bool CropAndResizeRel(const Array<Type>& types, int num_inputs, const Attrs& att
   const Layout in_layout(param->layout);
   auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
   auto oshape = layout_converter.ForwardShape(data->shape);
-  oshape.Set(0, box_indices->shape[0]);
+  oshape.Set(0, boxes->shape[0]);
   oshape.Set(2, crop_size[0]);
   oshape.Set(3, crop_size[1]);
   auto bshape = layout_converter.BackwardShape(oshape);
   // assign output type
-  reporter->Assign(types[3], TensorType(layout_converter.BackwardShape(oshape), out_dtype));
+  reporter->Assign(types[3], TensorType(bshape, out_dtype));
   return true;
 }
 
index 6ccf585..7282ac7 100644 (file)
@@ -447,44 +447,6 @@ RELAY_REGISTER_OP("transpose")
 /* relay.reshape */
 TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
 
-double ToScalar(const runtime::NDArray& array, int i = 0) {
-  if (array->dtype.code == kDLInt) {
-    if (array->dtype.bits == 8) {
-      return reinterpret_cast<int8_t*>(array->data)[i];
-    } else if (array->dtype.bits == 16) {
-      return reinterpret_cast<int16_t*>(array->data)[i];
-    } else if (array->dtype.bits == 32) {
-      return reinterpret_cast<int32_t*>(array->data)[i];
-    } else if (array->dtype.bits == 64) {
-      return reinterpret_cast<int64_t*>(array->data)[i];
-    }
-  } else if (array->dtype.code == kDLUInt) {
-    if (array->dtype.bits == 8) {
-      return reinterpret_cast<uint8_t*>(array->data)[i];
-    } else if (array->dtype.bits == 16) {
-      return reinterpret_cast<uint16_t*>(array->data)[i];
-    } else if (array->dtype.bits == 32) {
-      return reinterpret_cast<uint32_t*>(array->data)[i];
-    } else if (array->dtype.bits == 64) {
-      return reinterpret_cast<uint64_t*>(array->data)[i];
-    }
-  } else if (array->dtype.code == kDLFloat) {
-#if (__ARM_FP16_FORMAT_IEEE == 1)
-    if (array->dtype.bits == 16) {
-      return reinterpret_cast<__fp16*>(array->data)[i];
-    }
-#endif
-    if (array->dtype.bits == 32) {
-      return reinterpret_cast<float*>(array->data)[i];
-    } else if (array->dtype.bits == 64) {
-      return reinterpret_cast<double*>(array->data)[i];
-    }
-  }
-  LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
-  // make compiler happy
-  return -std::numeric_limits<double>::infinity();
-}
-
 bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                 const TypeReporter& reporter) {
   const auto* param = attrs.as<ReshapeAttrs>();
@@ -663,11 +625,7 @@ Expr MakeReshape(Expr data, Expr newshape) {
   auto attrs = make_object<ReshapeAttrs>();
   if (const ConstantNode* c = newshape.as<ConstantNode>()) {
     CHECK_EQ(c->data->ndim, 1);
-    Array<Integer> newshape;
-    for (int i = 0; i < c->data->shape[0]; i++) {
-      newshape.push_back(Integer(static_cast<int>(ToScalar(c->data, i))));
-    }
-    attrs->newshape = newshape;
+    attrs->newshape = ToVector(c->data);
   }
   attrs->reverse = false;
   static const Op& op = Op::Get("reshape");
@@ -929,9 +887,10 @@ TVM_REGISTER_NODE_TYPE(InitOpAttrs);
 
 bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
              const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
+  CHECK_EQ(types.size(), 3);
   const InitOpAttrs* param = attrs.as<InitOpAttrs>();
   const auto* fill_value = types[0].as<TensorTypeNode>();
+  const auto* fill_shape = types[1].as<TensorTypeNode>();
   if (fill_value == nullptr) {
     return false;
   }
@@ -944,7 +903,21 @@ bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   CHECK_EQ(fill_value->shape.size(), 0)
       << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << ".";
 
-  reporter->Assign(types[1], TensorType(param->shape, out_dtype));
+  const IntImmNode* shape_shape = fill_shape->shape[0].as<IntImmNode>();
+  CHECK(shape_shape) << "Parameter shape must have static shape";
+
+  std::vector<IndexExpr> oshape;
+  if (param->shape) {
+    const Array<Integer>& cshape_array = param->shape.value();
+    for (size_t i = 0; i < cshape_array.size(); ++i) {
+      oshape.push_back(cshape_array[i]);
+    }
+  } else {
+    for (int i = 0; i < shape_shape->value; ++i) {
+      oshape.push_back(Any::make());
+    }
+  }
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
   return true;
 }
 
@@ -954,12 +927,14 @@ Array<te::Tensor> FullCompute(const Attrs& attrs, const Array<te::Tensor>& input
   return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())};
 }
 
-Expr MakeFull(Expr fill_value, Array<IndexExpr> shape, DataType dtype) {
+Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) {
   auto attrs = make_object<InitOpAttrs>();
-  attrs->shape = std::move(shape);
+  if (const auto* cshape = shape.as<ConstantNode>()) {
+    attrs->shape = ToVector(cshape->data);
+  }
   attrs->dtype = std::move(dtype);
   static const Op& op = Op::Get("full");
-  return Call(op, {fill_value}, Attrs(attrs), {});
+  return Call(op, {fill_value, shape}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull);
@@ -969,8 +944,9 @@ RELAY_REGISTER_OP("full")
 
 )code" TVM_ADD_FILELINE)
     .set_attrs_type<InitOpAttrs>()
-    .set_num_inputs(1)
+    .set_num_inputs(2)
     .add_argument("fill_value", "double", "The value to fill.")
+    .add_argument("shape", "Tensor", "Target shape.")
     .set_support_level(3)
     .add_type_rel("Full", FullRel)
     .set_attr<FTVMCompute>("FTVMCompute", FullCompute)
@@ -978,19 +954,37 @@ RELAY_REGISTER_OP("full")
 
 bool InitOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 1);
+  CHECK_EQ(types.size(), 2);
   const InitOpAttrs* param = attrs.as<InitOpAttrs>();
+  const auto* fill_shape = types[0].as<TensorTypeNode>();
+  DataType out_dtype = param->dtype;
+
+  const IntImmNode* shape_shape = fill_shape->shape[0].as<IntImmNode>();
+  CHECK(shape_shape) << "Parameter shape must have static shape";
 
-  reporter->Assign(types[0], TensorType(param->shape, param->dtype));
+  std::vector<IndexExpr> oshape;
+  if (param->shape) {
+    const Array<Integer>& cshape_array = param->shape.value();
+    for (size_t i = 0; i < cshape_array.size(); ++i) {
+      oshape.push_back(cshape_array[i]);
+    }
+  } else {
+    for (int i = 0; i < shape_shape->value; ++i) {
+      oshape.push_back(Any::make());
+    }
+  }
+  reporter->Assign(types[1], TensorType(oshape, out_dtype));
   return true;
 }
 
-Expr MakeZeros(Array<IndexExpr> shape, DataType dtype) {
+Expr MakeZeros(Expr shape, DataType dtype) {
   auto attrs = make_object<InitOpAttrs>();
-  attrs->shape = std::move(shape);
+  if (const auto* cshape = shape.as<ConstantNode>()) {
+    attrs->shape = ToVector(cshape->data);
+  }
   attrs->dtype = std::move(dtype);
   static const Op& op = Op::Get("zeros");
-  return Call(op, {}, Attrs(attrs), {});
+  return Call(op, {shape}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op._make.zeros").set_body_typed(MakeZeros);
@@ -1000,16 +994,19 @@ RELAY_REGISTER_OP("zeros")
 
 )code" TVM_ADD_FILELINE)
     .set_attrs_type<InitOpAttrs>()
-    .set_num_inputs(0)
+    .set_num_inputs(1)
+    .add_argument("shape", "Tensor", "Target shape.")
     .set_support_level(3)
     .add_type_rel("InitOp", InitOpRel);
 
-Expr MakeOnes(Array<IndexExpr> shape, DataType dtype) {
+Expr MakeOnes(Expr shape, DataType dtype) {
   auto attrs = make_object<InitOpAttrs>();
-  attrs->shape = std::move(shape);
+  if (const auto* cshape = shape.as<ConstantNode>()) {
+    attrs->shape = ToVector(cshape->data);
+  }
   attrs->dtype = std::move(dtype);
   static const Op& op = Op::Get("ones");
-  return Call(op, {}, Attrs(attrs), {});
+  return Call(op, {shape}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op._make.ones").set_body_typed(MakeOnes);
@@ -1019,7 +1016,8 @@ RELAY_REGISTER_OP("ones")
 
 )code" TVM_ADD_FILELINE)
     .set_attrs_type<InitOpAttrs>()
-    .set_num_inputs(0)
+    .set_num_inputs(1)
+    .add_argument("shape", "Tensor", "Target shape.")
     .set_support_level(3)
     .add_type_rel("InitOp", InitOpRel);
 
@@ -1579,30 +1577,42 @@ RELAY_REGISTER_OP("collapse_sum_like")
 // BroadCastTo: <A, B> -> B where BroadCast(A, B) = B
 bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                     const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
-  auto ioattrs = attrs.as<InitOpAttrs>();
-  CHECK(ioattrs);
-  auto intt = types[0].as<TensorTypeNode>();
-  if (intt == nullptr) {
-    return false;
+  CHECK_EQ(types.size(), 3);
+  const InitOpAttrs* param = attrs.as<InitOpAttrs>();
+  const auto* target_shape = types[1].as<TensorTypeNode>();
+  DataType out_dtype = types[0].as<TensorTypeNode>()->dtype;
+
+  const IntImmNode* shape_shape = target_shape->shape[0].as<IntImmNode>();
+  CHECK(shape_shape) << "Parameter shape must have static shape";
+
+  std::vector<IndexExpr> oshape;
+  if (param->shape) {
+    const Array<Integer>& cshape_array = param->shape.value();
+    for (size_t i = 0; i < cshape_array.size(); ++i) {
+      oshape.push_back(cshape_array[i]);
+    }
+  } else {
+    for (int i = 0; i < shape_shape->value; ++i) {
+      oshape.push_back(Any::make());
+    }
   }
-  auto type = TensorType(ioattrs->shape, intt->dtype);
-  reporter->Assign(types[1], type);
-  return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  return BroadcastRel({types[0], types[2], types[2]}, 2, Attrs(), reporter);
 }
 
-Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) {
+Expr MakeBroadCastTo(Expr data, Expr shape) {
   static const Op& op = Op::Get("broadcast_to");
   auto attrs = make_object<InitOpAttrs>();
-  attrs->shape = std::move(shape);
-  return Call(op, {data}, Attrs(attrs), {});
+  if (const auto* cshape = shape.as<ConstantNode>()) {
+    attrs->shape = ToVector(cshape->data);
+  }
+  return Call(op, {data, shape}, Attrs(attrs), {});
 }
 
 Array<te::Tensor> BroadCastToCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
                                      const Type& out_type) {
-  auto ioattrs = attrs.as<InitOpAttrs>();
-  CHECK(ioattrs != nullptr);
-  return {topi::broadcast_to(inputs[0], ioattrs->shape)};
+  const auto* out_ttype = out_type.as<TensorTypeNode>();
+  return {topi::broadcast_to(inputs[0], out_ttype->shape)};
 }
 
 TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastTo);
@@ -1610,8 +1620,9 @@ TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastT
 RELAY_REGISTER_OP("broadcast_to")
     .describe(R"code(Broadcast the first input to match the shape argument.
 )code" TVM_ADD_FILELINE)
-    .set_num_inputs(1)
+    .set_num_inputs(2)
     .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("shape", "Tensor", "Target shape.")
     .set_support_level(4)
     .add_type_rel("BroadCastTo", BroadCastToRel)
     .set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute)
index bc35ed6..1f30b68 100644 (file)
@@ -90,34 +90,33 @@ bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
     if (e_dtype != dtype) {
       throw Error("relay.concatenate requires all tensors have the same dtype");
     }
-    for (size_t j = 0; j < first->shape.size(); ++j) {
-      if (j == static_cast<size_t>(axis)) continue;
-      if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
-      throw Error(
-          "relay.concatenate requires all tensors have the same shape "
-          "on non-concatenating axes");
-    }
   }
 
   // Calculate shape
   std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
-  IndexExpr& concat_dim = oshape[axis];
-  bool has_any = false;
-  if (concat_dim.as<Any>()) {
-    has_any = true;
-  } else {
-    for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
-      const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
-      if (e->shape[axis].as<Any>()) {
-        has_any = true;
-        break;
+  int data_length = static_cast<int>(tensor_tuple->fields.size());
+  for (int i = 0; i < ndim; ++i) {
+    std::vector<IndexExpr> non_any;
+    for (int j = 0; j < data_length; ++j) {
+      const auto& e = Downcast<TensorType>(tensor_tuple->fields[j]);
+      if (!e->shape[i].as<Any>()) {
+        non_any.push_back(e->shape[i]);
+        // accumulate axis dimension
+        if (j > 0 && i == axis && !oshape[i].as<Any>()) {
+          oshape[i] += e->shape[i];
+        }
+      }
+    }
+    int non_any_size = static_cast<int>(non_any.size());
+    if (non_any_size != data_length) oshape[i] = Any::make();
+    if (i != axis) {
+      for (int k = 1; k < non_any_size; k++) {
+        if (reporter->AssertEQ(non_any[0], non_any[k])) continue;
+        throw Error(
+            "relay.concatenate requires all tensors have the same shape "
+            "on non-concatenating axes");
       }
-      concat_dim += e->shape[axis];
     }
-  }
-
-  if (has_any) {
-    concat_dim = Any::make();
   }
 
   auto rtype = TensorType(oshape, dtype);
index 7171ded..4daa5c9 100644 (file)
@@ -202,8 +202,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
     round_scalar = exp_pos_rounding_value_expr;
   } else if (rounding == "TONEAREST") {
     // To satisfy where op shape requirements, the rounding values are broadcasted.
-    auto pos_rounder = MakeBroadCastTo(exp_pos_rounding_value_expr, input_shape);
-    auto neg_rounder = MakeBroadCastTo(exp_neg_rounding_value_expr, input_shape);
+    auto pos_rounder = BroadCastTo(exp_pos_rounding_value_expr, input_shape);
+    auto neg_rounder = BroadCastTo(exp_neg_rounding_value_expr, input_shape);
 
     auto zero_t = Zeros(input_shape, hp_dtype);
     round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder, neg_rounder);
index 3cd29d6..f062466 100644 (file)
@@ -203,9 +203,9 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator {
       }
 
       if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) {
-        // fn() -> T, function returns result of the operation
-        Expr func =
-            Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {});
+        // ones and zeros need TensorType input
+        Expr result = CallPrimitiveOp(call_node);
+        Expr func = Function({}, result, {call_node->checked_type()}, Array<TypeVar>());
         // call appropriate GradCell constructor
         std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero";
         return Call(module_->GetConstructor("GradCell", constructor_name), {func}, Attrs(),
@@ -288,7 +288,7 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator {
       args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()}));
     }
     // result of operation
-    return Call(call_node->op, args);
+    return Call(call_node->op, args, call_node->attrs);
   }
 };
 
index 8f37e7c..06b1e82 100644 (file)
@@ -37,6 +37,7 @@
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/tir/data_layout.h>
 
+#include <limits>
 #include <string>
 #include <utility>
 #include <vector>
@@ -311,6 +312,25 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> s
 }
 
 /*!
+ * \brief Check whether a shape is static and create corresponding Constant.
+ *
+ * \param shape The Array of the shape values.
+ * \return A Constant.
+ */
+static inline Constant CheckConstantShape(const Array<IndexExpr>& shape) {
+  auto shape_array =
+      runtime::NDArray::Empty({int64_t(shape.size())}, DataType::Int(64), {kDLCPU, 0});
+  auto* shape_data = static_cast<int64_t*>(shape_array->data);
+  for (size_t i = 0; i < shape.size(); ++i) {
+    const auto& dim_val = shape[i].as<IntImmNode>();
+    CHECK(dim_val) << "Do not support symbolic shape for "
+                      "Array format. Pass shape as Expr instead.";
+    shape_data[i] = dim_val->value;
+  }
+  return Constant(shape_array);
+}
+
+/*!
  * \brief Check if two expressions are equal scalars.
  * \param a The expression to be checked.
  * \param b The expression to be checked
@@ -325,6 +345,67 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
   return tvm::StructuralEqual()(a, b);
 }
 
+/*!
+ * \brief Convert an element of a NDArray with type int or float to scalar.
+ * \param array Input NDArray
+ * \param i element index
+ * \return Converted scalar value.
+ */
+static inline double ToScalar(const runtime::NDArray& array, size_t i = 0) {
+  if (array->dtype.code == kDLInt) {
+    if (array->dtype.bits == 8) {
+      return reinterpret_cast<int8_t*>(array->data)[i];
+    } else if (array->dtype.bits == 16) {
+      return reinterpret_cast<int16_t*>(array->data)[i];
+    } else if (array->dtype.bits == 32) {
+      return reinterpret_cast<int32_t*>(array->data)[i];
+    } else if (array->dtype.bits == 64) {
+      return reinterpret_cast<int64_t*>(array->data)[i];
+    }
+  } else if (array->dtype.code == kDLUInt) {
+    if (array->dtype.bits == 8) {
+      return reinterpret_cast<uint8_t*>(array->data)[i];
+    } else if (array->dtype.bits == 16) {
+      return reinterpret_cast<uint16_t*>(array->data)[i];
+    } else if (array->dtype.bits == 32) {
+      return reinterpret_cast<uint32_t*>(array->data)[i];
+    } else if (array->dtype.bits == 64) {
+      return reinterpret_cast<uint64_t*>(array->data)[i];
+    }
+  } else if (array->dtype.code == kDLFloat) {
+#if (__ARM_FP16_FORMAT_IEEE == 1)
+    if (array->dtype.bits == 16) {
+      return reinterpret_cast<__fp16*>(array->data)[i];
+    }
+#endif
+    if (array->dtype.bits == 32) {
+      return reinterpret_cast<float*>(array->data)[i];
+    } else if (array->dtype.bits == 64) {
+      return reinterpret_cast<double*>(array->data)[i];
+    }
+  }
+  LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
+  // make compiler happy
+  return -std::numeric_limits<double>::infinity();
+}
+
+/*!
+ * \brief Convert a NDArray with type int or float to Array<Integer>.
+ * \param array Input NDArray
+ * \return Converted Array.
+ */
+static inline Array<Integer> ToVector(const runtime::NDArray& array) {
+  size_t ndim = array.Shape().size();
+  CHECK_EQ(ndim, 1) << "This function should only used for shape tensor.";
+  size_t len = array.Shape().front();
+  Array<Integer> out;
+  for (size_t i = 0; i < len; ++i) {
+    double elem_val = ToScalar(array, i);
+    out.push_back(Integer(static_cast<int>(elem_val)));
+  }
+  return out;
+}
+
 inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); }
 
 inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); }
@@ -432,12 +513,10 @@ inline Expr ZerosLike(Expr e) {
   return Call(op, {e});
 }
 
+Expr MakeZeros(Expr shape, DataType dtype);
+
 inline Expr Zeros(Array<IndexExpr> shape, DataType dtype) {
-  auto attrs = make_object<InitOpAttrs>();
-  attrs->shape = std::move(shape);
-  attrs->dtype = std::move(dtype);
-  static const Op& op = Op::Get("zeros");
-  return Call(op, {}, Attrs(attrs), {});
+  return MakeZeros(CheckConstantShape(shape), dtype);
 }
 
 inline Expr OnesLike(Expr e) {
@@ -503,12 +582,10 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
   return Call(op, {lhs, rhs}, Attrs(), {});
 }
 
+Expr MakeFull(Expr fill_value, Expr shape, DataType dtype);
+
 static inline Expr Full(Expr fill_value, Array<IndexExpr> shape, DataType dtype) {
-  auto attrs = make_object<InitOpAttrs>();
-  attrs->shape = std::move(shape);
-  attrs->dtype = std::move(dtype);
-  static const Op& op = Op::Get("full");
-  return Call(op, {fill_value}, Attrs(attrs), {});
+  return MakeFull(fill_value, CheckConstantShape(shape), dtype);
 }
 
 static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
@@ -586,7 +663,11 @@ static inline Expr Tile(Expr data, Array<Integer> reps) {
   return Call(op, {data}, Attrs(attrs), {});
 }
 
-Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape);
+Expr MakeBroadCastTo(Expr data, Expr shape);
+
+static inline Expr BroadCastTo(Expr data, Array<IndexExpr> shape) {
+  return MakeBroadCastTo(data, CheckConstantShape(shape));
+}
 
 Expr MakeConcatenate(Expr data, int axis);
 
index 5e5542d..504c20a 100644 (file)
@@ -96,31 +96,48 @@ def test_any_broadcast_fail():
     check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add)
 
 
-def verify_any_full(x_shape, x_np_shape, relay_op, np_op, dtype='float32'):
+def verify_any_full_like(x_shape, x_np_shape, relay_op, np_op, dtype='float32'):
     x = relay.var('x', shape=x_shape, dtype=dtype)
     mod = tvm.IRModule()
-    mod['main'] = relay.Function([x], relay.zeros_like(x))
+    mod['main'] = relay.Function([x], relay_op(x))
     x_np = np.random.uniform(size=x_np_shape).astype(dtype)
-    res_np = np.zeros_like(x_np)
+    res_np = np_op(x_np)
+    for kind in ['debug', 'vm']:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm')
+        result = ex.evaluate()(x_np).asnumpy()
+        tvm.testing.assert_allclose(result, res_np)
+
+def test_any_full_like():
+    # zeros_like, ones_like
+    verify_any_full_like(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32")
+    verify_any_full_like(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32")
+    verify_any_full_like(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32")
+    verify_any_full_like(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32")
+    verify_any_full_like(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32")
+    verify_any_full_like(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32")
+
+def verify_any_full(x_np_shape, relay_op, np_op, dtype='float32', value=None):
+    x = relay.var('x', shape=(len(x_np_shape),), dtype="int32")
+    mod = tvm.IRModule()
+    out = relay_op(x, dtype) if value is None else relay_op(relay.expr.const(value), x, dtype)
+    mod['main'] = relay.Function([x], out)
+    res_np = np_op(x_np_shape) if value is None else np_op(x_np_shape, value)
+    x_np = np.array(x_np_shape).astype("int32")
     for kind in ['debug', 'vm']:
         ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm')
         result = ex.evaluate()(x_np).asnumpy()
         tvm.testing.assert_allclose(result, res_np)
 
 def test_any_full():
-    # zeros, zeros_like, ones, ones_like
-    verify_any_full(any_dims(3), (2, 3, 5), relay.zeros, np.zeros, "float32")
-    verify_any_full(any_dims(3), (225, 115, 15), relay.zeros, np.zeros, "float32")
-    verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32")
-    verify_any_full(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32")
-    verify_any_full(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32")
-    verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32")
-    verify_any_full(any_dims(3), (2, 3, 5), relay.ones, np.ones, "float32")
-    verify_any_full(any_dims(3), (225, 115, 15), relay.ones, np.ones, "float32")
-    verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones, np.ones, "int32")
-    verify_any_full(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32")
-    verify_any_full(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32")
-    verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32")
+    # zeros, ones, full
+    verify_any_full((2, 3, 5), relay.zeros, np.zeros, "float32")
+    verify_any_full((225, 115, 15), relay.zeros, np.zeros, "float32")
+    verify_any_full((10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32")
+    verify_any_full((2, 3, 5), relay.ones, np.ones, "float32")
+    verify_any_full((225, 115, 15), relay.ones, np.ones, "float32")
+    verify_any_full((10, 11, 12, 13, 14), relay.ones, np.ones, "int32")
+    verify_any_full((10, 11, 12, 13, 14), relay.full, np.full, "float32", 2.0)
+    verify_any_full((1, 2, 3, 4), relay.full, np.full, "int32", -2)
 
 def test_any_concat():
     x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
@@ -566,6 +583,37 @@ def test_any_softmax():
     verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3))
     verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1))
 
+def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):
+    mod = tvm.IRModule()
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    np_data = np.random.uniform(size=np_dshape).astype(dtype)
+    if const_k:
+        k = relay.const(kval)
+        args = [data]
+        in_vals = [np_data]
+    else:
+        k = relay.var('k', shape=(), dtype="int32")
+        args = [data, k]
+        in_vals = [np_data, kval]
+    out = relay.topk(data, k, ret_type="indices")
+    mod["main"] = relay.Function(args, out)
+
+    sorted = np.argsort(-np_data)
+    if len(np_dshape) == 2:
+        ref_out = sorted[:, 0:kval]
+    else:
+        ref_out = sorted[0:kval]
+
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(*in_vals)
+        tvm.testing.assert_allclose(result.asnumpy(), ref_out)
+
+def test_any_topk():
+    verify_any_topk(any_dims(1), 5, (10,), "float32")
+    verify_any_topk(any_dims(2), 2, (6, 3), "int32")
+    verify_any_topk(any_dims(2), 3, (6, 3), "float32", True)
+
 def test_fused_ops():
     x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32')
     y0 = x + relay.const(1.0, 'float32')
@@ -723,6 +771,7 @@ def test_mixed_input_type():
 
 if __name__ == "__main__":
     test_any_full()
+    test_any_full_like()
     test_any_broadcast()
     test_any_elemwise()
     test_any_broadcast_fail()
@@ -745,10 +794,10 @@ if __name__ == "__main__":
     test_any_dense()
     test_any_pad()
     test_any_softmax()
+    test_any_topk()
     test_fused_ops()
     test_arange_with_dynamic_shape()
     test_recursive_concat()
     test_recursive_concat_with_wrong_annotation()
     test_tuple_get_item()
     test_mixed_input_type()
-
index 744da62..e492d68 100644 (file)
@@ -107,7 +107,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
     data : tvm.te.Tensor
         The input tensor.
 
-    k : int, optional
+    k : int or tvm.te.Tensor, optional
         Number of top elements to select. Return all elements if k < 1.
 
     axis : int, optional
@@ -133,7 +133,10 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
     assert ret_type in ["both", "values", "indices"]
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
     out_shape = list(get_const_tuple(data.shape))
-    if k >= 1:
+    kvar = tvm.te.size_var("k")
+    if not isinstance(k, int):
+        out_shape[axis] = kvar
+    elif k >= 1:
         out_shape[axis] = k
     out_bufs = []
     if ret_type in ["both", "values"]:
@@ -142,10 +145,11 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
         out_bufs.append(tvm.tir.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8))
     out_shapes = [out_shape] * len(out_bufs)
 
+    kv = kvar if not isinstance(k, int) else k
     out = te.extern(out_shapes,
                     [data],
                     lambda ins, outs: tvm.tir.call_packed(
-                        "tvm.contrib.sort.topk", ins[0], *outs, k, axis, ret_type, is_ascend),
+                        "tvm.contrib.sort.topk", ins[0], *outs, kv, axis, ret_type, is_ascend),
                     in_buffers=[data_buf],
                     out_buffers=out_bufs,
                     name="topk_cpu",