[Bugfix] Fix reshape (#5739)
authorCody Yu <comaniac0422@gmail.com>
Wed, 10 Jun 2020 05:03:09 +0000 (22:03 -0700)
committerGitHub <noreply@github.com>
Wed, 10 Jun 2020 05:03:09 +0000 (22:03 -0700)
* Fix reshape

* fix doc warning

* fix ci

* address comments

python/tvm/relay/op/transform.py

index e1b5627..0458b9a 100644 (file)
@@ -20,6 +20,7 @@
 
 from . import _make
 from ..expr import TupleWrapper, const
+from ...tir import expr as _expr
 
 
 def cast(data, dtype):
@@ -212,7 +213,16 @@ def reshape(data, newshape):
     if isinstance(newshape, int):
         newshape = const([newshape])
     if isinstance(newshape, (tuple, list)):
-        newshape = const(list(newshape))
+        tempshape = []
+        for shape in newshape:
+            if isinstance(shape, _expr.IntImm):
+                tempshape.append(shape.value)
+            else:
+                try:
+                    tempshape.append(int(shape))
+                except ValueError as err:
+                    raise RuntimeError('Unrecognized shape type: %s' % err)
+        newshape = const(tempshape)
     return _make.reshape(data, newshape)
 
 def argwhere(condition):