From cc79591f5e0f05955fd3180b1975e4344b532345 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Mon, 25 May 2020 18:09:44 -0700 Subject: [PATCH] [Relay][Op]Support symbolic TopK, Ones, Zeros and Full (#5459) * Support symbolic TopK, Ones, Zeros and Full * Fix pylint * Add docstring for topk shape func * Fix grad * Fix lazy_gradient_init * Fix parser * Fix print ir text * Fix lint * Improve pattern_util * Fix topk * Fix build * Use Optional for attribute * Fix clang-format * Minot fix * Fix pylint * Fix build warning * Fix parser * Move ToScalar * Fix lint * Fix lint * Make topk shape func as data independent when k is constant. * Fix lint * Minor fix --- include/tvm/relay/attrs/algorithm.h | 5 +- include/tvm/relay/attrs/transform.h | 2 +- include/tvm/runtime/ndarray.h | 6 +- python/tvm/relay/_parser.py | 2 + python/tvm/relay/op/_algorithm.py | 68 ++++++++++++ python/tvm/relay/op/_tensor.py | 41 ++++---- python/tvm/relay/op/_tensor_grad.py | 8 +- python/tvm/relay/op/_transform.py | 2 + python/tvm/relay/op/algorithm.py | 9 +- python/tvm/relay/op/strategy/generic.py | 4 +- python/tvm/relay/op/tensor.py | 10 +- python/tvm/relay/op/transform.py | 8 +- src/relay/analysis/util.cc | 8 ++ src/relay/op/algorithm/topk.cc | 32 ++++-- src/relay/op/image/resize.cc | 4 +- src/relay/op/tensor/transform.cc | 161 +++++++++++++++-------------- src/relay/op/tensor/transform.h | 43 ++++---- src/relay/qnn/util.cc | 4 +- src/relay/transforms/lazy_gradient_init.cc | 8 +- src/relay/transforms/pattern_util.h | 103 ++++++++++++++++-- tests/python/relay/test_any.py | 83 ++++++++++++--- topi/python/topi/sort.py | 10 +- 22 files changed, 435 insertions(+), 186 deletions(-) diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index a7d4708..83b4dda 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -26,6 +26,7 @@ #include #include +#include #include @@ -52,14 +53,14 @@ struct ArgsortAttrs : public tvm::AttrsNode { }; struct TopKAttrs : public tvm::AttrsNode { - int k; + Optional k; int axis; bool is_ascend; std::string ret_type; DataType dtype; TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") { - TVM_ATTR_FIELD(k).set_default(1).describe("Number of top elements to select"); + TVM_ATTR_FIELD(k).describe("Number of top elements to select"); TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor."); TVM_ATTR_FIELD(ret_type).set_default("both").describe( "The return type [both, values, indices]." diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 7fb7f3a..ccf8e54 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -111,7 +111,7 @@ struct TakeAttrs : public tvm::AttrsNode { /*! \brief Attributes that specify a tensor */ struct InitOpAttrs : public tvm::AttrsNode { - Array shape; + Optional> shape; DataType dtype; TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") { diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 0171d8a..e69d802 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -462,7 +462,11 @@ inline bool NDArray::Load(dmlc::Stream* strm) { int64_t data_byte_size; CHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format"; CHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format"; - CHECK(strm->Read(ret->data, data_byte_size)) << "Invalid DLTensor file format"; + auto read_ret = strm->Read(ret->data, data_byte_size); + // Only check non-empty data + if (ndim > 0 && shape[0] != 0) { + CHECK(read_ret) << "Invalid DLTensor file format"; + } if (!DMLC_IO_NO_ENDIAN_SWAP) { dmlc::ByteSwap(ret->data, elem_bytes, num_elems); } diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 1d97b55..49f2d4d 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -116,6 +116,8 @@ class FuncOp(OpWrapper): attrs = {} if self.operator is op.reshape: x = self.operator(*args) + elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to): + x = self.operator(*args, dtype=attrs["dtype"]) else: x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()}) if isinstance(x, expr.TupleWrapper): diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index e1e6fd3..5a20480 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -18,7 +18,11 @@ # pylint: disable=invalid-name,unused-argument from __future__ import absolute_import +from tvm.te.hybrid import script +from tvm.runtime import convert + from . import strategy +from . import op as _reg from .op import OpPattern, register_pattern from .op import register_strategy @@ -29,3 +33,67 @@ register_pattern("argsort", OpPattern.OPAQUE) # topk register_strategy("topk", strategy.topk_strategy) register_pattern("topk", OpPattern.OPAQUE) + +@script +def _topk_shape_func_input_data(data, k, axis): + ndim = len(data.shape) + val_out = output_tensor((ndim,), "int64") + indices_out = output_tensor((ndim,), "int64") + + for i in const_range(ndim): + if i != axis: + val_out[i] = int64(data.shape[i]) + indices_out[i] = int64(data.shape[i]) + else: + if k[0] < 1: + val_out[i] = int64(data.shape[i]) + indices_out[i] = int64(data.shape[i]) + else: + val_out[i] = int64(k[0]) + indices_out[i] = int64(k[0]) + return val_out, indices_out + +@script +def _topk_shape_func_input_shape(data_shape, k, axis): + ndim = data_shape.shape[0] + val_out = output_tensor((ndim,), "int64") + indices_out = output_tensor((ndim,), "int64") + + for i in const_range(ndim): + if i != axis: + val_out[i] = int64(data_shape[i]) + indices_out[i] = int64(data_shape[i]) + else: + if k < 1: + val_out[i] = int64(data_shape[i]) + indices_out[i] = int64(data_shape[i]) + else: + val_out[i] = int64(k) + indices_out[i] = int64(k) + return val_out, indices_out + +@_reg.register_shape_func("topk", True) +def topk_shape_func(attrs, inputs, _): + """ + Shape func for topk. + """ + axis = attrs.axis + if attrs.k is not None: + if axis < 0: + axis += inputs[0].shape[0] + val_out, indices_out = \ + _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis)) + else: + if axis < 0: + axis += len(inputs[0].shape) + val_out, indices_out = \ + _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis)) + ret_type = attrs.ret_type + if ret_type == "both": + ret = [val_out, indices_out] + elif ret_type == "values": + ret = [val_out] + else: + ret = [indices_out] + + return ret diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index e029e0c..cd9e4ed 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -17,10 +17,9 @@ #pylint: disable=invalid-name, unused-argument, len-as-condition """Backend compiler related feature registration""" -from tvm.runtime import convert from tvm.te.hybrid import script import topi -from topi.util import get_const_tuple + from .op import register_compute, register_shape_func from .op import register_broadcast_schedule, register_injective_schedule from .op import register_pattern, OpPattern @@ -93,7 +92,7 @@ register_broadcast_schedule("fast_erf") # zeros @register_compute("zeros") def zeros_compute(attrs, inputs, output_type): - assert not inputs + assert len(inputs) == 1 return [topi.full(output_type.shape, output_type.dtype, 0.0)] register_broadcast_schedule("zeros") @@ -110,7 +109,7 @@ register_broadcast_schedule("zeros_like") # ones @register_compute("ones") def ones_compute(attrs, inputs, output_type): - assert not inputs + assert len(inputs) == 1 return [topi.full(output_type.shape, output_type.dtype, 1.0)] register_broadcast_schedule("ones") @@ -132,20 +131,10 @@ def clip_compute(attrs, inputs, output_type): register_injective_schedule("clip") -@script -def _cast_shape_function(x): - out_ndim = len(x) - out = output_tensor((out_ndim,), "int64") - for i in const_range(out_ndim): - out[i] = x[i] - return out - -def cast_shape_func(attrs, inputs, out_ndims): - return [_cast_shape_function(*inputs)] - +# full @script def _full_shape_func(shape): - out_ndim = len(shape) + out_ndim = shape.shape[0] out = output_tensor((out_ndim,), "int64") for i in const_range(out_ndim): out[i] = int64(shape[i]) @@ -153,10 +142,15 @@ def _full_shape_func(shape): def full_shape_func(attrs, inputs, out_ndims): """ - Shape func for zeros, zeros_like, ones, ones_like. + Shape func for full. + """ + return [_full_shape_func(inputs[1])] + +def no_data_full_shape_func(attrs, inputs, out_ndims): + """ + Shape func for zeros and ones. """ - shape = get_const_tuple(attrs.shape) - return [_full_shape_func(convert(shape))] + return [_full_shape_func(inputs[0])] @script def _broadcast_shape_func(x, y, ndim): @@ -198,13 +192,14 @@ def elemwise_shape_func(attrs, inputs, _): """ return [topi.math.identity(inputs[0])] -register_shape_func("cast", False, cast_shape_func) -register_shape_func("zeros", False, full_shape_func) +register_shape_func("cast", False, elemwise_shape_func) +register_shape_func("zeros", True, no_data_full_shape_func) register_shape_func("zeros_like", False, elemwise_shape_func) -register_shape_func("ones", False, full_shape_func) +register_shape_func("ones", True, no_data_full_shape_func) register_shape_func("ones_like", False, elemwise_shape_func) -register_shape_func("full", False, full_shape_func) +register_shape_func("full", True, full_shape_func) register_shape_func("full_like", False, elemwise_shape_func) +register_shape_func("broadcast_to", True, full_shape_func) register_shape_func("add", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 8be3358..8ba1020 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -232,14 +232,14 @@ def divide_grad(orig, grad): @register_gradient("zeros") def zeros_grad(orig, grad): - """Returns []""" - return [] + """Returns [shape]""" + return [orig.args[0]] @register_gradient("ones") def ones_grad(orig, grad): - """Returns []""" - return [] + """Returns [shape]""" + return [orig.args[0]] @register_gradient("zeros_like") diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 43d8d62..e1c2bd7 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -120,6 +120,8 @@ def _concatenate_shape_func(inputs, axis): @_reg.register_shape_func("concatenate", False) def concatenate_shape_func(attrs, inputs, _): axis = get_const_int(attrs.axis) + if axis < 0: + axis += inputs[0].shape[0] return [_concatenate_shape_func(inputs, convert(axis))] @script diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 17fab80..d31e89a 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,7 +17,7 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs from . import _make -from ..expr import TupleWrapper +from ..expr import TupleWrapper, const def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies @@ -48,7 +48,8 @@ def argsort(data, axis=-1, is_ascend=1, dtype="int32"): return _make.argsort(data, axis, is_ascend, dtype) -def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): +def topk(data, k=1, axis=-1, ret_type="both", + is_ascend=False, dtype="int32"): """Get the top k elements in an input tensor along the given axis. ret_type specifies the return type, can be one of ("both", "values", "indices"). @@ -58,7 +59,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): data : relay.Expr The input data tensor. - k : int, optional + k : int or relay.Expr, optional Number of top elements to select. Return all elements if k < 1. axis : int, optional @@ -81,6 +82,8 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): out : relay.Expr or List[relay.Expr] The computed result. """ + if isinstance(k, int): + k = const(k, "int64") out = _make.topk(data, k, axis, ret_type, is_ascend, dtype) if ret_type == "both": return TupleWrapper(out, 2) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 6db5b14..99439af 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -598,7 +598,9 @@ def argsort_strategy(attrs, inputs, out_type, target): def wrap_compute_topk(topi_compute): """Wrap topk compute""" def _compute_topk(attrs, inputs, out_type): - k = get_const_int(attrs.k) + k = inputs[1] + if attrs.k is not None: + k = attrs.k axis = get_const_int(attrs.axis) ret_type = attrs.ret_type is_ascend = bool(get_const_int(attrs.is_ascend)) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index d5ae5cd..c60dbee 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -20,7 +20,7 @@ from tvm.runtime import ndarray as _nd from tvm.runtime import TVMContext as _TVMContext from . import _make -from ..expr import Tuple +from ..expr import Tuple, const # We create a wrapper function for each operator in the @@ -928,7 +928,7 @@ def zeros(shape, dtype): Parameters ---------- - shape : tuple of int + shape : tuple of int or relay.Expr The shape of the target. dtype : data type @@ -939,6 +939,8 @@ def zeros(shape, dtype): result : relay.Expr The resulting tensor. """ + if isinstance(shape, (list, tuple)): + shape = const(list(shape), "int32") return _make.zeros(shape, dtype) @@ -963,7 +965,7 @@ def ones(shape, dtype): Parameters ---------- - shape : tuple of int + shape : tuple of int or relay.Expr The shape of the target. dtype : data type @@ -974,6 +976,8 @@ def ones(shape, dtype): result : relay.Expr The resulting tensor. """ + if isinstance(shape, (list, tuple)): + shape = const(list(shape), "int32") return _make.ones(shape, dtype) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 2d9e4ba..1da58ae 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -299,7 +299,7 @@ def full(fill_value, shape=(), dtype=""): fill_value : relay.Expr The value to fill. Must be a scalar. - shape : tuple of int + shape : tuple of int or relay.Expr The shape of the target. dtype : data type, optional (defaults to data type of the fill value) @@ -310,6 +310,8 @@ def full(fill_value, shape=(), dtype=""): result : relay.Expr The resulting tensor. """ + if isinstance(shape, (list, tuple)): + shape = const(list(shape), "int32") return _make.full(fill_value, shape, dtype) @@ -527,7 +529,7 @@ def broadcast_to(data, shape): data : relay.Expr The input tensor. - shape : shape + shape : tuple of int or relay.Expr Provide the shape to broadcast to. Returns @@ -535,6 +537,8 @@ def broadcast_to(data, shape): result : relay.Expr The resulting tensor. """ + if isinstance(shape, (list, tuple)): + shape = const(list(shape), "int32") return _make.broadcast_to(data, shape) def broadcast_to_like(data, broadcast_type): diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index a05bb8f..2853165 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -25,6 +25,7 @@ */ #include #include +#include #include #include #include @@ -450,6 +451,13 @@ bool IsDataDependant(const CallNode* call) { return false; } } + } else if (op->name == "topk") { + if (const auto* attrs = call->attrs.as()) { + if (attrs->k) { + // If k attribute exists, it isn't data dependant. + return false; + } + } } return tshape_data_dependant[op]; diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index 5ff5904..3db8eee 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -23,9 +23,11 @@ */ #include #include +#include namespace tvm { namespace relay { +using tir::make_const; TVM_REGISTER_NODE_TYPE(TopKAttrs); @@ -33,7 +35,7 @@ bool TopKRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] const TopKAttrs* param = attrs.as(); - CHECK_EQ(types.size(), 2); + CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); CHECK(data); int ndim = data->shape.size(); @@ -44,35 +46,44 @@ bool TopKRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(axis >= 0 && axis < ndim); Array out_shape; for (int i = 0; i < ndim; ++i) { - if (i != axis || param->k < 1) { + if (i != axis) { out_shape.push_back(data->shape[i]); + } else if (param->k) { + const Integer& ck = param->k.value(); + if (ck->value < 1) { + out_shape.push_back(data->shape[i]); + } else { + out_shape.push_back(ck); + } } else { - out_shape.push_back(param->k); + out_shape.push_back(Any::make()); } } auto values_ty = TensorType(out_shape, data->dtype); auto indices_ty = TensorType(out_shape, param->dtype); if (param->ret_type == "both") { - reporter->Assign(types[1], TupleType({values_ty, indices_ty})); + reporter->Assign(types[2], TupleType({values_ty, indices_ty})); } else if (param->ret_type == "values") { - reporter->Assign(types[1], values_ty); + reporter->Assign(types[2], values_ty); } else if (param->ret_type == "indices") { - reporter->Assign(types[1], indices_ty); + reporter->Assign(types[2], indices_ty); } else { LOG(FATAL) << "Unsupported ret type: " << param->ret_type; } return true; } -Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype) { +Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) { auto attrs = make_object(); - attrs->k = k; + if (const auto& ck = k.as()) { + attrs->k = tvm::Integer(reinterpret_cast(ck->data->data)[0]); + } attrs->axis = axis; attrs->ret_type = ret_type; attrs->is_ascend = is_ascend; attrs->dtype = dtype; static const Op& op = Op::Get("topk"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(op, {data, k}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK); @@ -80,9 +91,10 @@ TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK); RELAY_REGISTER_OP("topk") .describe(R"doc(Get the top k elements in an input tensor along the given axis. )doc" TVM_ADD_FILELINE) - .set_num_inputs(1) + .set_num_inputs(2) .set_attrs_type() .add_argument("data", "Tensor", "Input data.") + .add_argument("k", "Tensor", "Number of top elements.") .set_support_level(6) .add_type_rel("TopK", TopKRel); diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index 7bddb29..b6d2c71 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -194,12 +194,12 @@ bool CropAndResizeRel(const Array& types, int num_inputs, const Attrs& att const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(0, box_indices->shape[0]); + oshape.Set(0, boxes->shape[0]); oshape.Set(2, crop_size[0]); oshape.Set(3, crop_size[1]); auto bshape = layout_converter.BackwardShape(oshape); // assign output type - reporter->Assign(types[3], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); + reporter->Assign(types[3], TensorType(bshape, out_dtype)); return true; } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 6ccf585..7282ac7 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -447,44 +447,6 @@ RELAY_REGISTER_OP("transpose") /* relay.reshape */ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); -double ToScalar(const runtime::NDArray& array, int i = 0) { - if (array->dtype.code == kDLInt) { - if (array->dtype.bits == 8) { - return reinterpret_cast(array->data)[i]; - } else if (array->dtype.bits == 16) { - return reinterpret_cast(array->data)[i]; - } else if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[i]; - } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[i]; - } - } else if (array->dtype.code == kDLUInt) { - if (array->dtype.bits == 8) { - return reinterpret_cast(array->data)[i]; - } else if (array->dtype.bits == 16) { - return reinterpret_cast(array->data)[i]; - } else if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[i]; - } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[i]; - } - } else if (array->dtype.code == kDLFloat) { -#if (__ARM_FP16_FORMAT_IEEE == 1) - if (array->dtype.bits == 16) { - return reinterpret_cast<__fp16*>(array->data)[i]; - } -#endif - if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[i]; - } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[i]; - } - } - LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); - // make compiler happy - return -std::numeric_limits::infinity(); -} - bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { const auto* param = attrs.as(); @@ -663,11 +625,7 @@ Expr MakeReshape(Expr data, Expr newshape) { auto attrs = make_object(); if (const ConstantNode* c = newshape.as()) { CHECK_EQ(c->data->ndim, 1); - Array newshape; - for (int i = 0; i < c->data->shape[0]; i++) { - newshape.push_back(Integer(static_cast(ToScalar(c->data, i)))); - } - attrs->newshape = newshape; + attrs->newshape = ToVector(c->data); } attrs->reverse = false; static const Op& op = Op::Get("reshape"); @@ -929,9 +887,10 @@ TVM_REGISTER_NODE_TYPE(InitOpAttrs); bool FullRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); + CHECK_EQ(types.size(), 3); const InitOpAttrs* param = attrs.as(); const auto* fill_value = types[0].as(); + const auto* fill_shape = types[1].as(); if (fill_value == nullptr) { return false; } @@ -944,7 +903,21 @@ bool FullRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(fill_value->shape.size(), 0) << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << "."; - reporter->Assign(types[1], TensorType(param->shape, out_dtype)); + const IntImmNode* shape_shape = fill_shape->shape[0].as(); + CHECK(shape_shape) << "Parameter shape must have static shape"; + + std::vector oshape; + if (param->shape) { + const Array& cshape_array = param->shape.value(); + for (size_t i = 0; i < cshape_array.size(); ++i) { + oshape.push_back(cshape_array[i]); + } + } else { + for (int i = 0; i < shape_shape->value; ++i) { + oshape.push_back(Any::make()); + } + } + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } @@ -954,12 +927,14 @@ Array FullCompute(const Attrs& attrs, const Array& input return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())}; } -Expr MakeFull(Expr fill_value, Array shape, DataType dtype) { +Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) { auto attrs = make_object(); - attrs->shape = std::move(shape); + if (const auto* cshape = shape.as()) { + attrs->shape = ToVector(cshape->data); + } attrs->dtype = std::move(dtype); static const Op& op = Op::Get("full"); - return Call(op, {fill_value}, Attrs(attrs), {}); + return Call(op, {fill_value, shape}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull); @@ -969,8 +944,9 @@ RELAY_REGISTER_OP("full") )code" TVM_ADD_FILELINE) .set_attrs_type() - .set_num_inputs(1) + .set_num_inputs(2) .add_argument("fill_value", "double", "The value to fill.") + .add_argument("shape", "Tensor", "Target shape.") .set_support_level(3) .add_type_rel("Full", FullRel) .set_attr("FTVMCompute", FullCompute) @@ -978,19 +954,37 @@ RELAY_REGISTER_OP("full") bool InitOpRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 1); + CHECK_EQ(types.size(), 2); const InitOpAttrs* param = attrs.as(); + const auto* fill_shape = types[0].as(); + DataType out_dtype = param->dtype; + + const IntImmNode* shape_shape = fill_shape->shape[0].as(); + CHECK(shape_shape) << "Parameter shape must have static shape"; - reporter->Assign(types[0], TensorType(param->shape, param->dtype)); + std::vector oshape; + if (param->shape) { + const Array& cshape_array = param->shape.value(); + for (size_t i = 0; i < cshape_array.size(); ++i) { + oshape.push_back(cshape_array[i]); + } + } else { + for (int i = 0; i < shape_shape->value; ++i) { + oshape.push_back(Any::make()); + } + } + reporter->Assign(types[1], TensorType(oshape, out_dtype)); return true; } -Expr MakeZeros(Array shape, DataType dtype) { +Expr MakeZeros(Expr shape, DataType dtype) { auto attrs = make_object(); - attrs->shape = std::move(shape); + if (const auto* cshape = shape.as()) { + attrs->shape = ToVector(cshape->data); + } attrs->dtype = std::move(dtype); static const Op& op = Op::Get("zeros"); - return Call(op, {}, Attrs(attrs), {}); + return Call(op, {shape}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.zeros").set_body_typed(MakeZeros); @@ -1000,16 +994,19 @@ RELAY_REGISTER_OP("zeros") )code" TVM_ADD_FILELINE) .set_attrs_type() - .set_num_inputs(0) + .set_num_inputs(1) + .add_argument("shape", "Tensor", "Target shape.") .set_support_level(3) .add_type_rel("InitOp", InitOpRel); -Expr MakeOnes(Array shape, DataType dtype) { +Expr MakeOnes(Expr shape, DataType dtype) { auto attrs = make_object(); - attrs->shape = std::move(shape); + if (const auto* cshape = shape.as()) { + attrs->shape = ToVector(cshape->data); + } attrs->dtype = std::move(dtype); static const Op& op = Op::Get("ones"); - return Call(op, {}, Attrs(attrs), {}); + return Call(op, {shape}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.ones").set_body_typed(MakeOnes); @@ -1019,7 +1016,8 @@ RELAY_REGISTER_OP("ones") )code" TVM_ADD_FILELINE) .set_attrs_type() - .set_num_inputs(0) + .set_num_inputs(1) + .add_argument("shape", "Tensor", "Target shape.") .set_support_level(3) .add_type_rel("InitOp", InitOpRel); @@ -1579,30 +1577,42 @@ RELAY_REGISTER_OP("collapse_sum_like") // BroadCastTo: -> B where BroadCast(A, B) = B bool BroadCastToRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - auto ioattrs = attrs.as(); - CHECK(ioattrs); - auto intt = types[0].as(); - if (intt == nullptr) { - return false; + CHECK_EQ(types.size(), 3); + const InitOpAttrs* param = attrs.as(); + const auto* target_shape = types[1].as(); + DataType out_dtype = types[0].as()->dtype; + + const IntImmNode* shape_shape = target_shape->shape[0].as(); + CHECK(shape_shape) << "Parameter shape must have static shape"; + + std::vector oshape; + if (param->shape) { + const Array& cshape_array = param->shape.value(); + for (size_t i = 0; i < cshape_array.size(); ++i) { + oshape.push_back(cshape_array[i]); + } + } else { + for (int i = 0; i < shape_shape->value; ++i) { + oshape.push_back(Any::make()); + } } - auto type = TensorType(ioattrs->shape, intt->dtype); - reporter->Assign(types[1], type); - return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return BroadcastRel({types[0], types[2], types[2]}, 2, Attrs(), reporter); } -Expr MakeBroadCastTo(Expr data, Array shape) { +Expr MakeBroadCastTo(Expr data, Expr shape) { static const Op& op = Op::Get("broadcast_to"); auto attrs = make_object(); - attrs->shape = std::move(shape); - return Call(op, {data}, Attrs(attrs), {}); + if (const auto* cshape = shape.as()) { + attrs->shape = ToVector(cshape->data); + } + return Call(op, {data, shape}, Attrs(attrs), {}); } Array BroadCastToCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - auto ioattrs = attrs.as(); - CHECK(ioattrs != nullptr); - return {topi::broadcast_to(inputs[0], ioattrs->shape)}; + const auto* out_ttype = out_type.as(); + return {topi::broadcast_to(inputs[0], out_ttype->shape)}; } TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastTo); @@ -1610,8 +1620,9 @@ TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastT RELAY_REGISTER_OP("broadcast_to") .describe(R"code(Broadcast the first input to match the shape argument. )code" TVM_ADD_FILELINE) - .set_num_inputs(1) + .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape", "Tensor", "Target shape.") .set_support_level(4) .add_type_rel("BroadCastTo", BroadCastToRel) .set_attr("FTVMCompute", BroadCastToCompute) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index bc35ed6..1f30b68 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -90,34 +90,33 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs if (e_dtype != dtype) { throw Error("relay.concatenate requires all tensors have the same dtype"); } - for (size_t j = 0; j < first->shape.size(); ++j) { - if (j == static_cast(axis)) continue; - if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue; - throw Error( - "relay.concatenate requires all tensors have the same shape " - "on non-concatenating axes"); - } } // Calculate shape std::vector oshape(first->shape.begin(), first->shape.end()); - IndexExpr& concat_dim = oshape[axis]; - bool has_any = false; - if (concat_dim.as()) { - has_any = true; - } else { - for (int i = 1; i < static_cast(tensor_tuple->fields.size()); ++i) { - const auto& e = Downcast(tensor_tuple->fields[i]); - if (e->shape[axis].as()) { - has_any = true; - break; + int data_length = static_cast(tensor_tuple->fields.size()); + for (int i = 0; i < ndim; ++i) { + std::vector non_any; + for (int j = 0; j < data_length; ++j) { + const auto& e = Downcast(tensor_tuple->fields[j]); + if (!e->shape[i].as()) { + non_any.push_back(e->shape[i]); + // accumulate axis dimension + if (j > 0 && i == axis && !oshape[i].as()) { + oshape[i] += e->shape[i]; + } + } + } + int non_any_size = static_cast(non_any.size()); + if (non_any_size != data_length) oshape[i] = Any::make(); + if (i != axis) { + for (int k = 1; k < non_any_size; k++) { + if (reporter->AssertEQ(non_any[0], non_any[k])) continue; + throw Error( + "relay.concatenate requires all tensors have the same shape " + "on non-concatenating axes"); } - concat_dim += e->shape[axis]; } - } - - if (has_any) { - concat_dim = Any::make(); } auto rtype = TensorType(oshape, dtype); diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 7171ded..4daa5c9 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -202,8 +202,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, round_scalar = exp_pos_rounding_value_expr; } else if (rounding == "TONEAREST") { // To satisfy where op shape requirements, the rounding values are broadcasted. - auto pos_rounder = MakeBroadCastTo(exp_pos_rounding_value_expr, input_shape); - auto neg_rounder = MakeBroadCastTo(exp_neg_rounding_value_expr, input_shape); + auto pos_rounder = BroadCastTo(exp_pos_rounding_value_expr, input_shape); + auto neg_rounder = BroadCastTo(exp_neg_rounding_value_expr, input_shape); auto zero_t = Zeros(input_shape, hp_dtype); round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder, neg_rounder); diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index 3cd29d6..f062466 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -203,9 +203,9 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator { } if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) { - // fn() -> T, function returns result of the operation - Expr func = - Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); + // ones and zeros need TensorType input + Expr result = CallPrimitiveOp(call_node); + Expr func = Function({}, result, {call_node->checked_type()}, Array()); // call appropriate GradCell constructor std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero"; return Call(module_->GetConstructor("GradCell", constructor_name), {func}, Attrs(), @@ -288,7 +288,7 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator { args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } // result of operation - return Call(call_node->op, args); + return Call(call_node->op, args, call_node->attrs); } }; diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 8f37e7c..06b1e82 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -37,6 +37,7 @@ #include #include +#include #include #include #include @@ -311,6 +312,25 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector s } /*! + * \brief Check whether a shape is static and create corresponding Constant. + * + * \param shape The Array of the shape values. + * \return A Constant. + */ +static inline Constant CheckConstantShape(const Array& shape) { + auto shape_array = + runtime::NDArray::Empty({int64_t(shape.size())}, DataType::Int(64), {kDLCPU, 0}); + auto* shape_data = static_cast(shape_array->data); + for (size_t i = 0; i < shape.size(); ++i) { + const auto& dim_val = shape[i].as(); + CHECK(dim_val) << "Do not support symbolic shape for " + "Array format. Pass shape as Expr instead."; + shape_data[i] = dim_val->value; + } + return Constant(shape_array); +} + +/*! * \brief Check if two expressions are equal scalars. * \param a The expression to be checked. * \param b The expression to be checked @@ -325,6 +345,67 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { return tvm::StructuralEqual()(a, b); } +/*! + * \brief Convert an element of a NDArray with type int or float to scalar. + * \param array Input NDArray + * \param i element index + * \return Converted scalar value. + */ +static inline double ToScalar(const runtime::NDArray& array, size_t i = 0) { + if (array->dtype.code == kDLInt) { + if (array->dtype.bits == 8) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 16) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 32) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 64) { + return reinterpret_cast(array->data)[i]; + } + } else if (array->dtype.code == kDLUInt) { + if (array->dtype.bits == 8) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 16) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 32) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 64) { + return reinterpret_cast(array->data)[i]; + } + } else if (array->dtype.code == kDLFloat) { +#if (__ARM_FP16_FORMAT_IEEE == 1) + if (array->dtype.bits == 16) { + return reinterpret_cast<__fp16*>(array->data)[i]; + } +#endif + if (array->dtype.bits == 32) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 64) { + return reinterpret_cast(array->data)[i]; + } + } + LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); + // make compiler happy + return -std::numeric_limits::infinity(); +} + +/*! + * \brief Convert a NDArray with type int or float to Array. + * \param array Input NDArray + * \return Converted Array. + */ +static inline Array ToVector(const runtime::NDArray& array) { + size_t ndim = array.Shape().size(); + CHECK_EQ(ndim, 1) << "This function should only used for shape tensor."; + size_t len = array.Shape().front(); + Array out; + for (size_t i = 0; i < len; ++i) { + double elem_val = ToScalar(array, i); + out.push_back(Integer(static_cast(elem_val))); + } + return out; +} + inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); } inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); } @@ -432,12 +513,10 @@ inline Expr ZerosLike(Expr e) { return Call(op, {e}); } +Expr MakeZeros(Expr shape, DataType dtype); + inline Expr Zeros(Array shape, DataType dtype) { - auto attrs = make_object(); - attrs->shape = std::move(shape); - attrs->dtype = std::move(dtype); - static const Op& op = Op::Get("zeros"); - return Call(op, {}, Attrs(attrs), {}); + return MakeZeros(CheckConstantShape(shape), dtype); } inline Expr OnesLike(Expr e) { @@ -503,12 +582,10 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { return Call(op, {lhs, rhs}, Attrs(), {}); } +Expr MakeFull(Expr fill_value, Expr shape, DataType dtype); + static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { - auto attrs = make_object(); - attrs->shape = std::move(shape); - attrs->dtype = std::move(dtype); - static const Op& op = Op::Get("full"); - return Call(op, {fill_value}, Attrs(attrs), {}); + return MakeFull(fill_value, CheckConstantShape(shape), dtype); } static inline Expr Conv2D(Expr data, Expr weight, Array strides, @@ -586,7 +663,11 @@ static inline Expr Tile(Expr data, Array reps) { return Call(op, {data}, Attrs(attrs), {}); } -Expr MakeBroadCastTo(Expr data, Array shape); +Expr MakeBroadCastTo(Expr data, Expr shape); + +static inline Expr BroadCastTo(Expr data, Array shape) { + return MakeBroadCastTo(data, CheckConstantShape(shape)); +} Expr MakeConcatenate(Expr data, int axis); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 5e5542d..504c20a 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -96,31 +96,48 @@ def test_any_broadcast_fail(): check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add) -def verify_any_full(x_shape, x_np_shape, relay_op, np_op, dtype='float32'): +def verify_any_full_like(x_shape, x_np_shape, relay_op, np_op, dtype='float32'): x = relay.var('x', shape=x_shape, dtype=dtype) mod = tvm.IRModule() - mod['main'] = relay.Function([x], relay.zeros_like(x)) + mod['main'] = relay.Function([x], relay_op(x)) x_np = np.random.uniform(size=x_np_shape).astype(dtype) - res_np = np.zeros_like(x_np) + res_np = np_op(x_np) + for kind in ['debug', 'vm']: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm') + result = ex.evaluate()(x_np).asnumpy() + tvm.testing.assert_allclose(result, res_np) + +def test_any_full_like(): + # zeros_like, ones_like + verify_any_full_like(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32") + verify_any_full_like(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32") + verify_any_full_like(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32") + verify_any_full_like(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32") + verify_any_full_like(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32") + verify_any_full_like(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32") + +def verify_any_full(x_np_shape, relay_op, np_op, dtype='float32', value=None): + x = relay.var('x', shape=(len(x_np_shape),), dtype="int32") + mod = tvm.IRModule() + out = relay_op(x, dtype) if value is None else relay_op(relay.expr.const(value), x, dtype) + mod['main'] = relay.Function([x], out) + res_np = np_op(x_np_shape) if value is None else np_op(x_np_shape, value) + x_np = np.array(x_np_shape).astype("int32") for kind in ['debug', 'vm']: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm') result = ex.evaluate()(x_np).asnumpy() tvm.testing.assert_allclose(result, res_np) def test_any_full(): - # zeros, zeros_like, ones, ones_like - verify_any_full(any_dims(3), (2, 3, 5), relay.zeros, np.zeros, "float32") - verify_any_full(any_dims(3), (225, 115, 15), relay.zeros, np.zeros, "float32") - verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32") - verify_any_full(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32") - verify_any_full(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32") - verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32") - verify_any_full(any_dims(3), (2, 3, 5), relay.ones, np.ones, "float32") - verify_any_full(any_dims(3), (225, 115, 15), relay.ones, np.ones, "float32") - verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones, np.ones, "int32") - verify_any_full(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32") - verify_any_full(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32") - verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32") + # zeros, ones, full + verify_any_full((2, 3, 5), relay.zeros, np.zeros, "float32") + verify_any_full((225, 115, 15), relay.zeros, np.zeros, "float32") + verify_any_full((10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32") + verify_any_full((2, 3, 5), relay.ones, np.ones, "float32") + verify_any_full((225, 115, 15), relay.ones, np.ones, "float32") + verify_any_full((10, 11, 12, 13, 14), relay.ones, np.ones, "int32") + verify_any_full((10, 11, 12, 13, 14), relay.full, np.full, "float32", 2.0) + verify_any_full((1, 2, 3, 4), relay.full, np.full, "int32", -2) def test_any_concat(): x = relay.var('x', shape=(relay.Any(), 2), dtype="float32") @@ -566,6 +583,37 @@ def test_any_softmax(): verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3)) verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1)) +def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): + mod = tvm.IRModule() + data = relay.var('data', shape=data_shape, dtype=dtype) + np_data = np.random.uniform(size=np_dshape).astype(dtype) + if const_k: + k = relay.const(kval) + args = [data] + in_vals = [np_data] + else: + k = relay.var('k', shape=(), dtype="int32") + args = [data, k] + in_vals = [np_data, kval] + out = relay.topk(data, k, ret_type="indices") + mod["main"] = relay.Function(args, out) + + sorted = np.argsort(-np_data) + if len(np_dshape) == 2: + ref_out = sorted[:, 0:kval] + else: + ref_out = sorted[0:kval] + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(*in_vals) + tvm.testing.assert_allclose(result.asnumpy(), ref_out) + +def test_any_topk(): + verify_any_topk(any_dims(1), 5, (10,), "float32") + verify_any_topk(any_dims(2), 2, (6, 3), "int32") + verify_any_topk(any_dims(2), 3, (6, 3), "float32", True) + def test_fused_ops(): x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32') y0 = x + relay.const(1.0, 'float32') @@ -723,6 +771,7 @@ def test_mixed_input_type(): if __name__ == "__main__": test_any_full() + test_any_full_like() test_any_broadcast() test_any_elemwise() test_any_broadcast_fail() @@ -745,10 +794,10 @@ if __name__ == "__main__": test_any_dense() test_any_pad() test_any_softmax() + test_any_topk() test_fused_ops() test_arange_with_dynamic_shape() test_recursive_concat() test_recursive_concat_with_wrong_annotation() test_tuple_get_item() test_mixed_input_type() - diff --git a/topi/python/topi/sort.py b/topi/python/topi/sort.py index 744da62..e492d68 100644 --- a/topi/python/topi/sort.py +++ b/topi/python/topi/sort.py @@ -107,7 +107,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): data : tvm.te.Tensor The input tensor. - k : int, optional + k : int or tvm.te.Tensor, optional Number of top elements to select. Return all elements if k < 1. axis : int, optional @@ -133,7 +133,10 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): assert ret_type in ["both", "values", "indices"] data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) out_shape = list(get_const_tuple(data.shape)) - if k >= 1: + kvar = tvm.te.size_var("k") + if not isinstance(k, int): + out_shape[axis] = kvar + elif k >= 1: out_shape[axis] = k out_bufs = [] if ret_type in ["both", "values"]: @@ -142,10 +145,11 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): out_bufs.append(tvm.tir.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8)) out_shapes = [out_shape] * len(out_bufs) + kv = kvar if not isinstance(k, int) else k out = te.extern(out_shapes, [data], lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.sort.topk", ins[0], *outs, k, axis, ret_type, is_ascend), + "tvm.contrib.sort.topk", ins[0], *outs, kv, axis, ret_type, is_ascend), in_buffers=[data_buf], out_buffers=out_bufs, name="topk_cpu", -- 2.7.4