Numpy compatible dtype inference for `tvm.convert` and `tvm.const` (#3861)
authorXingjian Shi <xshiab@connect.ust.hk>
Mon, 9 Sep 2019 17:26:34 +0000 (10:26 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Mon, 9 Sep 2019 17:26:34 +0000 (01:26 +0800)
* numpy compatible type inference

* update

* try to fix

* fix

* try to fix

* fix lint

* Update nn.h

* cast to int32

* try to fix

* fix again

* retrigger ci

python/tvm/_ffi/node_generic.py
python/tvm/api.py
tests/python/unittest/test_lang_basic.py
topi/include/topi/image/resize.h
topi/include/topi/nn.h
topi/include/topi/nn/dilate.h
topi/include/topi/nn/pooling.h
topi/include/topi/transform.h

index ba5ccd0..e898126 100644 (file)
@@ -30,6 +30,23 @@ def _set_class_node_base(cls):
     _CLASS_NODE_BASE = cls
 
 
+def _scalar_type_inference(value):
+    if hasattr(value, 'dtype'):
+        dtype = str(value.dtype)
+    elif isinstance(value, bool):
+        dtype = 'bool'
+    elif isinstance(value, float):
+        # We intentionally convert the float to float32 since it's more common in DL.
+        dtype = 'float32'
+    elif isinstance(value, int):
+        # We intentionally convert the python int to int32 since it's more common in DL.
+        dtype = 'int32'
+    else:
+        raise NotImplementedError('Cannot automatically inference the type.'
+                                  ' value={}'.format(value))
+    return dtype
+
+
 class NodeGeneric(object):
     """Base class for all classes that can be converted to node."""
     def asnode(self):
@@ -86,7 +103,7 @@ def const(value, dtype=None):
     value : int or float
         The input value
 
-    dtype : str
+    dtype : str or None, optional
         The data type.
 
     Returns
@@ -95,8 +112,5 @@ def const(value, dtype=None):
         Constant expression corresponds to the value.
     """
     if dtype is None:
-        if isinstance(value, Integral):
-            dtype = 'int32'
-        else:
-            dtype = 'float32'
+        dtype = _scalar_type_inference(value)
     return _api_internal._const(value, dtype)
index cbc3459..6900742 100644 (file)
@@ -23,6 +23,7 @@ from numbers import Integral as _Integral
 from ._ffi.base import string_types
 from ._ffi.node import register_node, NodeBase
 from ._ffi.node import convert_to_node as _convert_to_node
+from ._ffi.node_generic import _scalar_type_inference
 from ._ffi.function import Function
 from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
 from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
@@ -73,7 +74,7 @@ def max_value(dtype):
     return _api_internal._max_value(dtype)
 
 
-def const(value, dtype):
+def const(value, dtype=None):
     """construct a constant
 
     Parameters
@@ -81,7 +82,7 @@ def const(value, dtype):
     value : number
         The content of the constant number.
 
-    dtype : str
+    dtype : str or None, optional
         The data type.
 
     Returns
@@ -89,6 +90,8 @@ def const(value, dtype):
     const_val: tvm.Expr
         The result expression.
     """
+    if dtype is None:
+        dtype = _scalar_type_inference(value)
     return _api_internal._const(value, dtype)
 
 
index 0ace220..7df92ed 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
+import numpy as np
 
 def test_const():
     x = tvm.const(1, "int32")
@@ -22,6 +23,22 @@ def test_const():
     assert x.dtype == tvm.int32
     assert isinstance(x, tvm.expr.IntImm)
 
+
+def test_scalar_dtype_inference():
+    for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
+                 np.int8(1), np.int16(1), np.int32(1), np.int64(1),
+                 np.float16(1), np.float32(1), np.float64(1)]:
+        assert tvm.const(data).dtype == str(np.array(data).dtype)
+    assert tvm.const(1).dtype == 'int32'
+    assert tvm.const(1.0).dtype == 'float32'
+
+    for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
+                 np.int8(1), np.int16(1), np.int32(1), np.int64(1),
+                 np.float16(1), np.float32(1), np.float64(1)]:
+        assert tvm.convert(data).dtype == str(np.array(data).dtype)
+    assert tvm.convert(1).dtype == 'int32'
+    assert tvm.convert(1.0).dtype == 'float32'
+
 def test_make():
     x = tvm.const(1, "int32")
     y = tvm.var("x")
@@ -175,6 +192,7 @@ if __name__ == "__main__":
     test_cast()
     test_attr()
     test_const()
+    test_scalar_dtype_inference()
     test_make()
     test_ir()
     test_basic()
index e44f5a7..3a5efba 100644 (file)
@@ -97,8 +97,8 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
                                            std::string tag = kInjective) {
   Array<Expr> out_shape;
   out_shape.push_back(input->shape[0]);
-  out_shape.push_back(shape[0]);
-  out_shape.push_back(shape[1]);
+  out_shape.push_back(cast(Int(32), shape[0]));
+  out_shape.push_back(cast(Int(32), shape[1]));
   out_shape.push_back(input->shape[3]);
 
   return compute(
@@ -132,8 +132,8 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
   Array<Expr> out_shape;
   out_shape.push_back(input->shape[0]);
   out_shape.push_back(input->shape[1]);
-  out_shape.push_back(shape[0]);
-  out_shape.push_back(shape[1]);
+  out_shape.push_back(cast(Int(32), shape[0]));
+  out_shape.push_back(cast(Int(32), shape[1]));
 
   return compute(
     out_shape, [&](const Array<Var>& indices) {
@@ -166,8 +166,8 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
   Array<Expr> out_shape;
   out_shape.push_back(input->shape[0]);
   out_shape.push_back(input->shape[1]);
-  out_shape.push_back(shape[0]);
-  out_shape.push_back(shape[1]);
+  out_shape.push_back(cast(Int(32), shape[0]));
+  out_shape.push_back(cast(Int(32), shape[1]));
   out_shape.push_back(input->shape[4]);
 
   return compute(
@@ -233,8 +233,8 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
                                    std::string tag = kInjective) {
   Array<Expr> out_shape;
   out_shape.push_back(input->shape[0]);
-  out_shape.push_back(shape[0]);
-  out_shape.push_back(shape[1]);
+  out_shape.push_back(cast(Int(32), shape[0]));
+  out_shape.push_back(cast(Int(32), shape[1]));
   out_shape.push_back(input->shape[3]);
 
   Expr cone = make_const(Int(32), 1);
@@ -311,8 +311,8 @@ inline Tensor resize_bilinear_nchw(const Tensor& input,
   Array<Expr> out_shape;
   out_shape.push_back(input->shape[0]);
   out_shape.push_back(input->shape[1]);
-  out_shape.push_back(shape[0]);
-  out_shape.push_back(shape[1]);
+  out_shape.push_back(cast(Int(32), shape[0]));
+  out_shape.push_back(cast(Int(32), shape[1]));
 
   Expr cone = make_const(Int(32), 1);
 
index b91bfe0..463f232 100644 (file)
@@ -182,12 +182,20 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
   CHECK_GE(pad_before.size(), 1);
   CHECK_EQ(pad_before.size(), pad_after.size());
   tvm::Array<tvm::Expr> output_shape;
+  tvm::Array<tvm::Expr> pad_before_int32;
+  tvm::Array<tvm::Expr> pad_after_int32;
+  for (const auto &ele : pad_before) {
+    pad_before_int32.push_back(tvm::cast(tvm::Int(32), ele));
+  }
+  for (const auto &ele : pad_after) {
+    pad_after_int32.push_back(tvm::cast(tvm::Int(32), ele));
+  }
   for (size_t i = 0; i < t->shape.size(); ++i) {
     if (i >= pad_before.size()) {
       output_shape.push_back(t->shape[i]);
     } else {
       output_shape.push_back(
-          tvm::ir::Simplify(t->shape[i] + pad_before[i] + pad_after[i]));
+          tvm::ir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
     }
   }
 
@@ -199,18 +207,18 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
     tvm::Array<tvm::Expr> indices;
     tvm::Array<tvm::Expr> sel;
     for (size_t i = 0; i < t->shape.size(); ++i) {
-      if (i >= pad_before.size()) {
+      if (i >= pad_before_int32.size()) {
         indices.push_back(ovars[i]);
         continue;
       }
-      if (!topi::detail::EqualCheck(pad_before[i], 0)) {
-        sel.push_back(ovars[i] >= pad_before[i]);
-        indices.push_back(ovars[i] - pad_before[i]);
+      if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
+        sel.push_back(ovars[i] >= pad_before_int32[i]);
+        indices.push_back(ovars[i] - pad_before_int32[i]);
       } else {
         indices.push_back(ovars[i]);
       }
-      if (!topi::detail::EqualCheck(pad_after[i], 0)) {
-        sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before[i] + t->shape[i]));
+      if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
+        sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
       }
     }
     if (sel.size() != 0) {
index d9287cd..c1020b1 100644 (file)
@@ -77,7 +77,7 @@ inline Tensor dilate(const Tensor& x,
   Array<Expr> out_shape;
   for (size_t i = 0; i < n; ++i) {
     out_shape.push_back(tvm::ir::Simplify(
-      (x->shape[i] - 1) * strides[i] + 1));
+      (x->shape[i] - 1) * cast(Int(32), strides[i] + 1)));
   }
 
   return tvm::compute(
index 2eff244..27d4045 100644 (file)
@@ -73,18 +73,18 @@ inline Tensor pool_impl(const Tensor& x,
   CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
   CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
 
-  auto kernel_height = kernel_size[0];
-  auto kernel_width = kernel_size[1];
-  auto stride_height = stride_size[0];
-  auto stride_width = stride_size[1];
+  auto kernel_height = cast(Int(32), kernel_size[0]);
+  auto kernel_width = cast(Int(32), kernel_size[1]);
+  auto stride_height = cast(Int(32), stride_size[0]);
+  auto stride_width = cast(Int(32), stride_size[1]);
 
   auto height = x->shape[height_axis];
   auto width = x->shape[width_axis];
 
-  auto pad_top = padding_size[0];
-  auto pad_left = padding_size[1];
-  auto pad_bottom = padding_size[2];
-  auto pad_right = padding_size[3];
+  auto pad_top = cast(Int(32), padding_size[0]);
+  auto pad_left = cast(Int(32), padding_size[1]);
+  auto pad_bottom = cast(Int(32), padding_size[2]);
+  auto pad_right = cast(Int(32), padding_size[3]);
 
   if (ceil_mode) {
     // Additional padding to ensure we do ceil instead of floor when
@@ -179,18 +179,18 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
   CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
   CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
 
-  auto kernel_height = kernel_size[0];
-  auto kernel_width = kernel_size[1];
-  auto stride_height = stride_size[0];
-  auto stride_width = stride_size[1];
+  auto kernel_height = cast(Int(32), kernel_size[0]);
+  auto kernel_width = cast(Int(32), kernel_size[1]);
+  auto stride_height = cast(Int(32), stride_size[0]);
+  auto stride_width = cast(Int(32), stride_size[1]);
 
   auto height = x->shape[height_axis];
   auto width = x->shape[width_axis];
 
-  auto pad_top = padding_size[0];
-  auto pad_left = padding_size[1];
-  auto pad_bottom = padding_size[2];
-  auto pad_right = padding_size[3];
+  auto pad_top = cast(Int(32), padding_size[0]);
+  auto pad_left = cast(Int(32), padding_size[1]);
+  auto pad_bottom = cast(Int(32), padding_size[2]);
+  auto pad_right = cast(Int(32), padding_size[3]);
 
   if (ceil_mode) {
     // Additional padding to ensure we do ceil instead of floor when
@@ -471,8 +471,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
   auto height = x->shape[height_axis];
   auto width = x->shape[width_axis];
 
-  auto out_height = output_size[0];
-  auto out_width = output_size[1];
+  auto out_height = cast(Int(32), output_size[0]);
+  auto out_width = cast(Int(32), output_size[1]);
   Array<Expr> out_shape = x->shape;
   out_shape.Set(height_axis, out_height);
   out_shape.Set(width_axis, out_width);
index af2ed16..4180f5b 100644 (file)
@@ -208,9 +208,14 @@ inline Tensor reshape(const Tensor& x,
                       std::string name = "T_reshape",
                       std::string tag = kInjective) {
   auto x_shape = x->shape;
+  Array<Expr> newshape_int32;
+
+  for (const auto &ele : newshape) {
+    newshape_int32.push_back(cast(Int(32), ele));
+  }
   return compute(
-    newshape, [&](const Array<Var>& indices) {
-      return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape),
+    newshape_int32, [&](const Array<Var>& indices) {
+      return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape_int32),
                             x_shape));
     }, name, tag);
 }