Update precision in the ONNX strided_slice, update precision of ToScalar (#6272)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Fri, 14 Aug 2020 17:27:20 +0000 (10:27 -0700)
committerGitHub <noreply@github.com>
Fri, 14 Aug 2020 17:27:20 +0000 (10:27 -0700)
* Update precision in the ONNX strided_slice, update precision of ToScalar

* fix tests

python/tvm/relay/frontend/onnx.py
src/relay/transforms/pattern_util.h
tests/python/frontend/onnx/test_forward.py

index 74626d4..f54a145 100644 (file)
@@ -1045,8 +1045,8 @@ class Slice(OnnxOpConverter):
         end = list(attr['ends'])
 
         return _op.strided_slice(inputs[0],
-                                 begin=_expr.const(begin, dtype="int32"),
-                                 end=_expr.const(end, dtype="int32"))
+                                 begin=_expr.const(begin, dtype="int64"),
+                                 end=_expr.const(end, dtype="int64"))
 
     @classmethod
     def _impl_v10(cls, inputs, attr, params):
@@ -1063,8 +1063,8 @@ class Slice(OnnxOpConverter):
                 starts = new_starts
                 ends = new_ends
         return _op.strided_slice(inputs[0],
-                                 begin=_expr.const(starts, dtype="int32"),
-                                 end=_expr.const(ends, dtype="int32"))
+                                 begin=_expr.const(starts, dtype="int64"),
+                                 end=_expr.const(ends, dtype="int64"))
 
 
 class Gather(OnnxOpConverter):
index a7063f5..0b64846 100644 (file)
@@ -374,7 +374,7 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
  * \param i element index
  * \return Converted scalar value.
  */
-static inline double ToScalar(const runtime::NDArray& array, size_t i = 0) {
+static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
   if (array->dtype.code == kDLInt) {
     if (array->dtype.bits == 8) {
       return reinterpret_cast<int8_t*>(array->data)[i];
@@ -423,8 +423,8 @@ static inline Array<Integer> ToVector(const runtime::NDArray& array) {
   size_t len = array.Shape().front();
   Array<Integer> out;
   for (size_t i = 0; i < len; ++i) {
-    double elem_val = ToScalar(array, i);
-    out.push_back(Integer(static_cast<int>(elem_val)));
+    long double elem_val = ToScalar(array, i);
+    out.push_back(Integer(IntImm(DataType::Int(32), static_cast<int64_t>(elem_val))));
   }
   return out;
 }
index 14b827c..c376c9a 100644 (file)
@@ -478,15 +478,15 @@ def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None):
     inputs = [
         helper.make_tensor_value_info("data", TensorProto.FLOAT,
                                       list(indata.shape)),
-        helper.make_tensor_value_info("starts", TensorProto.INT32,
+        helper.make_tensor_value_info("starts", TensorProto.INT64,
                                       list(starts.shape)),
-        helper.make_tensor_value_info("ends", TensorProto.INT32,
+        helper.make_tensor_value_info("ends", TensorProto.INT64,
                                       list(ends.shape))
     ]
     initializer = [
-        helper.make_tensor("starts", TensorProto.INT32, list(starts.shape),
+        helper.make_tensor("starts", TensorProto.INT64, list(starts.shape),
                            starts),
-        helper.make_tensor("ends", TensorProto.INT32, list(ends.shape), ends)
+        helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends)
     ]
 
     if axes:
@@ -534,7 +534,8 @@ def test_slice():
     _test_slice_iteration_v10(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
     _test_slice_iteration_v10(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
     _test_slice_iteration_v10(x, x[:, 1:1000], (1), (1000), (1))
-    _test_slice_iteration_v10(x, x[:, 0:-1], (0), (-1), (1))
+    x = np.random.randn(1, 1, 1, 128).astype(np.float32)
+    _test_slice_iteration_v10(x, x, (0, 0), (9223372036854775807, 9223372036854775807), (0, 3))
 
 
 def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):