Add More Shape Functions (#4179)
authorYao Wang <kevinthesunwy@gmail.com>
Mon, 11 Nov 2019 23:46:29 +0000 (15:46 -0800)
committerHaichen Shen <shenhaichen@gmail.com>
Mon, 11 Nov 2019 23:46:29 +0000 (15:46 -0800)
* Add shape functions

* Fix get_const_tuple

* Fix cpplint

* Fix pylint

* Fix pylint

* rebase and fix

* Check Any for infer type

* Fix expand_dim shape func for zero rank input

* Fix pooling infer type

* Address comment

* Register layout transform attr

19 files changed:
python/tvm/autotvm/task/task.py
python/tvm/autotvm/util.py
python/tvm/relay/op/_reduce.py
python/tvm/relay/op/_tensor.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/nn/_nn.py
src/lang/data_layout.cc
src/relay/op/nn/convolution.cc
src/relay/op/nn/convolution.h
src/relay/op/nn/nn.cc
src/relay/op/nn/pad.cc
src/relay/op/nn/pooling.cc
src/relay/op/tensor/reduce.cc
src/relay/op/tensor/transform.cc
tests/python/relay/test_any.py
topi/include/topi/nn/flatten.h
topi/python/topi/util.py
topi/python/topi/x86/conv2d.py
topi/python/topi/x86/dense.py

index 4f3cc90..7f36914 100644 (file)
@@ -219,15 +219,15 @@ def args_to_workload(x, topi_compute_func=None):
         workload = get_const_tuple(x.shape) + (x.dtype, )
     elif isinstance(x, (tuple, list, container.Array)):
         workload = tuple([args_to_workload(a) for a in x])
-    elif isinstance(x, (str, int, float, np.int, np.float)):
+    elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
         workload = x
     elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
         workload = x.value
     elif x is None:
         workload = 0
     else:
-        raise RuntimeError('Do not support type "%s" in argument. Consider to use '
-                           'primitive types only' % type(x))
+        raise RuntimeError('Do not support type "%s" in argument. Consider to use'
+                           'primitive types or tvm.expr.Var only' % type(x))
     return (get_func_name(topi_compute_func), ) + workload  if topi_compute_func else workload
 
 def template(func):
index 7c98acd..3026914 100644 (file)
@@ -163,7 +163,7 @@ def get_const_int(exp):
 
 
 def get_const_tuple(in_tuple):
-    """Verifies input tuple is IntImm, returns tuple of int.
+    """Verifies input tuple is IntImm or Var, returns tuple of int or Var.
 
     Parameters
     ----------
@@ -175,4 +175,14 @@ def get_const_tuple(in_tuple):
     out_tuple : tuple of int
         The output.
     """
-    return tuple(get_const_int(x) for x in in_tuple)
+    ret = []
+    for elem in in_tuple:
+        if isinstance(elem, expr.Var):
+            ret.append(elem)
+        elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
+            elem = ir_pass.Simplify(elem)
+            if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
+                ret.append(elem)
+        else:
+            ret.append(get_const_int(elem))
+    return tuple(ret)
index 06d0d66..43f71c0 100644 (file)
 from __future__ import absolute_import
 
 import topi
+
+from topi.util import get_const_int, get_const_tuple
 from . import op as _reg
+from ...api import convert
+from ...hybrid import script
 
 
 def _schedule_reduce(_, outs, target):
@@ -39,3 +43,67 @@ _reg.register_schedule("mean", _schedule_reduce)
 _reg.register_schedule("variance", _schedule_reduce)
 _reg.register_schedule("nn.cross_entropy", _schedule_reduce)
 _reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)
+
+
+def _create_axis_record(attrs, inputs):
+    axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis))
+    exclude = get_const_int(attrs.exclude) > 0
+    keepdims = get_const_int(attrs.keepdims) > 0
+    data_shape = inputs[0]
+    shape_size = data_shape.shape[0].value
+    axis_record = [-1] * shape_size
+    if axes is None:
+        axes = list(range(shape_size))
+
+    for i, axis in enumerate(axes):
+        if axis < 0:
+            axes[i] = shape_size + axis
+
+    if exclude:
+        ex_axes = []
+        for i in range(shape_size):
+            if i not in axes:
+                ex_axes.append(i)
+        axes = ex_axes
+
+    for i in range(shape_size):
+        if i not in axes:
+            axis_record[i] = i
+
+    if not keepdims:
+        tmp = []
+        for i in axis_record:
+            if i >= 0:
+                tmp.append(i)
+        axis_record = tmp
+
+    return axis_record
+
+
+@script
+def _reduce_shape_func(data_shape, axis_record):
+    out = output_tensor((len(axis_record),), "int64")
+    for i in const_range(len(axis_record)):
+        if axis_record[i] >= 0:
+            out[i] = data_shape[axis_record[i]]
+        else:
+            out[i] = int64(1)
+
+    return out
+
+def reduce_shape_func(attrs, inputs, _):
+    """
+    Shape function for reduce op.
+    """
+    axis_record = _create_axis_record(attrs, inputs)
+    return [_reduce_shape_func(inputs[0], convert(axis_record))]
+
+_reg.register_shape_func("argmax", False, reduce_shape_func)
+_reg.register_shape_func("argmin", False, reduce_shape_func)
+_reg.register_shape_func("all", False, reduce_shape_func)
+_reg.register_shape_func("sum", False, reduce_shape_func)
+_reg.register_shape_func("max", False, reduce_shape_func)
+_reg.register_shape_func("min", False, reduce_shape_func)
+_reg.register_shape_func("prod", False, reduce_shape_func)
+_reg.register_shape_func("mean", False, reduce_shape_func)
+_reg.register_shape_func("variance", False, reduce_shape_func)
index 188b3bb..dcff084 100644 (file)
@@ -119,18 +119,6 @@ def _cast_shape_function(x):
 def cast_shape_func(attrs, inputs, out_ndims):
     return [_cast_shape_function(*inputs)]
 
-@script
-def _expand_dims_shape_func(x):
-    ndim = len(x.shape)
-    out = output_tensor((ndim+1,), "int64")
-    out[0] = int64(1)
-    for i in const_range(0, ndim):
-        out[i+1] = int64(x.shape[i])
-    return out
-
-def expand_dims_shape_func(attrs, inputs, out_ndims):
-    return [_expand_dims_shape_func(*inputs)]
-
 # shape func
 @script
 def _broadcast_shape_func(x, y, ndim):
@@ -161,9 +149,17 @@ def _broadcast_shape_func(x, y, ndim):
     return out
 
 def broadcast_shape_func(attrs, inputs, out_ndims):
+    """
+    Shape function for broadcast op.
+    """
     return [_broadcast_shape_func(*inputs, out_ndims[0])]
 
-register_shape_func("expand_dims", False, expand_dims_shape_func)
+def elemwise_shape_func(attrs, inputs, _):
+    """
+    Shape function for elemwise op.
+    """
+    return [topi.math.identity(inputs[0])]
+
 register_shape_func("cast", False, cast_shape_func)
 
 register_shape_func("add", False, broadcast_shape_func)
@@ -179,3 +175,6 @@ register_shape_func("less", False, broadcast_shape_func)
 register_shape_func("less_equal", False, broadcast_shape_func)
 register_shape_func("greater", False, broadcast_shape_func)
 register_shape_func("greater_equal", False, broadcast_shape_func)
+
+register_shape_func("sqrt", False, elemwise_shape_func)
+register_shape_func("negative", False, elemwise_shape_func)
index 687d5b4..13f41fc 100644 (file)
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Backend compiler related feature registration"""
-# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks
+# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
 from __future__ import absolute_import
 import tvm
 import topi
@@ -303,3 +303,195 @@ def compute_argwhere(attrs, inputs, output_type, _):
             output_shape.append(tvm.var("any_dim", "int32"))
     new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
     return [topi.argwhere(new_output_type, inputs[0])]
+
+@script
+def _layout_transform_shape_func(data_shape,
+                                 out_layout_len,
+                                 dst_equal_list,
+                                 dst_mul_list,
+                                 dst_div_list,
+                                 dst_mix_list):
+    out = output_tensor((out_layout_len,), "int64")
+    for i in const_range(len(dst_equal_list)):
+        out[dst_equal_list[i][0]] = data_shape[dst_equal_list[i][1]]
+    for i in const_range(len(dst_mul_list)):
+        out[dst_mul_list[i][0]] = data_shape[dst_mul_list[i][1]] * \
+                                  data_shape[dst_mul_list[i][2]]
+    for i in const_range(len(dst_div_list)):
+        out[dst_div_list[i][0]] = data_shape[dst_div_list[i][1]] \
+                                  // dst_div_list[i][3]
+        out[dst_div_list[i][2]] = int64(dst_div_list[i][3])
+    for i in const_range(len(dst_mix_list)):
+        out[dst_mix_list[i][0]] = data_shape[dst_mix_list[i][1]] * \
+                                  dst_mix_list[i][2] // dst_mix_list[i][4]
+        out[dst_mix_list[i][3]] = int64(dst_mix_list[i][4])
+
+    return out
+
+@_reg.register_shape_func("layout_transform", False)
+def layout_transform_shape_func(attrs, inputs, _):
+    """
+    Shape function for layout_transform op.
+    """
+    def _fetch_axis(layout):
+        major_axes = []
+        minor_axes = {}
+        num_start = -1
+        for i, item in enumerate(layout):
+            if "A" <= item <= "Z":
+                major_axes.append(item)
+            elif "a" <= item <= "z":
+                last_num = int(layout[num_start:i])
+                minor_axes[item] = last_num
+                num_start = -1
+            elif num_start < 0:
+                num_start = i
+        return major_axes, minor_axes
+
+    _, src_minor_axes = _fetch_axis(attrs.src_layout)
+    dst_major_axes, dst_minor_axes = _fetch_axis(attrs.dst_layout)
+    src_letter_list = []
+    dst_letter_list = []
+    for item in attrs.src_layout:
+        if "A" <= item <= "Z" or "a" <= item <= "z":
+            src_letter_list.append(item)
+    for item in attrs.dst_layout:
+        if "A" <= item <= "Z" or "a" <= item <= "z":
+            dst_letter_list.append(item)
+    out_layout_len = len(dst_major_axes) + len(dst_minor_axes)
+    dst_equal_list = []
+    dst_mul_list = []
+    dst_div_list = []
+    dst_mix_list = []
+
+    for key in dst_major_axes:
+        if key.lower() not in dst_minor_axes:
+            if key.lower() not in src_minor_axes:
+                dst_equal_list.append((dst_letter_list.index(key),
+                                       src_letter_list.index(key)))
+            else:
+                dst_mul_list.append((dst_letter_list.index(key),
+                                     src_letter_list.index(key),
+                                     src_letter_list.index(key.lower())))
+        else:
+            if key.lower() not in src_minor_axes:
+                dst_div_list.append((dst_letter_list.index(key),
+                                     src_letter_list.index(key),
+                                     dst_letter_list.index(key.lower()),
+                                     dst_minor_axes[key.lower()]))
+            else:
+                dst_mix_list.append((dst_letter_list.index(key),
+                                     src_letter_list.index(key),
+                                     src_minor_axes[key.lower()],
+                                     dst_letter_list.index(key.lower()),
+                                     dst_minor_axes[key.lower()]))
+
+    return [_layout_transform_shape_func(inputs[0],
+                                         convert(out_layout_len),
+                                         convert(dst_equal_list),
+                                         convert(dst_mul_list),
+                                         convert(dst_div_list),
+                                         convert(dst_mix_list))]
+
+@script
+def _expand_dim_shape_func(data_shape, ndim, axis, num_newaxis):
+    out = output_tensor((ndim + num_newaxis,), "int64")
+    for i in const_range(out.shape[0]):
+        if i < axis:
+            out[i] = data_shape[i]
+        elif i < axis + num_newaxis:
+            out[i] = int64(1)
+        else:
+            out[i] = data_shape[i - num_newaxis]
+
+    return out
+
+@_reg.register_shape_func("expand_dims", False)
+def expand_dim_shape_func(attrs, inputs, _):
+    """
+    Shape function for expand_dim op.
+    """
+    axis = get_const_int(attrs.axis)
+    num_newaxis = get_const_int(attrs.num_newaxis)
+    if axis < 0:
+        axis = inputs[0].shape[0] + axis + 1
+    ndim = inputs[0].shape[0] if inputs[0].shape else 0
+    return [_expand_dim_shape_func(inputs[0],
+                                   convert(ndim),
+                                   convert(axis),
+                                   convert(num_newaxis))]
+
+@script
+def _transpose_shape_func(data_shape, axes):
+    out = output_tensor((data_shape.shape[0],), "int64")
+    for i in const_range(len(axes)):
+        out[i] = data_shape[axes[i]]
+
+    return out
+
+@_reg.register_shape_func("transpose", False)
+def transpose_shape_func(attrs, inputs, _):
+    """
+    Shape function for transpose op.
+    """
+    axes = attrs.axes if attrs.axes is None else get_const_tuple(attrs.axes)
+    if axes is None:
+        axes = list(range(inputs[0].shape[0].value))
+        axes.reverse()
+    for i, axis in enumerate(axes):
+        if axis < 0:
+            axes[i] = inputs[0].shape[0] - axis
+    return [_transpose_shape_func(inputs[0], convert(axes))]
+
+@script
+def _squeeze_shape_func(data_shape, keep_axes):
+    out = output_tensor((len(keep_axes),), "int64")
+    if len(keep_axes) == 0:
+        out_size = 0
+        for i in const_range(data_shape.shape[0]):
+            if data_shape[i] != 1:
+                out_size += 1
+
+        if out_size == 0:
+            out_size = 1
+        out = output_tensor((out_size,), "int64")
+        out[0] = int64(1)
+        pos = 0
+        for i in const_range(data_shape.shape[0]):
+            if data_shape[i] != 1:
+                out[pos] = data_shape[i]
+                pos += 1
+    else:
+        for i in const_range(len(keep_axes)):
+            out[i] = data_shape[keep_axes[i]]
+
+    return out
+
+@_reg.register_shape_func("squeeze", False)
+def squeeze_shape_func(attrs, inputs, _):
+    """
+    Shape function for squeeze op.
+    """
+    axis = attrs.axis if attrs.axis is None else get_const_tuple(attrs.axis)
+    keep_axes = []
+    if axis is not None:
+        for i in range(inputs[0].shape[0].value):
+            if i not in axis:
+                keep_axes.append(i)
+
+    return [_squeeze_shape_func(inputs[0], convert(keep_axes))]
+
+@script
+def _reshape_like_shape_func(target_shape):
+    out = output_tensor((target_shape.shape[0],), "int64")
+    for i in const_range(target_shape.shape[0]):
+        out[i] = target_shape[i]
+
+    return out
+
+@_reg.register_shape_func("reshape_like", False)
+def reshape_like_shape_func(attrs, inputs, _):
+    """
+    Shape function for reshape_like op.
+    """
+    return [_reshape_like_shape_func(inputs[1])]
index 8915480..54f13c6 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, unused-argument
+# pylint: disable=invalid-name, unused-argument, too-many-arguments
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 
@@ -22,6 +22,9 @@ import topi
 from topi.util import get_const_tuple
 from .. import op as reg
 from ..op import OpPattern, schedule_injective
+from .._tensor import elemwise_shape_func
+from ....api import convert
+from ....hybrid import script
 
 # relu
 reg.register_schedule("nn.relu", schedule_injective)
@@ -766,7 +769,6 @@ reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
 reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
 
-
 @reg.register_compute("nn.cross_entropy")
 def compute_cross_entropy(attrs, inputs, out_dtype, target):
     x, y = inputs
@@ -775,8 +777,170 @@ def compute_cross_entropy(attrs, inputs, out_dtype, target):
 
 reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
 
-
 @reg.register_compute("nn.cross_entropy_with_logits")
 def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
     x, y = inputs
     return [-topi.sum(x * y) / x.shape[0]]
+
+# shape func
+@script
+def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
+    out = output_tensor((dshape.shape[0],), "int64")
+    ic_chunk = dshape[1]
+    height = dshape[2]
+    width = dshape[3]
+    ic_bn = dshape[4]
+    kheight = kshape[2]
+    kwidth = kshape[3]
+    dilated_kh = (kheight - 1) * dilation[0] + 1
+    dilated_kw = (kwidth - 1) * dilation[1] + 1
+    kflatten = int64(1)
+    for i in const_range(kshape.shape[0]):
+        kflatten *= kshape[i]
+
+    oc = kflatten // (kheight * kwidth * ic_chunk * ic_bn)
+    oc_chunk = oc // oc_bn
+
+    out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1
+    out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1
+
+    out[0] = dshape[0]
+    out[1] = oc_chunk
+    out[2] = out_height
+    out[3] = out_width
+    out[4] = int64(oc_bn)
+    return out
+
+@reg.register_shape_func("nn.contrib_conv2d_NCHWc", False)
+def conv2d_NCHWc_shape_func(attrs, inputs, _):
+    """
+    Shape function for contrib_conv2d_NCHWc op.
+    """
+    strides = get_const_tuple(attrs.strides)
+    padding = get_const_tuple(attrs.padding)
+    dilation = get_const_tuple(attrs.dilation)
+    out_layout = attrs.out_layout
+    oc_bn = int(out_layout[4:-1])
+
+    return [_conv2d_NCHWc_shape_func(inputs[0], inputs[1],
+                                     convert(strides), convert(padding),
+                                     convert(dilation), convert(oc_bn))]
+
+@script
+def _pool2d_shape_func(data_shape, pool_size, strides,
+                       padding, height_axis, width_axis):
+    out = output_tensor((data_shape.shape[0],), "int64")
+    for i in const_range(data_shape.shape[0]):
+        if i == height_axis:
+            out[i] = (data_shape[i] + padding[0] + padding[2] - pool_size[0]) // strides[0] + 1
+        elif i == width_axis:
+            out[i] = (data_shape[i] + padding[1] + padding[3] - pool_size[1]) // strides[1] + 1
+        else:
+            out[i] = data_shape[i]
+
+    return out
+
+def pool2d_shape_func(attrs, inputs, _):
+    """
+    Shape function for pool2d op.
+    """
+    pool_size = get_const_tuple(attrs.pool_size)
+    strides = get_const_tuple(attrs.strides)
+    padding = get_const_tuple(attrs.padding)
+    layout = attrs.layout
+    height_axis = layout.index("H")
+    width_axis = layout.index("W")
+    if len(padding) == 1:
+        padding = [padding[0]] * 4
+    elif len(padding) == 2:
+        padding = [padding[0], padding[1], padding[0], padding[1]]
+
+    return [_pool2d_shape_func(inputs[0], convert(pool_size),
+                               convert(strides), convert(padding),
+                               convert(height_axis), convert(width_axis))]
+
+reg.register_shape_func("nn.max_pool2d", False, pool2d_shape_func)
+reg.register_shape_func("nn.avg_pool2d", False, pool2d_shape_func)
+
+@script
+def _global_pool2d_shape_func(data_shape, height_axis, width_axis):
+    out = output_tensor((data_shape.shape[0],), "int64")
+    for i in const_range(out.shape[0]):
+        if i == height_axis or i == width_axis:
+            out[i] = int64(1)
+        else:
+            out[i] = data_shape[i]
+
+    return out
+
+def global_pool2d_shape_func(attrs, inputs, _):
+    """
+    Shape function for global pool2d op.
+    """
+    layout = attrs.layout
+    height_axis = width_axis = 1
+    for i, letter in enumerate(layout):
+        if letter == "H":
+            height_axis = i
+        if letter == "W":
+            width_axis = i
+    return [_global_pool2d_shape_func(inputs[0], convert(height_axis), convert(width_axis))]
+
+reg.register_shape_func("nn.global_max_pool2d", False, global_pool2d_shape_func)
+reg.register_shape_func("nn.global_avg_pool2d", False, global_pool2d_shape_func)
+
+@script
+def _batch_flatten_shape_func(data_shape):
+    out = output_tensor((2,), "int64")
+    out[0] = data_shape[0]
+    out[1] = int64(1)
+    for i in const_range(data_shape.shape[0] - 1):
+        out[1] *= data_shape[i + 1]
+
+    return out
+
+@reg.register_shape_func("nn.batch_flatten", False)
+def batch_flatten_shape_func(attrs, inputs, _):
+    """
+    Shape function for batch_flatten op.
+    """
+    return [_batch_flatten_shape_func(inputs[0])]
+
+@script
+def _dense_shape_func(data_shape, weight_shape):
+    out = output_tensor((data_shape.shape[0],), "int64")
+    for i in const_range(out.shape[0] - 1):
+        out[i] = data_shape[i]
+    out[out.shape[0] - 1] = weight_shape[0]
+
+    return out
+
+@reg.register_shape_func("nn.dense", False)
+def dense_shape_func(attrs, inputs, _):
+    """
+    Shape function for dense op.
+    """
+    ret = [_dense_shape_func(inputs[0], inputs[1])]
+    return ret
+
+@script
+def _pad_shape_func(data_shape, pad_width):
+    out = output_tensor((data_shape.shape[0],), "int64")
+    for i in const_range(out.shape[0]):
+        out[i] = data_shape[i] + pad_width[i][0] + pad_width[i][1]
+
+    return out
+
+@reg.register_shape_func("nn.pad", False)
+def pad_shape_func(attrs, inputs, _):
+    """
+    Shape function for pad op.
+    """
+    pad_width = []
+    for pair in attrs.pad_width:
+        pad_width.append(get_const_tuple(pair))
+    return [_pad_shape_func(inputs[0], convert(pad_width))]
+
+reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
+reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
+reg.register_shape_func("nn.relu", False, elemwise_shape_func)
index 7c76e40..35139bb 100644 (file)
@@ -289,16 +289,22 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
   // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
   // e.g., (C * 16 + c) / 32
   std::unordered_map<const Variable*, Expr> bind_map;
+  std::unordered_set<size_t> symbolic_var_set;
   for (size_t i = 0; i < src_shape.size(); ++i) {
     Expr orig_shape = src_shape[i];
     IterVar orig_axis = src_axis[i];
+    if (orig_shape.as<ir::Any>()) {
+      symbolic_var_set.insert(i);
+    }
     if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
       if (orig_shape.defined()) {
         const auto* orig_shape_const = orig_shape.as<IntImm>();
         const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImm>();
-        CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
-          << "Input shape mismatch at index " << i << ". Expected "
-          << orig_axis->dom->extent << ", get " << orig_shape;
+        if (orig_shape_const) {
+          CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
+            << "Input shape mismatch at index " << i << ". Expected "
+            << orig_axis->dom->extent << ", get " << orig_shape;
+        }
       }
       bind_map[orig_axis->var.get()] = Expr(0);
     } else {
@@ -316,7 +322,11 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
     if (!LayoutAxis::Get(axis).IsPrimal()) {
       result.push_back(axis->dom->extent);
     } else {
-      result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
+      if (symbolic_var_set.count(i)) {
+        result.push_back(ir::Any::make());
+      } else {
+        result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
+      }
     }
   }
   return result;
index bf4e54b..002a246 100644 (file)
@@ -330,8 +330,19 @@ bool Conv2DWinogradRel(const Array<Type>& types,
   // dilation
   Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
 
-  oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
-  oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
+  if (!dshape_nchw[2].as<ir::Any>()) {
+    oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2
+                   - dilated_ksize_y) / param->strides[0] + 1);
+  } else {
+    oshape.Set(2, dshape_nchw[2]);
+  }
+  if (!dshape_nchw[3].as<ir::Any>()) {
+    oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2
+                   - dilated_ksize_x) / param->strides[1] + 1);
+  } else {
+    oshape.Set(3, dshape_nchw[3]);
+  }
+
   DataType out_dtype = param->out_dtype;
   if (out_dtype.bits() == 0) {
     out_dtype = data->dtype;
index 602a091..19b84dd 100644 (file)
@@ -116,10 +116,19 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   // dilation
   Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
 
-  oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y,
-                         param->strides[0]) + 1);
-  oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x,
-                         param->strides[1]) + 1);
+  if (!dshape_nchw[2].as<ir::Any>()) {
+    oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y,
+                           param->strides[0]) + 1);
+  } else {
+    oshape.Set(2, dshape_nchw[2]);
+  }
+
+  if (!dshape_nchw[3].as<ir::Any>()) {
+    oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x,
+                           param->strides[1]) + 1);
+  } else {
+    oshape.Set(3, dshape_nchw[3]);
+  }
   DataType out_dtype = param->out_dtype;
   if (out_dtype.bits() == 0) {
     out_dtype = data->dtype;
index 416a0d7..d3a7178 100644 (file)
@@ -408,7 +408,12 @@ bool BatchFlattenRel(const Array<Type>& types,
   auto target_dim = make_const(Int(32), 1);
 
   for (uint32_t i = 1; i < data->shape.size(); ++i) {
-    target_dim = target_dim * data->shape[i];
+    if (!data->shape[i].as<ir::Any>()) {
+      target_dim = target_dim * data->shape[i];
+    } else {
+      target_dim = data->shape[i];
+      break;
+    }
   }
 
   std::vector<IndexExpr> oshape({data->shape[0], target_dim});
index 2342880..d625f19 100644 (file)
@@ -148,8 +148,12 @@ bool PadRel(const Array<Type>& types,
       << "Param width elements should be positive but first pad width at "
       << "index " << i << " is " << *width2 << ".";
 
-    auto padding = make_const(data->shape[i].type(), *width1 + *width2);
-    oshape.push_back(data->shape[i] + padding);
+    if (!data->shape[i].as<ir::Any>()) {
+      auto padding = make_const(data->shape[i].type(), *width1 + *width2);
+      oshape.push_back(data->shape[i] + padding);
+    } else {
+      oshape.push_back(data->shape[i]);
+    }
   }
 
   reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
index 94f8a54..99d184f 100644 (file)
@@ -102,14 +102,25 @@ bool Pool2DRel(const Array<Type>& types,
     oshape.push_back(e);
   }
 
-  if (param->ceil_mode) {
-    oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] +
-                    param->strides[0] - 1) / param->strides[0]) + 1;
-    oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] +
-                    param->strides[1] - 1) / param->strides[1]) + 1;
+  if (dshape[hidx].as<ir::Any>()) {
+    oshape[hidx] = dshape[hidx];
   } else {
-    oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1;
-    oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1;
+    if (param->ceil_mode) {
+      oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] +
+                       param->strides[0] - 1) / param->strides[0]) + 1;
+    } else {
+      oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1;
+    }
+  }
+  if (dshape[widx].as<ir::Any>()) {
+    oshape[widx] = dshape[widx];
+  } else {
+    if (param->ceil_mode) {
+      oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] +
+                       param->strides[1] - 1) / param->strides[1]) + 1;
+    } else {
+      oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1;
+    }
   }
 
   // assign output type
