[RELAY][ONNX]ReduceLogSumExp Operator support (#5453)
authorSamuel <siju.samuel@huawei.com>
Thu, 7 May 2020 20:50:12 +0000 (02:20 +0530)
committerGitHub <noreply@github.com>
Thu, 7 May 2020 20:50:12 +0000 (05:50 +0900)
* [RELAY]LogSumExp Op Support

* [ONNX]LogSumExp Op Support

python/tvm/relay/frontend/onnx.py
python/tvm/relay/op/reduce.py
tests/python/frontend/onnx/test_forward.py
tests/python/relay/test_op_level4.py

index 52d87f6..4ae083c 100644 (file)
@@ -1076,6 +1076,11 @@ class ReduceProd(Reduce):
     """
     name = 'prod'
 
+class ReduceLogSumExp(Reduce):
+    """ Operator converter for ReduceLogSumExp.
+    """
+    name = 'logsumexp'
+
 class ArgMax(OnnxOpConverter):
     """ Operator converter for ArgMax.
     """
@@ -1640,8 +1645,7 @@ def _get_convert_map(opset):
         'ReduceSum': ReduceSum.get_converter(opset),
         'ReduceMean': ReduceMean.get_converter(opset),
         'ReduceProd': ReduceProd.get_converter(opset),
-        # 'ReduceProd'
-        # 'ReduceLogSumExp'
+        'ReduceLogSumExp': ReduceLogSumExp.get_converter(opset),
 
         #defs/sorting
         'ArgMax': ArgMax.get_converter(opset),
index d322601..988c949 100644 (file)
@@ -18,7 +18,7 @@
 # pylint: disable=redefined-builtin
 
 from . import _make
-from .tensor import sqrt
+from .tensor import sqrt, log, exp
 from .transform import squeeze
 from ..expr import Tuple, TupleWrapper
 
@@ -475,3 +475,40 @@ def prod(data, axis=None, keepdims=False, exclude=False):
     """
     axis = [axis] if isinstance(axis, int) else axis
     return _make.prod(data, axis, keepdims, exclude)
+
+
+def logsumexp(data, axis=None, keepdims=False):
+    """Compute the log of the sum of exponentials of input elements over given axes.
+
+       This function is more numerically stable than log(sum(exp(input))).
+       It avoids overflows caused by taking the exp of large inputs and underflows
+       caused by taking the log of small inputs.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data
+
+    axis : None or int or tuple of int
+        Axis or axes along which a standard deviation operation is performed.
+        The default, axis=None, will compute the log of the sum of exponentials of all elements
+        in the input array. If axis is negative it counts from the last to the first axis.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the result as dimensions
+        with size one.
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+    """
+
+    axis = [axis] if isinstance(axis, int) else axis
+    max_x = max(data, axis, True)
+    exp_x = exp(data - max_x)
+    sum_x = sum(exp_x, axis, True)
+    out_x = log(sum_x) + max_x
+    if not keepdims:
+        out_x = squeeze(out_x, axis)
+    return out_x
index dc832aa..78658e7 100644 (file)
@@ -1307,6 +1307,15 @@ def verify_reduce_x(name, indata, axis, keepdims):
         outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1)
     elif name == 'ReduceMean':
         outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1)
+    elif name == 'ReduceLogSumExp':
+        def _np_log_sum_exp(x, axis, keepdims=False):
+            max_x = np.max(x, axis=axis, keepdims=True)
+            x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True))
+            x = x + max_x
+            if not keepdims:
+                x = np.squeeze(x, axis=axis)
+            return x
+        outdata = _np_log_sum_exp(indata, axis=axis, keepdims=keepdims == 1)
     else:
         raise Exception('unsupport op: {}'.format(name))
     if len(np.asarray(outdata).shape) == 0:
@@ -1380,6 +1389,34 @@ def test_reduce_mean():
                     axis=(1,), keepdims=1)
 
 
+def test_reduce_logsumexp():
+
+    for keepdims in [True, False]:
+        verify_reduce_x("ReduceLogSumExp",
+                        np.random.randn(3, 2, 2).astype(np.float32),
+                        axis=None, keepdims=keepdims)
+
+        verify_reduce_x("ReduceLogSumExp",
+                        np.random.randn(3, 2, 3).astype(np.float32),
+                        axis=None, keepdims=keepdims)
+
+        verify_reduce_x("ReduceLogSumExp",
+                        np.random.randn(3, 3, 3).astype(np.float32),
+                        axis=(1,), keepdims=keepdims)
+
+        verify_reduce_x("ReduceLogSumExp",
+                        np.random.randn(3, 3, 3, 1).astype(np.float32),
+                        axis=(1, 2), keepdims=keepdims)
+
+        verify_reduce_x("ReduceLogSumExp",
+                        np.random.randn(3, 3, 3, 1).astype(np.float32),
+                        axis=(1), keepdims=keepdims)
+
+        verify_reduce_x("ReduceLogSumExp",
+                        np.random.randn(1, 3, 4, 1).astype(np.float32),
+                        axis=(1), keepdims=keepdims)
+
+
 def verify_split(indata, outdatas, split, axis=0):
     indata = np.array(indata).astype(np.float32)
     outdatas = [np.array(o).astype(np.float32) for o in outdatas]
@@ -2557,6 +2594,7 @@ if __name__ == '__main__':
     test_reduce_min()
     test_reduce_sum()
     test_reduce_mean()
+    test_reduce_logsumexp()
     test_pad()
     test_split()
     test_binary_ops()
index bbe2c69..947a4bf 100644 (file)
@@ -165,7 +165,10 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
     dtype = "bool" if ref_func in [np.all, np.any] else dtype
 
     x = relay.var("x", relay.TensorType(data, dtype))
-    z = test_func(x, axis, keepdims, exclude)
+    if test_func == relay.logsumexp:
+        z = test_func(x, axis, keepdims)
+    else:
+        z = test_func(x, axis, keepdims, exclude)
     zz = run_infer_type(z)
     if axis:
         assert "axis=" in z.astext()
@@ -215,6 +218,14 @@ def test_reduce_functions():
                 return func(data, axis=axis).reshape(out_shape)
         return _wrapper
 
+    def _np_log_sum_exp(x, axis, keepdims=False):
+        max_x = np.max(x, axis=axis, keepdims=True)
+        x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True))
+        x = x + max_x
+        if not keepdims:
+            x = np.squeeze(x, axis=axis)
+        return x
+
     d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4")
     for func in [[relay.sum, np.sum],
                  [relay.max, np.max],
@@ -225,6 +236,7 @@ def test_reduce_functions():
                  [relay.prod, np.prod],
                  [relay.all, np.all],
                  [relay.any, np.any],
+                 [relay.logsumexp, _np_log_sum_exp],
                  [relay.argmin, _with_keepdims(np.argmin)],
                  [relay.argmax, _with_keepdims(np.argmax)]]:
         verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())