_CLASS_NODE_BASE = cls
+def _scalar_type_inference(value):
+ if hasattr(value, 'dtype'):
+ dtype = str(value.dtype)
+ elif isinstance(value, bool):
+ dtype = 'bool'
+ elif isinstance(value, float):
+ # We intentionally convert the float to float32 since it's more common in DL.
+ dtype = 'float32'
+ elif isinstance(value, int):
+ # We intentionally convert the python int to int32 since it's more common in DL.
+ dtype = 'int32'
+ else:
+ raise NotImplementedError('Cannot automatically inference the type.'
+ ' value={}'.format(value))
+ return dtype
+
+
class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
value : int or float
The input value
- dtype : str
+ dtype : str or None, optional
The data type.
Returns
Constant expression corresponds to the value.
"""
if dtype is None:
- if isinstance(value, Integral):
- dtype = 'int32'
- else:
- dtype = 'float32'
+ dtype = _scalar_type_inference(value)
return _api_internal._const(value, dtype)
from ._ffi.base import string_types
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
+from ._ffi.node_generic import _scalar_type_inference
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
return _api_internal._max_value(dtype)
-def const(value, dtype):
+def const(value, dtype=None):
"""construct a constant
Parameters
value : number
The content of the constant number.
- dtype : str
+ dtype : str or None, optional
The data type.
Returns
const_val: tvm.Expr
The result expression.
"""
+ if dtype is None:
+ dtype = _scalar_type_inference(value)
return _api_internal._const(value, dtype)
# specific language governing permissions and limitations
# under the License.
import tvm
+import numpy as np
def test_const():
x = tvm.const(1, "int32")
assert x.dtype == tvm.int32
assert isinstance(x, tvm.expr.IntImm)
+
+def test_scalar_dtype_inference():
+ for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
+ np.int8(1), np.int16(1), np.int32(1), np.int64(1),
+ np.float16(1), np.float32(1), np.float64(1)]:
+ assert tvm.const(data).dtype == str(np.array(data).dtype)
+ assert tvm.const(1).dtype == 'int32'
+ assert tvm.const(1.0).dtype == 'float32'
+
+ for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
+ np.int8(1), np.int16(1), np.int32(1), np.int64(1),
+ np.float16(1), np.float32(1), np.float64(1)]:
+ assert tvm.convert(data).dtype == str(np.array(data).dtype)
+ assert tvm.convert(1).dtype == 'int32'
+ assert tvm.convert(1.0).dtype == 'float32'
+
def test_make():
x = tvm.const(1, "int32")
y = tvm.var("x")
test_cast()
test_attr()
test_const()
+ test_scalar_dtype_inference()
test_make()
test_ir()
test_basic()
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
- out_shape.push_back(shape[0]);
- out_shape.push_back(shape[1]);
+ out_shape.push_back(cast(Int(32), shape[0]));
+ out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[3]);
return compute(
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
- out_shape.push_back(shape[0]);
- out_shape.push_back(shape[1]);
+ out_shape.push_back(cast(Int(32), shape[0]));
+ out_shape.push_back(cast(Int(32), shape[1]));
return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
- out_shape.push_back(shape[0]);
- out_shape.push_back(shape[1]);
+ out_shape.push_back(cast(Int(32), shape[0]));
+ out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[4]);
return compute(
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
- out_shape.push_back(shape[0]);
- out_shape.push_back(shape[1]);
+ out_shape.push_back(cast(Int(32), shape[0]));
+ out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[3]);
Expr cone = make_const(Int(32), 1);
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
- out_shape.push_back(shape[0]);
- out_shape.push_back(shape[1]);
+ out_shape.push_back(cast(Int(32), shape[0]));
+ out_shape.push_back(cast(Int(32), shape[1]));
Expr cone = make_const(Int(32), 1);
CHECK_GE(pad_before.size(), 1);
CHECK_EQ(pad_before.size(), pad_after.size());
tvm::Array<tvm::Expr> output_shape;
+ tvm::Array<tvm::Expr> pad_before_int32;
+ tvm::Array<tvm::Expr> pad_after_int32;
+ for (const auto &ele : pad_before) {
+ pad_before_int32.push_back(tvm::cast(tvm::Int(32), ele));
+ }
+ for (const auto &ele : pad_after) {
+ pad_after_int32.push_back(tvm::cast(tvm::Int(32), ele));
+ }
for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before.size()) {
output_shape.push_back(t->shape[i]);
} else {
output_shape.push_back(
- tvm::ir::Simplify(t->shape[i] + pad_before[i] + pad_after[i]));
+ tvm::ir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
}
}
tvm::Array<tvm::Expr> indices;
tvm::Array<tvm::Expr> sel;
for (size_t i = 0; i < t->shape.size(); ++i) {
- if (i >= pad_before.size()) {
+ if (i >= pad_before_int32.size()) {
indices.push_back(ovars[i]);
continue;
}
- if (!topi::detail::EqualCheck(pad_before[i], 0)) {
- sel.push_back(ovars[i] >= pad_before[i]);
- indices.push_back(ovars[i] - pad_before[i]);
+ if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
+ sel.push_back(ovars[i] >= pad_before_int32[i]);
+ indices.push_back(ovars[i] - pad_before_int32[i]);
} else {
indices.push_back(ovars[i]);
}
- if (!topi::detail::EqualCheck(pad_after[i], 0)) {
- sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before[i] + t->shape[i]));
+ if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
+ sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
}
}
if (sel.size() != 0) {
Array<Expr> out_shape;
for (size_t i = 0; i < n; ++i) {
out_shape.push_back(tvm::ir::Simplify(
- (x->shape[i] - 1) * strides[i] + 1));
+ (x->shape[i] - 1) * cast(Int(32), strides[i] + 1)));
}
return tvm::compute(
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
- auto kernel_height = kernel_size[0];
- auto kernel_width = kernel_size[1];
- auto stride_height = stride_size[0];
- auto stride_width = stride_size[1];
+ auto kernel_height = cast(Int(32), kernel_size[0]);
+ auto kernel_width = cast(Int(32), kernel_size[1]);
+ auto stride_height = cast(Int(32), stride_size[0]);
+ auto stride_width = cast(Int(32), stride_size[1]);
auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
- auto pad_top = padding_size[0];
- auto pad_left = padding_size[1];
- auto pad_bottom = padding_size[2];
- auto pad_right = padding_size[3];
+ auto pad_top = cast(Int(32), padding_size[0]);
+ auto pad_left = cast(Int(32), padding_size[1]);
+ auto pad_bottom = cast(Int(32), padding_size[2]);
+ auto pad_right = cast(Int(32), padding_size[3]);
if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
- auto kernel_height = kernel_size[0];
- auto kernel_width = kernel_size[1];
- auto stride_height = stride_size[0];
- auto stride_width = stride_size[1];
+ auto kernel_height = cast(Int(32), kernel_size[0]);
+ auto kernel_width = cast(Int(32), kernel_size[1]);
+ auto stride_height = cast(Int(32), stride_size[0]);
+ auto stride_width = cast(Int(32), stride_size[1]);
auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
- auto pad_top = padding_size[0];
- auto pad_left = padding_size[1];
- auto pad_bottom = padding_size[2];
- auto pad_right = padding_size[3];
+ auto pad_top = cast(Int(32), padding_size[0]);
+ auto pad_left = cast(Int(32), padding_size[1]);
+ auto pad_bottom = cast(Int(32), padding_size[2]);
+ auto pad_right = cast(Int(32), padding_size[3]);
if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
- auto out_height = output_size[0];
- auto out_width = output_size[1];
+ auto out_height = cast(Int(32), output_size[0]);
+ auto out_width = cast(Int(32), output_size[1]);
Array<Expr> out_shape = x->shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
std::string name = "T_reshape",
std::string tag = kInjective) {
auto x_shape = x->shape;
+ Array<Expr> newshape_int32;
+
+ for (const auto &ele : newshape) {
+ newshape_int32.push_back(cast(Int(32), ele));
+ }
return compute(
- newshape, [&](const Array<Var>& indices) {
- return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape),
+ newshape_int32, [&](const Array<Var>& indices) {
+ return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape_int32),
x_shape));
}, name, tag);
}