index 63524bc..3a2d469 100644 (file)
@@ -211,11 +211,20 @@ inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr> &in_s
   }
 
   auto max_shape = make_const(Int(64), 1);
+  bool is_dynamic_input = false;
   for (int64_t axis : r_axes) {
-    max_shape *= in_shape[axis];
+    if (in_shape[axis].as<IntImm>()) {
+      max_shape *= in_shape[axis];
+    } else {
+      is_dynamic_input = true;
+      break;
+    }
+  }
+
+  if (is_dynamic_input) {
+    CHECK(reporter->Assert(max_shape < make_const(Int(64), std::numeric_limits<int32_t>::max())))
+      << "The maximum possible index of reduced shape cannot be more than int32 max.";
   }
-  CHECK(reporter->Assert(max_shape < make_const(Int(64), std::numeric_limits<int32_t>::max())))
-    << "The maximum possible index of reduced shape cannot be more than int32 max.";
 
   if (param->keepdims) {
     std::vector<IndexExpr> oshape(in_shape);
index e1239ae..203a041 100644 (file)
@@ -797,8 +797,18 @@ bool ReshapeLikeRel(const Array<Type>& types,
   if (reshape_like == nullptr) {
     return false;
   }
-  CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
-    << "Reshape inputs size should be compatible.";
+  // Only check When input data has static shape.
+  bool is_static_shape = true;
+  for (size_t i = 0; i < data->shape.size(); ++i) {
+    if (!data->shape[i].as<IntImm>()) {
+      is_static_shape = false;
+      break;
+    }
+  }
+  if (is_static_shape) {
+    CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
+      << "Reshape inputs size should be compatible.";
+  }
   reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype));
   return true;
 }
