add support for mxnet smooth_l1 (#2905)
authorHao Jin <hjjn.amzn@gmail.com>
Fri, 29 Mar 2019 03:59:05 +0000 (20:59 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Fri, 29 Mar 2019 03:59:05 +0000 (11:59 +0800)
python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index 39daaf9..69d7792 100644 (file)
@@ -594,6 +594,15 @@ def _mx_embedding(inputs, _):
     return _op.take(weight, indices.astype('int32'), axis=0)
 
 
+def _mx_smooth_l1(inputs, attrs):
+    scalar = attrs.get_float("scalar", 1.0)
+    scalar_sq = scalar * scalar
+    mask = _op.less(inputs[0], _expr.const(1.0 / scalar_sq, dtype='float32'))
+    return _op.where(mask,
+                     _expr.const(scalar_sq / 2.0, dtype='float32') * inputs[0] * inputs[0],
+                     _op.abs(inputs[0]) - _expr.const(0.5 / scalar_sq))
+
+
 # Note: due to attribute conversion constraint
 # ops in the identity set must be attribute free
 _identity_list = [
@@ -729,6 +738,7 @@ _convert_map = {
     "Embedding"     : _mx_embedding,
     "SoftmaxOutput" : _mx_softmax_output,
     "SoftmaxActivation" : _mx_softmax_activation,
+    "smooth_l1"     : _mx_smooth_l1,
     # vision
     "_contrib_BilinearResize2D" : _mx_upsampling,
     "_contrib_MultiBoxPrior" : _mx_multibox_prior,
index aad666c..faccfbf 100644 (file)
@@ -464,6 +464,14 @@ def test_forward_embedding():
     verify((2, 2), (4, 5))
     verify((2, 3, 4), (4, 5))
 
+
+def test_forward_smooth_l1():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.smooth_l1(data)
+    verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
+    mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
+    verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -498,3 +506,4 @@ if __name__ == '__main__':
     test_forward_broadcast_axis()
     test_forward_full()
     test_forward_embedding()
+    test_forward_smooth_l1()