[Relay][Frontend] Add a few mxnet ops in relay frontend (#2704)
authorHaichen Shen <shenhaichen@gmail.com>
Sun, 3 Mar 2019 18:24:20 +0000 (10:24 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sun, 3 Mar 2019 18:24:20 +0000 (10:24 -0800)
python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index 3d3bb8e4fd84eb5a98935eb895e926c4a617f2f3..1f1d18e240cd1795e1447b177a7c35faa78683d7 100644 (file)
@@ -64,6 +64,13 @@ def _mx_activations(inputs, attrs):
     raise RuntimeError("Do not support act_type: {}".format(act_type))
 
 
+def _mx_compare(new_op, wrapper):
+    def impl(inputs, attrs):
+        dtype = ir_pass.infer_type(inputs[0]).checked_type.dtype
+        return wrapper(new_op)(inputs, attrs).astype(dtype)
+    return impl
+
+
 def _mx_conv2d(inputs, attrs):
     kernel_size = attrs.get_int_tuple("kernel")
     if len(kernel_size) != 2:
@@ -333,32 +340,52 @@ _identity_list = [
 ]
 
 _convert_map = {
-    "_copy"         : _rename(_op.copy),
-    "relu"          : _rename(_op.nn.relu),
-    "broadcast_add" : _rename(_op.add),
-    "broadcast_sub" : _rename(_op.subtract),
-    "broadcast_mul" : _rename(_op.multiply),
-    "broadcast_div" : _rename(_op.divide),
-    "elemwise_add"  : _rename(_op.add),
-    "elemwise_sub"  : _rename(_op.subtract),
-    "elemwise_mul"  : _rename(_op.multiply),
-    "elemwise_div"  : _rename(_op.divide),
-    "flatten"       : _rename(_op.nn.batch_flatten),
-    "Flatten"       : _rename(_op.nn.batch_flatten),
-    "_plus_scalar"  : _binop_scalar(_op.add),
-    "__add_scalar__": _binop_scalar(_op.add),
-    "__sub_scalar__": _binop_scalar(_op.subtract),
-    "_minus_scalar" : _binop_scalar(_op.subtract),
-    "__mul_scalar__": _binop_scalar(_op.multiply),
-    "_mul_scalar"   : _binop_scalar(_op.multiply),
-    "__div_scalar__": _binop_scalar(_op.divide),
-    "_div_scalar"   : _binop_scalar(_op.divide),
-    "__pow_scalar__": _binop_scalar(_op.power),
-    "_rminus_scalar": _rbinop_scalar(_op.subtract),
-    "__rsub_scalar__": _rbinop_scalar(_op.subtract),
-    "_rdiv_scalar"  : _rbinop_scalar(_op.divide),
-    "__rdiv_scalar__"  : _rbinop_scalar(_op.divide),
-    "__rpow_scalar__": _rbinop_scalar(_op.power),
+    "_copy"                  : _rename(_op.copy),
+    "relu"                   : _rename(_op.nn.relu),
+    "broadcast_add"          : _rename(_op.add),
+    "broadcast_sub"          : _rename(_op.subtract),
+    "broadcast_mul"          : _rename(_op.multiply),
+    "broadcast_div"          : _rename(_op.divide),
+    "broadcast_mod"          : _rename(_op.mod),
+    "broadcast_maximum"      : _rename(_op.maximum),
+    "broadcast_minimum"      : _rename(_op.minimum),
+    "broadcast_equal"        : _mx_compare(_op.equal, _rename),
+    "broadcast_not_equal"    : _mx_compare(_op.not_equal, _rename),
+    "broadcast_greater"      : _mx_compare(_op.greater, _rename),
+    "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename),
+    "broadcast_lesser"       : _mx_compare(_op.less, _rename),
+    "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename),
+    "elemwise_add"           : _rename(_op.add),
+    "elemwise_sub"           : _rename(_op.subtract),
+    "elemwise_mul"           : _rename(_op.multiply),
+    "elemwise_div"           : _rename(_op.divide),
+    "_maximum"               : _rename(_op.maximum),
+    "_minimum"               : _rename(_op.minimum),
+    "flatten"                : _rename(_op.nn.batch_flatten),
+    "Flatten"                : _rename(_op.nn.batch_flatten),
+    "__add_scalar__"         : _binop_scalar(_op.add),
+    "_plus_scalar"           : _binop_scalar(_op.add),
+    "__sub_scalar__"         : _binop_scalar(_op.subtract),
+    "_minus_scalar"          : _binop_scalar(_op.subtract),
+    "__mul_scalar__"         : _binop_scalar(_op.multiply),
+    "_mul_scalar"            : _binop_scalar(_op.multiply),
+    "__div_scalar__"         : _binop_scalar(_op.divide),
+    "_div_scalar"            : _binop_scalar(_op.divide),
+    "__pow_scalar__"         : _binop_scalar(_op.power),
+    "_power_scalar"          : _binop_scalar(_op.power),
+    "__rsub_scalar__"        : _rbinop_scalar(_op.subtract),
+    "_rminus_scalar"         : _rbinop_scalar(_op.subtract),
+    "__rdiv_scalar__"        : _rbinop_scalar(_op.divide),
+    "_rdiv_scalar"           : _rbinop_scalar(_op.divide),
+    "__rpow_scalar__"        : _rbinop_scalar(_op.power),
+    "_equal_scalar"          : _mx_compare(_op.equal, _binop_scalar),
+    "_not_equal_scalar"      : _mx_compare(_op.not_equal, _binop_scalar),
+    "_greater_scalar"        : _mx_compare(_op.greater, _binop_scalar),
+    "_greater_equal_scalar"  : _mx_compare(_op.greater_equal, _binop_scalar),
+    "_lesser_scalar"         : _mx_compare(_op.less, _binop_scalar),
+    "_lesser_equal_scalar"   : _mx_compare(_op.less_equal, _binop_scalar),
+    "_maximum_scalar"        : _binop_scalar(_op.maximum),
+    "_minimum_scalar"        : _binop_scalar(_op.minimum),
     # reduction ops
     "max"           : _reduce(_op.max),
     "min"           : _reduce(_op.min),
index 671316079308d2157a52ec93a68ffd0742b52085..ee47d72046ed4792f1e5f9df9b84a74bda01ea59 100644 (file)
@@ -1,4 +1,5 @@
 import numpy as np
+import operator
 
 import tvm
 from tvm.contrib import graph_runtime
@@ -256,6 +257,85 @@ def test_forward_arange():
     verify(20, 1, -1)
     verify(20, 1, -1.5)
 
+def _mx_symbol(F, op_name, inputs):
+    op = getattr(F, op_name)
+    return op(*inputs)
+
+def test_forward_broadcast_ops():
+    for op in ["broadcast_add", "broadcast_sub", "broadcast_mul",
+               "broadcast_div", "broadcast_mod", "broadcast_maximum",
+               "broadcast_minimum", "broadcast_equal", "broadcast_not_equal",
+               "broadcast_greater", "broadcast_greater_equal",
+               "broadcast_lesser", "broadcast_lesser_equal"]:
+        a_shape = (3, 4, 5)
+        b_shape = (4, 5)
+        if op == "broadcast_mod":
+            dtype = 'int32'
+            a_np = np.random.randint(1, 100, size=a_shape).astype(dtype)
+            b_np = np.random.randint(1, 100, size=b_shape).astype(dtype)
+        else:
+            dtype = 'float32'
+            a_np = np.random.uniform(size=a_shape).astype(dtype)
+            b_np = np.random.uniform(size=b_shape).astype(dtype)
+        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
+        ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
+        shapes = {'a': a_shape, 'b': b_shape}
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(a_np, b_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
+def test_forward_elemwise_ops():
+    for op in ["elemwise_add", "elemwise_sub", "elemwise_mul",
+               "elemwise_div", "maximum", "minimum"]:
+        shape = (3, 4, 5)
+        dtype = 'float32'
+        a_np = np.random.uniform(size=shape).astype(dtype)
+        b_np = np.random.uniform(size=shape).astype(dtype)
+        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
+        ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
+        shapes = {'a': shape, 'b': shape}
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(a_np, b_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
+def test_forward_scalar_ops():
+    for op in [operator.add, operator.sub, operator.mul, operator.truediv,
+               operator.pow, operator.lt, operator.le, operator.eq,
+               operator.ne, operator.gt, operator.ge]:
+        dtype='float32'
+        a_shape = (3, 4, 5)
+        a_np = np.random.uniform(size=a_shape).astype(dtype)
+        b_scalar = 2.3
+        mx_sym = op(mx.sym.var('a'), b_scalar)
+        ref_res = op(mx.nd.array(a_np), b_scalar)
+        shapes = {'a': a_shape}
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(a_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    for op in ["maximum", "minimum"]:
+        dtype='float32'
+        a_shape = (3, 4, 5)
+        a_np = np.random.uniform(size=a_shape).astype(dtype)
+        b_scalar = 2.3
+        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), b_scalar])
+        ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), b_scalar])
+        shapes = {'a': a_shape}
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(a_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
 
 if __name__ == '__main__':
     test_forward_mlp()
@@ -280,3 +360,6 @@ if __name__ == '__main__':
     test_forward_argmin()
     test_forward_where()
     test_forward_arange()
+    test_forward_broadcast_ops()
+    test_forward_elemwise_ops()
+    test_forward_scalar_ops()