@@ -2292,6 +2302,8 @@ RELAY_REGISTER_OP("slice_like")
 .set_attr<TOpPattern>("TOpPattern", kInjective);
 
 // relay.layout_transform
+TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
+
 Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
                                      const Array<Tensor>& inputs,
                                      const Type& out_type,
index 891cfad..75be88c 100644 (file)
@@ -188,6 +188,257 @@ def test_any_shape_of():
         result = ex.evaluate()(data)
         tvm.testing.assert_allclose(result.asnumpy(), np.array(3).astype("int64"))
 
+def verify_any_reduce(reduce_op, data_shape, axis, exclude, keepdims,
+                      static_data_shape, ref_out_shape):
+    mod = relay.Module()
+    dtype = "bool" if reduce_op == relay.all else "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = reduce_op(data, axis, keepdims, exclude)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_reduce():
+    verify_any_reduce(relay.argmax, any_dims(3), None, False, False, (3, 4, 5), ())
+    verify_any_reduce(relay.argmin, any_dims(4), 1, False, True, (3, 4, 5, 6), (3, 1, 5, 6))
+    verify_any_reduce(relay.all, any_dims(3), (1, 2), True, False, (3, 4, 5), (4, 5))
+    verify_any_reduce(relay.max, any_dims(4), -1, True, True, (3, 4, 5, 6), (1, 1, 1, 6))
+    verify_any_reduce(relay.min, any_dims(3), (0, 1), False, False, (4, 5, 6), (6,))
+    verify_any_reduce(relay.prod, any_dims(4), 2, True, True, (3, 4, 5, 6), (1, 1, 5, 1))
+    verify_any_reduce(relay.mean, any_dims(2), 0, False, False, (1, 2), (2,))
+    verify_any_reduce(relay.variance, any_dims(5), (2, 4), False, False, (3, 4, 5, 6, 7), (3, 4, 6))
+
+def verify_any_layout_transform(data_shape, src_layout, dst_layout, static_data_shape, ref_out_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = relay.layout_transform(data, src_layout, dst_layout)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_layout_transform():
+    verify_any_layout_transform(any_dims(4), "NCHW", "NHWC", (3, 4, 5, 6), (3, 5, 6, 4))
+    verify_any_layout_transform(any_dims(5), "NCHW16c", "NCHW2c", (1, 2, 8, 8, 16), (1, 16, 8, 8, 2))
+    verify_any_layout_transform(any_dims(5), "NCHW6n", "NHWC", (3, 4, 5, 6, 6), (18, 5, 6, 4))
+    verify_any_layout_transform(any_dims(4), "NCHW", "NCHW4c", (3, 4, 5, 6), (3, 1, 5, 6, 4))
+    verify_any_layout_transform((16, 1), "CH", "C4cH", (16, 1), (4, 4, 1))
+
+def verify_any_expand_dims(data_shape, axis, num_newaxis, static_data_shape, ref_out_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = relay.expand_dims(data, axis=axis, num_newaxis=num_newaxis)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_expand_dims():
+    verify_any_expand_dims(any_dims(3), 1, 2, (1, 2, 3), (1, 1, 1, 2, 3))
+    verify_any_expand_dims(any_dims(3), -1, 2, (1, 2, 3), (1, 2, 3, 1, 1))
+
+def verify_any_transpose(data_shape, axes, static_data_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = relay.transpose(data, axes=axes)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    ref_out = np.transpose(data_np, axes)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        tvm.testing.assert_allclose(result.asnumpy(), ref_out)
+
+def test_any_transpose():
+    verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2))
+    verify_any_transpose(any_dims(3), None, (2, 3, 4))
+    verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17))
+
+def verify_any_squeeze(data_shape, axis, static_data_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = relay.squeeze(data, axis=axis)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    ref_out = np.squeeze(data_np, axis)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        tvm.testing.assert_allclose(result.asnumpy(), ref_out)
+
+def test_any_squeeze():
+    verify_any_squeeze((1, relay.Any(), relay.Any()), (0,), (1, 9, 8))
+    verify_any_squeeze((1, relay.Any(), relay.Any(), 1, relay.Any(), relay.Any()), (0, 3), (1, 12, 2, 1, 9, 17))
+
+def test_any_reshape_like():
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=(relay.Any(), 3, 10), dtype=dtype)
+    shape_like = relay.var('data', shape=(relay.Any(), 5, 6), dtype=dtype)
+    y = relay.reshape_like(data, shape_like)
+    mod["main"] = relay.Function([data, shape_like], y)
+    data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype)
+    shape_like_np = np.random.uniform(size=(3, 5, 6)).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np, shape_like_np)
+        assert result.asnumpy().shape == shape_like_np.shape, \
+            "Shape mismatch: expect %s but got %s." % (str(shape_like_np.shape), str(result.asnumpy().shape))
+
+def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation,
+                            data_layout, kernel_layout, out_layout,
+                            static_data_shape, ref_out_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    kernel = relay.var('kernel', shape=kernel_shape, dtype=dtype)
+    y = relay.nn.contrib_conv2d_nchwc(data, kernel, strides, padding, dilation,
+                                      kernel_size=kernel_shape[2:4],
+                                      channels=kernel_shape[0]*kernel_shape[-1],
+                                      data_layout=data_layout, kernel_layout=kernel_layout,
+                                      out_layout=out_layout)
+    mod["main"] = relay.Function([data, kernel], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np, kernel_np)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_conv2d_NCHWc():
+    verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (1, 1),
+                            "NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 224, 224, 8))
+    verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (2, 2),
+                            "NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 222, 222, 8))
+
+def verify_any_pool2d(pool_type, data_shape, pool_size, strides, padding,
+                      layout, static_data_shape, ref_out_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    pool_func = relay.nn.max_pool2d if pool_type == "max" else relay.nn.avg_pool2d
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = pool_func(data, pool_size, strides, padding, layout)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_pool2d():
+    verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()),
+                      (3, 3), (1, 1), (1, 1), "NCHW", (2, 3, 220, 220), (2, 3, 220, 220))
+    verify_any_pool2d("avg", (relay.Any(), relay.Any(), relay.Any(), 4),
+                      (1, 1), (2, 2), (0, 0), "NHWC", (3, 220, 220, 4), (3, 110, 110, 4))
+    verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4),
+                      (3, 3), (2, 2), (1, 1), "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 110, 110, 4))
+
+def verify_any_global_pool2d(pool_type, data_shape, layout, static_data_shape, ref_out_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    pool_func = relay.nn.global_max_pool2d if pool_type == "max" else relay.nn.global_avg_pool2d
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = pool_func(data, layout)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_global_pool2d():
+    verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()),
+                      "NCHW", (2, 3, 220, 220), (2, 3, 1, 1))
+    verify_any_global_pool2d("avg", (relay.Any(), relay.Any(), relay.Any(), 4),
+                      "NHWC", (3, 220, 220, 4), (3, 1, 1, 4))
+    verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4),
+                      "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4))
+
+def test_any_batch_flatten():
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=any_dims(3), dtype=dtype)
+    y = relay.nn.batch_flatten(data)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype)
+    ref_out_shape = (3, 30)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def verify_any_dense(data_shape, weight_shape, units, static_data_shape,
+                     static_weight_shape, ref_out_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    weight = relay.var('weight', shape=weight_shape, dtype=dtype)
+    y = relay.nn.dense(data, weight, units)
+    mod["main"] = relay.Function([data, weight], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    weight_np = np.random.uniform(size=static_weight_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np, weight_np)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_dense():
+    verify_any_dense(any_dims(2), any_dims(2), None, (4, 16), (8, 16), (4, 8))
+    verify_any_dense(any_dims(2), (50, relay.Any()), 50, (4, 40), (50, 40), (4, 50))
+
+def verify_any_pad(data_shape, pad_width, static_data_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = relay.nn.pad(data, pad_width)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        ref_out = np.pad(data_np, pad_width)
+        tvm.testing.assert_allclose(result.asnumpy(), ref_out)
+
+def test_any_pad():
+    verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3))
+    verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1))
+
+def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape):
+    mod = relay.Module()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = relay.nn.softmax(data, axis)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_softmax():
+    verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3))
+    verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1))
+
 def test_fused_ops():
     x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32')
     y0 = x + relay.const(1.0, 'float32')
