[Relay] Fixes to sum (#2439)
authorJared Roesch <roeschinc@gmail.com>
Wed, 16 Jan 2019 06:04:44 +0000 (22:04 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 16 Jan 2019 06:04:44 +0000 (22:04 -0800)
python/tvm/relay/op/reduce.py
tests/python/relay/test_op_level4.py

index 71c7dea9c0dc40b07339b4a62610641fce45d195..a4d5f66c009d57f5842702404e0c3979dc4dd4d2 100644 (file)
@@ -12,8 +12,8 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
         The input data
 
     axis : None or int or tuple of int
-        Axis or axes along which a argmin operation is performed.
-        The default, axis=None, will find the indices of maximum element all of the elements of
+        Axis or axes along which a argmax operation is performed.
+        The default, axis=None, will find the indices of the maximum element of the elements of
         the input array. If axis is negative it counts from the last to the first axis.
 
     keepdims : bool
@@ -73,14 +73,14 @@ def sum(data, axis=None, keepdims=False, exclude=False):
         The input data
 
     axis : None or int or tuple of int
-        Axis or axes along which a argmin operation is performed.
-        The default, axis=None, will find the indices of minimum element all of the elements of
-        the input array. If axis is negative it counts from the last to the first axis.
+        Axis or axes along which a sum is performed. The default, axis=None,
+        will sum all of the elements of the input array. If axis is
+        negative it counts from the last to the first axis.
 
     keepdims : bool
-        If this is set to True, the axes which are reduced are left in the result as dimensions
-        with size one.
-        With this option, the result will broadcast correctly against the input array.
+        If this is set to True, the axes which are reduced are left in the result as
+        dimensions with size one. With this option, the result will broadcast
+        correctly against the input array.
 
     exclude : bool
         If `exclude` is true, reduction will be performed on the axes that are
@@ -91,7 +91,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
     result : relay.Expr
         The computed result.
     """
-    axis = [axis] if isinstance(axis, int) else axis
+    axis = [axis] if axis and isinstance(axis, int) else axis
     return _make.sum(data, axis, keepdims, exclude)
 
 
@@ -104,9 +104,9 @@ def max(data, axis=None, keepdims=False, exclude=False):
         The input data
 
     axis : None or int or tuple of int
-        Axis or axes along which a argmin operation is performed.
-        The default, axis=None, will find the indices of minimum element all of the elements of
-        the input array. If axis is negative it counts from the last to the first axis.
+        Axis or axes along which the max operation is performed.
+        The default, axis=None, will find the max element from all of the elements of the input
+        array. If axis is negative it counts from the last to the first axis.
 
     keepdims : bool
         If this is set to True, the axes which are reduced are left in the result as dimensions
@@ -135,9 +135,10 @@ def min(data, axis=None, keepdims=False, exclude=False):
         The input data
 
     axis : None or int or tuple of int
-        Axis or axes along which a argmin operation is performed.
-        The default, axis=None, will find the indices of minimum element all of the elements of
-        the input array. If axis is negative it counts from the last to the first axis.
+        Axis or axes along which a minimum operation is performed.
+        The default, axis=None, will find the minimum element from all
+        of the elements of the input array. If axis is negative it counts from
+        the last to the first axis.
 
     keepdims : bool
         If this is set to True, the axes which are reduced are left in the result as dimensions
@@ -166,7 +167,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
         The input data
 
     axis : None or int or tuple of int
-        Axis or axes along which a argmin operation is performed.
+        Axis or axes along which a mean operation is performed.
         The default, axis=None, will find the indices of minimum element all of the elements of
         the input array. If axis is negative it counts from the last to the first axis.
 
@@ -197,7 +198,7 @@ def prod(data, axis=None, keepdims=False, exclude=False):
         The input data
 
     axis : None or int or tuple of int
-        Axis or axes along which a argmin operation is performed.
+        Axis or axes along which a product is performed.
         The default, axis=None, will find the indices of minimum element all of the elements of
         the input array. If axis is negative it counts from the last to the first axis.
 
index 45d6d36fdc20d93cdd32e73aefe045474956a43a..ae7fe320940aaa98fe71c76415b9c43f04b6e50e 100644 (file)
@@ -180,6 +180,7 @@ def test_reduce_functions():
                  [relay.prod, np.prod],
                  [relay.argmin, _with_keepdims(np.argmin)],
                  [relay.argmax, _with_keepdims(np.argmax)]]:
+        verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
         verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4))
         verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3))
         verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))