[Relay] Fix reduce axis bug (#3422)
authorAltan Haan <altancpp@gmail.com>
Thu, 27 Jun 2019 17:03:29 +0000 (10:03 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 27 Jun 2019 17:03:29 +0000 (10:03 -0700)
* fix relay reduce axis bug

* add tests for reduce bug

python/tvm/relay/op/reduce.py
tests/python/relay/test_op_level4.py

index 0f25946..41e1fc0 100644 (file)
@@ -107,7 +107,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
     result : relay.Expr
         The computed result.
     """
-    axis = [axis] if axis and isinstance(axis, int) else axis
+    axis = [axis] if isinstance(axis, int) else axis
     return _make.sum(data, axis, keepdims, exclude)
 
 
@@ -159,7 +159,7 @@ def all(data, axis=None, keepdims=False, exclude=False):
     # [False,  True, False]]
 
     """
-    axis = [axis] if axis and isinstance(axis, int) else axis
+    axis = [axis] if isinstance(axis, int) else axis
     return _make.all(data, axis, keepdims, exclude)
 
 
index aac4a6d..da0fe01 100644 (file)
@@ -202,7 +202,9 @@ def test_reduce_functions():
                  [relay.argmax, _with_keepdims(np.argmax)]]:
         verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
         verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4))
+        verify_reduce(func, (d1, d2, d3, d4), 0, True, False, (1, d2, d3, d4))
         verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3))
+        verify_reduce(func, (d1, d2, d3), 0, True, False, (1, d2, d3))
         verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
         verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
         verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))