@@ -308,6 +559,19 @@ if __name__ == "__main__":
     test_any_reshape()
     test_any_take()
     test_any_shape_of()
+    test_any_reduce()
+    test_any_layout_transform()
+    test_any_expand_dims()
+    test_any_transpose()
+    test_any_squeeze()
+    test_any_reshape_like()
+    test_any_conv2d_NCHWc()
+    test_any_pool2d()
+    test_any_global_pool2d()
+    test_any_batch_flatten()
+    test_any_dense()
+    test_any_pad()
+    test_any_softmax()
     test_fused_ops()
     test_arange_with_dynamic_shape()
     test_recursive_concat()
index ea4f62e..d044547 100644 (file)
@@ -52,9 +52,9 @@ inline Tensor flatten(const Tensor& x,
                       std::string name = "tensor",
                       std::string tag = kInjective) {
   auto ishape = x->shape;
-  int dim = 1;
+  Expr dim = 1;
   for (size_t i = 1; i < ishape.size(); ++i) {
-    dim = dim * static_cast<int>(topi::detail::GetConstInt(ishape[i]));
+    dim = dim * ishape[i];
   }
 
   Array<Expr> oshape({ ishape[0], dim });
index 1bf3a10..623d06a 100644 (file)
@@ -144,7 +144,7 @@ def equal_const_int(expr, value):
 
 
 def get_const_tuple(in_tuple):
-    """Verifies input tuple is IntImm, returns tuple of int.
+    """Verifies input tuple is IntImm or Var, returns tuple of int or Var.
 
     Parameters
     ----------
@@ -156,7 +156,17 @@ def get_const_tuple(in_tuple):
     out_tuple : tuple of int
         The output.
     """
-    return tuple(get_const_int(elem) for elem in in_tuple)
+    ret = []
+    for elem in in_tuple:
+        if isinstance(elem, tvm.expr.Var):
+            ret.append(elem)
+        elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)):
+            elem = tvm.ir_pass.Simplify(elem)
+            if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+                ret.append(elem)
+        else:
+            ret.append(get_const_int(elem))
+    return tuple(ret)
 
 
 def get_float_tuple(in_tuple):
index 9ea93cd..0e284da 100644 (file)
@@ -41,6 +41,13 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
     """
     Get default schedule config for the workload
     """
+    static_data_shape = []
+    for dim in get_const_tuple(data.shape):
+        if isinstance(dim, tvm.expr.Var):
+            static_data_shape.append(1)
+        else:
+            static_data_shape.append(dim)
+    data = tvm.placeholder(static_data_shape, dtype=data.dtype)
     if is_depthwise:
         wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
         from .depthwise_conv2d import _fallback_schedule
index ec401bf..2a739d5 100644 (file)
@@ -37,6 +37,12 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
         return C
 
     M, _ = get_const_tuple(data.shape)
+    # Always use dense_nopack for dynamic input.
+    # This is a temporary for CV models.
+    # TODO(kevinthesun): use kernel dispatcher instead.
+    if isinstance(M, tvm.expr.Var):
+        return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype)
+
     # For small batch sizes, don't pack weight into cache-friendly layout
     # because of overhead in packing and limited reuse from batch dimension
     # TODO(icemelon9): use a more systematic way to determine which schedule to use
@@ -53,9 +59,9 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
     M, K = get_const_tuple(data.shape) # batch, in_dim
     N, _ = get_const_tuple(weight.shape) # out_dim
     # create tuning space
-    cfg.define_split("tile_y", M, num_outputs=3)
-    cfg.define_split("tile_x", N, num_outputs=3)
-    cfg.define_split("tile_k", K, num_outputs=2)
+    cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2)
+    cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2)
+    cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2)
     if cfg.is_fallback:
         _default_dense_pack_config(cfg, M, N, K)
 
@@ -87,9 +93,9 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
     M, K = get_const_tuple(data.shape)
     N, _ = get_const_tuple(weight.shape)
     # create tuning space
-    cfg.define_split("tile_y", M, num_outputs=2)
-    cfg.define_split("tile_x", N, num_outputs=2)
-    cfg.define_split("tile_k", K, num_outputs=2)
+    cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2)
+    cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2)
+    cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2)
     if cfg.is_fallback:
         _default_dense_nopack_config(cfg, M, N, K)
 
@@ -211,8 +217,15 @@ def _schedule_dense_nopack_template(cfg, s, C):
 
 
 def _default_dense_pack_config(cfg, M, N, K):
-    vec_width = get_fp32_len()
+    # Generate default schedule for dynamic shape.
+    if isinstance(M, tvm.expr.Var):
+        M = 16
+    if isinstance(N, tvm.expr.Var):
+        N = 16
+    if isinstance(K, tvm.expr.Var):
+        K = 16
 
+    vec_width = get_fp32_len()
     tilex_ii = 1
     for bn in range(vec_width*2, 0, -1):
         if N % bn == 0:
@@ -241,6 +254,14 @@ def _default_dense_pack_config(cfg, M, N, K):
 
 
 def _default_dense_nopack_config(cfg, M, N, K):
+    # Generate default schedule for dynamic shape.
+    if isinstance(M, tvm.expr.Var):
+        M = 16
+    if isinstance(N, tvm.expr.Var):
+        N = 16
+    if isinstance(K, tvm.expr.Var):
+        K = 16
+
     vec_width = get_fp32_len()
     tilek_bn = 1
     for bn in range(vec_width*2, 0, -1):