workload = get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)):
workload = tuple([args_to_workload(a) for a in x])
- elif isinstance(x, (str, int, float, np.int, np.float)):
+ elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
workload = x
elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
else:
- raise RuntimeError('Do not support type "%s" in argument. Consider to use '
- 'primitive types only' % type(x))
+ raise RuntimeError('Do not support type "%s" in argument. Consider to use'
+ 'primitive types or tvm.expr.Var only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload
def template(func):
def get_const_tuple(in_tuple):
- """Verifies input tuple is IntImm, returns tuple of int.
+ """Verifies input tuple is IntImm or Var, returns tuple of int or Var.
Parameters
----------
out_tuple : tuple of int
The output.
"""
- return tuple(get_const_int(x) for x in in_tuple)
+ ret = []
+ for elem in in_tuple:
+ if isinstance(elem, expr.Var):
+ ret.append(elem)
+ elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
+ elem = ir_pass.Simplify(elem)
+ if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
+ ret.append(elem)
+ else:
+ ret.append(get_const_int(elem))
+ return tuple(ret)
from __future__ import absolute_import
import topi
+
+from topi.util import get_const_int, get_const_tuple
from . import op as _reg
+from ...api import convert
+from ...hybrid import script
def _schedule_reduce(_, outs, target):
_reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)
+
+
+def _create_axis_record(attrs, inputs):
+ axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis))
+ exclude = get_const_int(attrs.exclude) > 0
+ keepdims = get_const_int(attrs.keepdims) > 0
+ data_shape = inputs[0]
+ shape_size = data_shape.shape[0].value
+ axis_record = [-1] * shape_size
+ if axes is None:
+ axes = list(range(shape_size))
+
+ for i, axis in enumerate(axes):
+ if axis < 0:
+ axes[i] = shape_size + axis
+
+ if exclude:
+ ex_axes = []
+ for i in range(shape_size):
+ if i not in axes:
+ ex_axes.append(i)
+ axes = ex_axes
+
+ for i in range(shape_size):
+ if i not in axes:
+ axis_record[i] = i
+
+ if not keepdims:
+ tmp = []
+ for i in axis_record:
+ if i >= 0:
+ tmp.append(i)
+ axis_record = tmp
+
+ return axis_record
+
+
+@script
+def _reduce_shape_func(data_shape, axis_record):
+ out = output_tensor((len(axis_record),), "int64")
+ for i in const_range(len(axis_record)):
+ if axis_record[i] >= 0:
+ out[i] = data_shape[axis_record[i]]
+ else:
+ out[i] = int64(1)
+
+ return out
+
+def reduce_shape_func(attrs, inputs, _):
+ """
+ Shape function for reduce op.
+ """
+ axis_record = _create_axis_record(attrs, inputs)
+ return [_reduce_shape_func(inputs[0], convert(axis_record))]
+
+_reg.register_shape_func("argmax", False, reduce_shape_func)
+_reg.register_shape_func("argmin", False, reduce_shape_func)
+_reg.register_shape_func("all", False, reduce_shape_func)
+_reg.register_shape_func("sum", False, reduce_shape_func)
+_reg.register_shape_func("max", False, reduce_shape_func)
+_reg.register_shape_func("min", False, reduce_shape_func)
+_reg.register_shape_func("prod", False, reduce_shape_func)
+_reg.register_shape_func("mean", False, reduce_shape_func)
+_reg.register_shape_func("variance", False, reduce_shape_func)
def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]
-@script
-def _expand_dims_shape_func(x):
- ndim = len(x.shape)
- out = output_tensor((ndim+1,), "int64")
- out[0] = int64(1)
- for i in const_range(0, ndim):
- out[i+1] = int64(x.shape[i])
- return out
-
-def expand_dims_shape_func(attrs, inputs, out_ndims):
- return [_expand_dims_shape_func(*inputs)]
-
# shape func
@script
def _broadcast_shape_func(x, y, ndim):
return out
def broadcast_shape_func(attrs, inputs, out_ndims):
+ """
+ Shape function for broadcast op.
+ """
return [_broadcast_shape_func(*inputs, out_ndims[0])]
-register_shape_func("expand_dims", False, expand_dims_shape_func)
+def elemwise_shape_func(attrs, inputs, _):
+ """
+ Shape function for elemwise op.
+ """
+ return [topi.math.identity(inputs[0])]
+
register_shape_func("cast", False, cast_shape_func)
register_shape_func("add", False, broadcast_shape_func)
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)
+
+register_shape_func("sqrt", False, elemwise_shape_func)
+register_shape_func("negative", False, elemwise_shape_func)
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
-# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks
+# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
from __future__ import absolute_import
import tvm
import topi
output_shape.append(tvm.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]
+
+@script
+def _layout_transform_shape_func(data_shape,
+ out_layout_len,
+ dst_equal_list,
+ dst_mul_list,
+ dst_div_list,
+ dst_mix_list):
+ out = output_tensor((out_layout_len,), "int64")
+ for i in const_range(len(dst_equal_list)):
+ out[dst_equal_list[i][0]] = data_shape[dst_equal_list[i][1]]
+ for i in const_range(len(dst_mul_list)):
+ out[dst_mul_list[i][0]] = data_shape[dst_mul_list[i][1]] * \
+ data_shape[dst_mul_list[i][2]]
+ for i in const_range(len(dst_div_list)):
+ out[dst_div_list[i][0]] = data_shape[dst_div_list[i][1]] \
+ // dst_div_list[i][3]
+ out[dst_div_list[i][2]] = int64(dst_div_list[i][3])
+ for i in const_range(len(dst_mix_list)):
+ out[dst_mix_list[i][0]] = data_shape[dst_mix_list[i][1]] * \
+ dst_mix_list[i][2] // dst_mix_list[i][4]
+ out[dst_mix_list[i][3]] = int64(dst_mix_list[i][4])
+
+ return out
+
+@_reg.register_shape_func("layout_transform", False)
+def layout_transform_shape_func(attrs, inputs, _):
+ """
+ Shape function for layout_transform op.
+ """
+ def _fetch_axis(layout):
+ major_axes = []
+ minor_axes = {}
+ num_start = -1
+ for i, item in enumerate(layout):
+ if "A" <= item <= "Z":
+ major_axes.append(item)
+ elif "a" <= item <= "z":
+ last_num = int(layout[num_start:i])
+ minor_axes[item] = last_num
+ num_start = -1
+ elif num_start < 0:
+ num_start = i
+ return major_axes, minor_axes
+
+ _, src_minor_axes = _fetch_axis(attrs.src_layout)
+ dst_major_axes, dst_minor_axes = _fetch_axis(attrs.dst_layout)
+ src_letter_list = []
+ dst_letter_list = []
+ for item in attrs.src_layout:
+ if "A" <= item <= "Z" or "a" <= item <= "z":
+ src_letter_list.append(item)
+ for item in attrs.dst_layout:
+ if "A" <= item <= "Z" or "a" <= item <= "z":
+ dst_letter_list.append(item)
+ out_layout_len = len(dst_major_axes) + len(dst_minor_axes)
+ dst_equal_list = []
+ dst_mul_list = []
+ dst_div_list = []
+ dst_mix_list = []
+
+ for key in dst_major_axes:
+ if key.lower() not in dst_minor_axes:
+ if key.lower() not in src_minor_axes:
+ dst_equal_list.append((dst_letter_list.index(key),
+ src_letter_list.index(key)))
+ else:
+ dst_mul_list.append((dst_letter_list.index(key),
+ src_letter_list.index(key),
+ src_letter_list.index(key.lower())))
+ else:
+ if key.lower() not in src_minor_axes:
+ dst_div_list.append((dst_letter_list.index(key),
+ src_letter_list.index(key),
+ dst_letter_list.index(key.lower()),
+ dst_minor_axes[key.lower()]))
+ else:
+ dst_mix_list.append((dst_letter_list.index(key),
+ src_letter_list.index(key),
+ src_minor_axes[key.lower()],
+ dst_letter_list.index(key.lower()),
+ dst_minor_axes[key.lower()]))
+
+ return [_layout_transform_shape_func(inputs[0],
+ convert(out_layout_len),
+ convert(dst_equal_list),
+ convert(dst_mul_list),
+ convert(dst_div_list),
+ convert(dst_mix_list))]
+
+@script
+def _expand_dim_shape_func(data_shape, ndim, axis, num_newaxis):
+ out = output_tensor((ndim + num_newaxis,), "int64")
+ for i in const_range(out.shape[0]):
+ if i < axis:
+ out[i] = data_shape[i]
+ elif i < axis + num_newaxis:
+ out[i] = int64(1)
+ else:
+ out[i] = data_shape[i - num_newaxis]
+
+ return out
+
+@_reg.register_shape_func("expand_dims", False)
+def expand_dim_shape_func(attrs, inputs, _):
+ """
+ Shape function for expand_dim op.
+ """
+ axis = get_const_int(attrs.axis)
+ num_newaxis = get_const_int(attrs.num_newaxis)
+ if axis < 0:
+ axis = inputs[0].shape[0] + axis + 1
+ ndim = inputs[0].shape[0] if inputs[0].shape else 0
+ return [_expand_dim_shape_func(inputs[0],
+ convert(ndim),
+ convert(axis),
+ convert(num_newaxis))]
+
+@script
+def _transpose_shape_func(data_shape, axes):
+ out = output_tensor((data_shape.shape[0],), "int64")
+ for i in const_range(len(axes)):
+ out[i] = data_shape[axes[i]]
+
+ return out
+
+@_reg.register_shape_func("transpose", False)
+def transpose_shape_func(attrs, inputs, _):
+ """
+ Shape function for transpose op.
+ """
+ axes = attrs.axes if attrs.axes is None else get_const_tuple(attrs.axes)
+ if axes is None:
+ axes = list(range(inputs[0].shape[0].value))
+ axes.reverse()
+ for i, axis in enumerate(axes):
+ if axis < 0:
+ axes[i] = inputs[0].shape[0] - axis
+ return [_transpose_shape_func(inputs[0], convert(axes))]
+
+@script
+def _squeeze_shape_func(data_shape, keep_axes):
+ out = output_tensor((len(keep_axes),), "int64")
+ if len(keep_axes) == 0:
+ out_size = 0
+ for i in const_range(data_shape.shape[0]):
+ if data_shape[i] != 1:
+ out_size += 1
+
+ if out_size == 0:
+ out_size = 1
+ out = output_tensor((out_size,), "int64")
+ out[0] = int64(1)
+ pos = 0
+ for i in const_range(data_shape.shape[0]):
+ if data_shape[i] != 1:
+ out[pos] = data_shape[i]
+ pos += 1
+ else:
+ for i in const_range(len(keep_axes)):
+ out[i] = data_shape[keep_axes[i]]
+
+ return out
+
+@_reg.register_shape_func("squeeze", False)
+def squeeze_shape_func(attrs, inputs, _):
+ """
+ Shape function for squeeze op.
+ """
+ axis = attrs.axis if attrs.axis is None else get_const_tuple(attrs.axis)
+ keep_axes = []
+ if axis is not None:
+ for i in range(inputs[0].shape[0].value):
+ if i not in axis:
+ keep_axes.append(i)
+
+ return [_squeeze_shape_func(inputs[0], convert(keep_axes))]
+
+@script
+def _reshape_like_shape_func(target_shape):
+ out = output_tensor((target_shape.shape[0],), "int64")
+ for i in const_range(target_shape.shape[0]):
+ out[i] = target_shape[i]
+
+ return out
+
+@_reg.register_shape_func("reshape_like", False)
+def reshape_like_shape_func(attrs, inputs, _):
+ """
+ Shape function for reshape_like op.
+ """
+ return [_reshape_like_shape_func(inputs[1])]
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name, unused-argument
+# pylint: disable=invalid-name, unused-argument, too-many-arguments
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from topi.util import get_const_tuple
from .. import op as reg
from ..op import OpPattern, schedule_injective
+from .._tensor import elemwise_shape_func
+from ....api import convert
+from ....hybrid import script
# relu
reg.register_schedule("nn.relu", schedule_injective)
reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
-
@reg.register_compute("nn.cross_entropy")
def compute_cross_entropy(attrs, inputs, out_dtype, target):
x, y = inputs
reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
-
@reg.register_compute("nn.cross_entropy_with_logits")
def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(x * y) / x.shape[0]]
+
+# shape func
+@script
+def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
+ out = output_tensor((dshape.shape[0],), "int64")
+ ic_chunk = dshape[1]
+ height = dshape[2]
+ width = dshape[3]
+ ic_bn = dshape[4]
+ kheight = kshape[2]
+ kwidth = kshape[3]
+ dilated_kh = (kheight - 1) * dilation[0] + 1
+ dilated_kw = (kwidth - 1) * dilation[1] + 1
+ kflatten = int64(1)
+ for i in const_range(kshape.shape[0]):
+ kflatten *= kshape[i]
+
+ oc = kflatten // (kheight * kwidth * ic_chunk * ic_bn)
+ oc_chunk = oc // oc_bn
+
+ out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1
+ out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1
+
+ out[0] = dshape[0]
+ out[1] = oc_chunk
+ out[2] = out_height
+ out[3] = out_width
+ out[4] = int64(oc_bn)
+ return out
+
+@reg.register_shape_func("nn.contrib_conv2d_NCHWc", False)
+def conv2d_NCHWc_shape_func(attrs, inputs, _):
+ """
+ Shape function for contrib_conv2d_NCHWc op.
+ """
+ strides = get_const_tuple(attrs.strides)
+ padding = get_const_tuple(attrs.padding)
+ dilation = get_const_tuple(attrs.dilation)
+ out_layout = attrs.out_layout
+ oc_bn = int(out_layout[4:-1])
+
+ return [_conv2d_NCHWc_shape_func(inputs[0], inputs[1],
+ convert(strides), convert(padding),
+ convert(dilation), convert(oc_bn))]
+
+@script
+def _pool2d_shape_func(data_shape, pool_size, strides,
+ padding, height_axis, width_axis):
+ out = output_tensor((data_shape.shape[0],), "int64")
+ for i in const_range(data_shape.shape[0]):
+ if i == height_axis:
+ out[i] = (data_shape[i] + padding[0] + padding[2] - pool_size[0]) // strides[0] + 1
+ elif i == width_axis:
+ out[i] = (data_shape[i] + padding[1] + padding[3] - pool_size[1]) // strides[1] + 1
+ else:
+ out[i] = data_shape[i]
+
+ return out
+
+def pool2d_shape_func(attrs, inputs, _):
+ """
+ Shape function for pool2d op.
+ """
+ pool_size = get_const_tuple(attrs.pool_size)
+ strides = get_const_tuple(attrs.strides)
+ padding = get_const_tuple(attrs.padding)
+ layout = attrs.layout
+ height_axis = layout.index("H")
+ width_axis = layout.index("W")
+ if len(padding) == 1:
+ padding = [padding[0]] * 4
+ elif len(padding) == 2:
+ padding = [padding[0], padding[1], padding[0], padding[1]]
+
+ return [_pool2d_shape_func(inputs[0], convert(pool_size),
+ convert(strides), convert(padding),
+ convert(height_axis), convert(width_axis))]
+
+reg.register_shape_func("nn.max_pool2d", False, pool2d_shape_func)
+reg.register_shape_func("nn.avg_pool2d", False, pool2d_shape_func)
+
+@script
+def _global_pool2d_shape_func(data_shape, height_axis, width_axis):
+ out = output_tensor((data_shape.shape[0],), "int64")
+ for i in const_range(out.shape[0]):
+ if i == height_axis or i == width_axis:
+ out[i] = int64(1)
+ else:
+ out[i] = data_shape[i]
+
+ return out
+
+def global_pool2d_shape_func(attrs, inputs, _):
+ """
+ Shape function for global pool2d op.
+ """
+ layout = attrs.layout
+ height_axis = width_axis = 1
+ for i, letter in enumerate(layout):
+ if letter == "H":
+ height_axis = i
+ if letter == "W":
+ width_axis = i
+ return [_global_pool2d_shape_func(inputs[0], convert(height_axis), convert(width_axis))]
+
+reg.register_shape_func("nn.global_max_pool2d", False, global_pool2d_shape_func)
+reg.register_shape_func("nn.global_avg_pool2d", False, global_pool2d_shape_func)
+
+@script
+def _batch_flatten_shape_func(data_shape):
+ out = output_tensor((2,), "int64")
+ out[0] = data_shape[0]
+ out[1] = int64(1)
+ for i in const_range(data_shape.shape[0] - 1):
+ out[1] *= data_shape[i + 1]
+
+ return out
+
+@reg.register_shape_func("nn.batch_flatten", False)
+def batch_flatten_shape_func(attrs, inputs, _):
+ """
+ Shape function for batch_flatten op.
+ """
+ return [_batch_flatten_shape_func(inputs[0])]
+
+@script
+def _dense_shape_func(data_shape, weight_shape):
+ out = output_tensor((data_shape.shape[0],), "int64")
+ for i in const_range(out.shape[0] - 1):
+ out[i] = data_shape[i]
+ out[out.shape[0] - 1] = weight_shape[0]
+
+ return out
+
+@reg.register_shape_func("nn.dense", False)
+def dense_shape_func(attrs, inputs, _):
+ """
+ Shape function for dense op.
+ """
+ ret = [_dense_shape_func(inputs[0], inputs[1])]
+ return ret
+
+@script
+def _pad_shape_func(data_shape, pad_width):
+ out = output_tensor((data_shape.shape[0],), "int64")
+ for i in const_range(out.shape[0]):
+ out[i] = data_shape[i] + pad_width[i][0] + pad_width[i][1]
+
+ return out
+
+@reg.register_shape_func("nn.pad", False)
+def pad_shape_func(attrs, inputs, _):
+ """
+ Shape function for pad op.
+ """
+ pad_width = []
+ for pair in attrs.pad_width:
+ pad_width.append(get_const_tuple(pair))
+ return [_pad_shape_func(inputs[0], convert(pad_width))]
+
+reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
+reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
+reg.register_shape_func("nn.relu", False, elemwise_shape_func)
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
// e.g., (C * 16 + c) / 32
std::unordered_map<const Variable*, Expr> bind_map;
+ std::unordered_set<size_t> symbolic_var_set;
for (size_t i = 0; i < src_shape.size(); ++i) {
Expr orig_shape = src_shape[i];
IterVar orig_axis = src_axis[i];
+ if (orig_shape.as<ir::Any>()) {
+ symbolic_var_set.insert(i);
+ }
if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
if (orig_shape.defined()) {
const auto* orig_shape_const = orig_shape.as<IntImm>();
const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImm>();
- CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
- << "Input shape mismatch at index " << i << ". Expected "
- << orig_axis->dom->extent << ", get " << orig_shape;
+ if (orig_shape_const) {
+ CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
+ << "Input shape mismatch at index " << i << ". Expected "
+ << orig_axis->dom->extent << ", get " << orig_shape;
+ }
}
bind_map[orig_axis->var.get()] = Expr(0);
} else {
if (!LayoutAxis::Get(axis).IsPrimal()) {
result.push_back(axis->dom->extent);
} else {
- result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
+ if (symbolic_var_set.count(i)) {
+ result.push_back(ir::Any::make());
+ } else {
+ result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
+ }
}
}
return result;
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
- oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
- oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
+ if (!dshape_nchw[2].as<ir::Any>()) {
+ oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2
+ - dilated_ksize_y) / param->strides[0] + 1);
+ } else {
+ oshape.Set(2, dshape_nchw[2]);
+ }
+ if (!dshape_nchw[3].as<ir::Any>()) {
+ oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2
+ - dilated_ksize_x) / param->strides[1] + 1);
+ } else {
+ oshape.Set(3, dshape_nchw[3]);
+ }
+
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
- oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y,
- param->strides[0]) + 1);
- oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x,
- param->strides[1]) + 1);
+ if (!dshape_nchw[2].as<ir::Any>()) {
+ oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y,
+ param->strides[0]) + 1);
+ } else {
+ oshape.Set(2, dshape_nchw[2]);
+ }
+
+ if (!dshape_nchw[3].as<ir::Any>()) {
+ oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x,
+ param->strides[1]) + 1);
+ } else {
+ oshape.Set(3, dshape_nchw[3]);
+ }
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
auto target_dim = make_const(Int(32), 1);
for (uint32_t i = 1; i < data->shape.size(); ++i) {
- target_dim = target_dim * data->shape[i];
+ if (!data->shape[i].as<ir::Any>()) {
+ target_dim = target_dim * data->shape[i];
+ } else {
+ target_dim = data->shape[i];
+ break;
+ }
}
std::vector<IndexExpr> oshape({data->shape[0], target_dim});
<< "Param width elements should be positive but first pad width at "
<< "index " << i << " is " << *width2 << ".";
- auto padding = make_const(data->shape[i].type(), *width1 + *width2);
- oshape.push_back(data->shape[i] + padding);
+ if (!data->shape[i].as<ir::Any>()) {
+ auto padding = make_const(data->shape[i].type(), *width1 + *width2);
+ oshape.push_back(data->shape[i] + padding);
+ } else {
+ oshape.push_back(data->shape[i]);
+ }
}
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
oshape.push_back(e);
}
- if (param->ceil_mode) {
- oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] +
- param->strides[0] - 1) / param->strides[0]) + 1;
- oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] +
- param->strides[1] - 1) / param->strides[1]) + 1;
+ if (dshape[hidx].as<ir::Any>()) {
+ oshape[hidx] = dshape[hidx];
} else {
- oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1;
- oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1;
+ if (param->ceil_mode) {
+ oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] +
+ param->strides[0] - 1) / param->strides[0]) + 1;
+ } else {
+ oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1;
+ }
+ }
+ if (dshape[widx].as<ir::Any>()) {
+ oshape[widx] = dshape[widx];
+ } else {
+ if (param->ceil_mode) {
+ oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] +
+ param->strides[1] - 1) / param->strides[1]) + 1;
+ } else {
+ oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1;
+ }
}
// assign output type
}
auto max_shape = make_const(Int(64), 1);
+ bool is_dynamic_input = false;
for (int64_t axis : r_axes) {
- max_shape *= in_shape[axis];
+ if (in_shape[axis].as<IntImm>()) {
+ max_shape *= in_shape[axis];
+ } else {
+ is_dynamic_input = true;
+ break;
+ }
+ }
+
+ if (is_dynamic_input) {
+ CHECK(reporter->Assert(max_shape < make_const(Int(64), std::numeric_limits<int32_t>::max())))
+ << "The maximum possible index of reduced shape cannot be more than int32 max.";
}
- CHECK(reporter->Assert(max_shape < make_const(Int(64), std::numeric_limits<int32_t>::max())))
- << "The maximum possible index of reduced shape cannot be more than int32 max.";
if (param->keepdims) {
std::vector<IndexExpr> oshape(in_shape);
if (reshape_like == nullptr) {
return false;
}
- CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
- << "Reshape inputs size should be compatible.";
+ // Only check When input data has static shape.
+ bool is_static_shape = true;
+ for (size_t i = 0; i < data->shape.size(); ++i) {
+ if (!data->shape[i].as<IntImm>()) {
+ is_static_shape = false;
+ break;
+ }
+ }
+ if (is_static_shape) {
+ CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
+ << "Reshape inputs size should be compatible.";
+ }
reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype));
return true;
}
.set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.layout_transform
+TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
+
Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
result = ex.evaluate()(data)
tvm.testing.assert_allclose(result.asnumpy(), np.array(3).astype("int64"))
+def verify_any_reduce(reduce_op, data_shape, axis, exclude, keepdims,
+ static_data_shape, ref_out_shape):
+ mod = relay.Module()
+ dtype = "bool" if reduce_op == relay.all else "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = reduce_op(data, axis, keepdims, exclude)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_reduce():
+ verify_any_reduce(relay.argmax, any_dims(3), None, False, False, (3, 4, 5), ())
+ verify_any_reduce(relay.argmin, any_dims(4), 1, False, True, (3, 4, 5, 6), (3, 1, 5, 6))
+ verify_any_reduce(relay.all, any_dims(3), (1, 2), True, False, (3, 4, 5), (4, 5))
+ verify_any_reduce(relay.max, any_dims(4), -1, True, True, (3, 4, 5, 6), (1, 1, 1, 6))
+ verify_any_reduce(relay.min, any_dims(3), (0, 1), False, False, (4, 5, 6), (6,))
+ verify_any_reduce(relay.prod, any_dims(4), 2, True, True, (3, 4, 5, 6), (1, 1, 5, 1))
+ verify_any_reduce(relay.mean, any_dims(2), 0, False, False, (1, 2), (2,))
+ verify_any_reduce(relay.variance, any_dims(5), (2, 4), False, False, (3, 4, 5, 6, 7), (3, 4, 6))
+
+def verify_any_layout_transform(data_shape, src_layout, dst_layout, static_data_shape, ref_out_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = relay.layout_transform(data, src_layout, dst_layout)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_layout_transform():
+ verify_any_layout_transform(any_dims(4), "NCHW", "NHWC", (3, 4, 5, 6), (3, 5, 6, 4))
+ verify_any_layout_transform(any_dims(5), "NCHW16c", "NCHW2c", (1, 2, 8, 8, 16), (1, 16, 8, 8, 2))
+ verify_any_layout_transform(any_dims(5), "NCHW6n", "NHWC", (3, 4, 5, 6, 6), (18, 5, 6, 4))
+ verify_any_layout_transform(any_dims(4), "NCHW", "NCHW4c", (3, 4, 5, 6), (3, 1, 5, 6, 4))
+ verify_any_layout_transform((16, 1), "CH", "C4cH", (16, 1), (4, 4, 1))
+
+def verify_any_expand_dims(data_shape, axis, num_newaxis, static_data_shape, ref_out_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = relay.expand_dims(data, axis=axis, num_newaxis=num_newaxis)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_expand_dims():
+ verify_any_expand_dims(any_dims(3), 1, 2, (1, 2, 3), (1, 1, 1, 2, 3))
+ verify_any_expand_dims(any_dims(3), -1, 2, (1, 2, 3), (1, 2, 3, 1, 1))
+
+def verify_any_transpose(data_shape, axes, static_data_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = relay.transpose(data, axes=axes)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ ref_out = np.transpose(data_np, axes)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ tvm.testing.assert_allclose(result.asnumpy(), ref_out)
+
+def test_any_transpose():
+ verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2))
+ verify_any_transpose(any_dims(3), None, (2, 3, 4))
+ verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17))
+
+def verify_any_squeeze(data_shape, axis, static_data_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = relay.squeeze(data, axis=axis)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ ref_out = np.squeeze(data_np, axis)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ tvm.testing.assert_allclose(result.asnumpy(), ref_out)
+
+def test_any_squeeze():
+ verify_any_squeeze((1, relay.Any(), relay.Any()), (0,), (1, 9, 8))
+ verify_any_squeeze((1, relay.Any(), relay.Any(), 1, relay.Any(), relay.Any()), (0, 3), (1, 12, 2, 1, 9, 17))
+
+def test_any_reshape_like():
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=(relay.Any(), 3, 10), dtype=dtype)
+ shape_like = relay.var('data', shape=(relay.Any(), 5, 6), dtype=dtype)
+ y = relay.reshape_like(data, shape_like)
+ mod["main"] = relay.Function([data, shape_like], y)
+ data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype)
+ shape_like_np = np.random.uniform(size=(3, 5, 6)).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np, shape_like_np)
+ assert result.asnumpy().shape == shape_like_np.shape, \
+ "Shape mismatch: expect %s but got %s." % (str(shape_like_np.shape), str(result.asnumpy().shape))
+
+def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation,
+ data_layout, kernel_layout, out_layout,
+ static_data_shape, ref_out_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ kernel = relay.var('kernel', shape=kernel_shape, dtype=dtype)
+ y = relay.nn.contrib_conv2d_nchwc(data, kernel, strides, padding, dilation,
+ kernel_size=kernel_shape[2:4],
+ channels=kernel_shape[0]*kernel_shape[-1],
+ data_layout=data_layout, kernel_layout=kernel_layout,
+ out_layout=out_layout)
+ mod["main"] = relay.Function([data, kernel], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np, kernel_np)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_conv2d_NCHWc():
+ verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (1, 1),
+ "NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 224, 224, 8))
+ verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (2, 2),
+ "NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 222, 222, 8))
+
+def verify_any_pool2d(pool_type, data_shape, pool_size, strides, padding,
+ layout, static_data_shape, ref_out_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ pool_func = relay.nn.max_pool2d if pool_type == "max" else relay.nn.avg_pool2d
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = pool_func(data, pool_size, strides, padding, layout)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_pool2d():
+ verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()),
+ (3, 3), (1, 1), (1, 1), "NCHW", (2, 3, 220, 220), (2, 3, 220, 220))
+ verify_any_pool2d("avg", (relay.Any(), relay.Any(), relay.Any(), 4),
+ (1, 1), (2, 2), (0, 0), "NHWC", (3, 220, 220, 4), (3, 110, 110, 4))
+ verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4),
+ (3, 3), (2, 2), (1, 1), "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 110, 110, 4))
+
+def verify_any_global_pool2d(pool_type, data_shape, layout, static_data_shape, ref_out_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ pool_func = relay.nn.global_max_pool2d if pool_type == "max" else relay.nn.global_avg_pool2d
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = pool_func(data, layout)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_global_pool2d():
+ verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()),
+ "NCHW", (2, 3, 220, 220), (2, 3, 1, 1))
+ verify_any_global_pool2d("avg", (relay.Any(), relay.Any(), relay.Any(), 4),
+ "NHWC", (3, 220, 220, 4), (3, 1, 1, 4))
+ verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4),
+ "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4))
+
+def test_any_batch_flatten():
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=any_dims(3), dtype=dtype)
+ y = relay.nn.batch_flatten(data)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype)
+ ref_out_shape = (3, 30)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def verify_any_dense(data_shape, weight_shape, units, static_data_shape,
+ static_weight_shape, ref_out_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ weight = relay.var('weight', shape=weight_shape, dtype=dtype)
+ y = relay.nn.dense(data, weight, units)
+ mod["main"] = relay.Function([data, weight], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ weight_np = np.random.uniform(size=static_weight_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np, weight_np)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_dense():
+ verify_any_dense(any_dims(2), any_dims(2), None, (4, 16), (8, 16), (4, 8))
+ verify_any_dense(any_dims(2), (50, relay.Any()), 50, (4, 40), (50, 40), (4, 50))
+
+def verify_any_pad(data_shape, pad_width, static_data_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = relay.nn.pad(data, pad_width)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ ref_out = np.pad(data_np, pad_width)
+ tvm.testing.assert_allclose(result.asnumpy(), ref_out)
+
+def test_any_pad():
+ verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3))
+ verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1))
+
+def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape):
+ mod = relay.Module()
+ dtype = "float32"
+ data = relay.var('data', shape=data_shape, dtype=dtype)
+ y = relay.nn.softmax(data, axis)
+ mod["main"] = relay.Function([data], y)
+ data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np)
+ assert result.asnumpy().shape == ref_out_shape, \
+ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
+
+def test_any_softmax():
+ verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3))
+ verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1))
+
def test_fused_ops():
x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32')
y0 = x + relay.const(1.0, 'float32')
test_any_reshape()
test_any_take()
test_any_shape_of()
+ test_any_reduce()
+ test_any_layout_transform()
+ test_any_expand_dims()
+ test_any_transpose()
+ test_any_squeeze()
+ test_any_reshape_like()
+ test_any_conv2d_NCHWc()
+ test_any_pool2d()
+ test_any_global_pool2d()
+ test_any_batch_flatten()
+ test_any_dense()
+ test_any_pad()
+ test_any_softmax()
test_fused_ops()
test_arange_with_dynamic_shape()
test_recursive_concat()
std::string name = "tensor",
std::string tag = kInjective) {
auto ishape = x->shape;
- int dim = 1;
+ Expr dim = 1;
for (size_t i = 1; i < ishape.size(); ++i) {
- dim = dim * static_cast<int>(topi::detail::GetConstInt(ishape[i]));
+ dim = dim * ishape[i];
}
Array<Expr> oshape({ ishape[0], dim });
def get_const_tuple(in_tuple):
- """Verifies input tuple is IntImm, returns tuple of int.
+ """Verifies input tuple is IntImm or Var, returns tuple of int or Var.
Parameters
----------
out_tuple : tuple of int
The output.
"""
- return tuple(get_const_int(elem) for elem in in_tuple)
+ ret = []
+ for elem in in_tuple:
+ if isinstance(elem, tvm.expr.Var):
+ ret.append(elem)
+ elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)):
+ elem = tvm.ir_pass.Simplify(elem)
+ if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+ ret.append(elem)
+ else:
+ ret.append(get_const_int(elem))
+ return tuple(ret)
def get_float_tuple(in_tuple):
"""
Get default schedule config for the workload
"""
+ static_data_shape = []
+ for dim in get_const_tuple(data.shape):
+ if isinstance(dim, tvm.expr.Var):
+ static_data_shape.append(1)
+ else:
+ static_data_shape.append(dim)
+ data = tvm.placeholder(static_data_shape, dtype=data.dtype)
if is_depthwise:
wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
from .depthwise_conv2d import _fallback_schedule
return C
M, _ = get_const_tuple(data.shape)
+ # Always use dense_nopack for dynamic input.
+ # This is a temporary for CV models.
+ # TODO(kevinthesun): use kernel dispatcher instead.
+ if isinstance(M, tvm.expr.Var):
+ return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype)
+
# For small batch sizes, don't pack weight into cache-friendly layout
# because of overhead in packing and limited reuse from batch dimension
# TODO(icemelon9): use a more systematic way to determine which schedule to use
M, K = get_const_tuple(data.shape) # batch, in_dim
N, _ = get_const_tuple(weight.shape) # out_dim
# create tuning space
- cfg.define_split("tile_y", M, num_outputs=3)
- cfg.define_split("tile_x", N, num_outputs=3)
- cfg.define_split("tile_k", K, num_outputs=2)
+ cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2)
+ cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2)
+ cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2)
if cfg.is_fallback:
_default_dense_pack_config(cfg, M, N, K)
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
# create tuning space
- cfg.define_split("tile_y", M, num_outputs=2)
- cfg.define_split("tile_x", N, num_outputs=2)
- cfg.define_split("tile_k", K, num_outputs=2)
+ cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2)
+ cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2)
+ cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2)
if cfg.is_fallback:
_default_dense_nopack_config(cfg, M, N, K)
def _default_dense_pack_config(cfg, M, N, K):
- vec_width = get_fp32_len()
+ # Generate default schedule for dynamic shape.
+ if isinstance(M, tvm.expr.Var):
+ M = 16
+ if isinstance(N, tvm.expr.Var):
+ N = 16
+ if isinstance(K, tvm.expr.Var):
+ K = 16
+ vec_width = get_fp32_len()
tilex_ii = 1
for bn in range(vec_width*2, 0, -1):
if N % bn == 0:
def _default_dense_nopack_config(cfg, M, N, K):
+ # Generate default schedule for dynamic shape.
+ if isinstance(M, tvm.expr.Var):
+ M = 16
+ if isinstance(N, tvm.expr.Var):
+ N = 16
+ if isinstance(K, tvm.expr.Var):
+ K = 16
+
vec_width = get_fp32_len()
tilek_bn = 1
for bn in range(vec_width*2, 0, -1):