raise RuntimeError("Do not support act_type: {}".format(act_type))
+def _mx_compare(new_op, wrapper):
+ def impl(inputs, attrs):
+ dtype = ir_pass.infer_type(inputs[0]).checked_type.dtype
+ return wrapper(new_op)(inputs, attrs).astype(dtype)
+ return impl
+
+
def _mx_conv2d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2:
]
_convert_map = {
- "_copy" : _rename(_op.copy),
- "relu" : _rename(_op.nn.relu),
- "broadcast_add" : _rename(_op.add),
- "broadcast_sub" : _rename(_op.subtract),
- "broadcast_mul" : _rename(_op.multiply),
- "broadcast_div" : _rename(_op.divide),
- "elemwise_add" : _rename(_op.add),
- "elemwise_sub" : _rename(_op.subtract),
- "elemwise_mul" : _rename(_op.multiply),
- "elemwise_div" : _rename(_op.divide),
- "flatten" : _rename(_op.nn.batch_flatten),
- "Flatten" : _rename(_op.nn.batch_flatten),
- "_plus_scalar" : _binop_scalar(_op.add),
- "__add_scalar__": _binop_scalar(_op.add),
- "__sub_scalar__": _binop_scalar(_op.subtract),
- "_minus_scalar" : _binop_scalar(_op.subtract),
- "__mul_scalar__": _binop_scalar(_op.multiply),
- "_mul_scalar" : _binop_scalar(_op.multiply),
- "__div_scalar__": _binop_scalar(_op.divide),
- "_div_scalar" : _binop_scalar(_op.divide),
- "__pow_scalar__": _binop_scalar(_op.power),
- "_rminus_scalar": _rbinop_scalar(_op.subtract),
- "__rsub_scalar__": _rbinop_scalar(_op.subtract),
- "_rdiv_scalar" : _rbinop_scalar(_op.divide),
- "__rdiv_scalar__" : _rbinop_scalar(_op.divide),
- "__rpow_scalar__": _rbinop_scalar(_op.power),
+ "_copy" : _rename(_op.copy),
+ "relu" : _rename(_op.nn.relu),
+ "broadcast_add" : _rename(_op.add),
+ "broadcast_sub" : _rename(_op.subtract),
+ "broadcast_mul" : _rename(_op.multiply),
+ "broadcast_div" : _rename(_op.divide),
+ "broadcast_mod" : _rename(_op.mod),
+ "broadcast_maximum" : _rename(_op.maximum),
+ "broadcast_minimum" : _rename(_op.minimum),
+ "broadcast_equal" : _mx_compare(_op.equal, _rename),
+ "broadcast_not_equal" : _mx_compare(_op.not_equal, _rename),
+ "broadcast_greater" : _mx_compare(_op.greater, _rename),
+ "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename),
+ "broadcast_lesser" : _mx_compare(_op.less, _rename),
+ "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename),
+ "elemwise_add" : _rename(_op.add),
+ "elemwise_sub" : _rename(_op.subtract),
+ "elemwise_mul" : _rename(_op.multiply),
+ "elemwise_div" : _rename(_op.divide),
+ "_maximum" : _rename(_op.maximum),
+ "_minimum" : _rename(_op.minimum),
+ "flatten" : _rename(_op.nn.batch_flatten),
+ "Flatten" : _rename(_op.nn.batch_flatten),
+ "__add_scalar__" : _binop_scalar(_op.add),
+ "_plus_scalar" : _binop_scalar(_op.add),
+ "__sub_scalar__" : _binop_scalar(_op.subtract),
+ "_minus_scalar" : _binop_scalar(_op.subtract),
+ "__mul_scalar__" : _binop_scalar(_op.multiply),
+ "_mul_scalar" : _binop_scalar(_op.multiply),
+ "__div_scalar__" : _binop_scalar(_op.divide),
+ "_div_scalar" : _binop_scalar(_op.divide),
+ "__pow_scalar__" : _binop_scalar(_op.power),
+ "_power_scalar" : _binop_scalar(_op.power),
+ "__rsub_scalar__" : _rbinop_scalar(_op.subtract),
+ "_rminus_scalar" : _rbinop_scalar(_op.subtract),
+ "__rdiv_scalar__" : _rbinop_scalar(_op.divide),
+ "_rdiv_scalar" : _rbinop_scalar(_op.divide),
+ "__rpow_scalar__" : _rbinop_scalar(_op.power),
+ "_equal_scalar" : _mx_compare(_op.equal, _binop_scalar),
+ "_not_equal_scalar" : _mx_compare(_op.not_equal, _binop_scalar),
+ "_greater_scalar" : _mx_compare(_op.greater, _binop_scalar),
+ "_greater_equal_scalar" : _mx_compare(_op.greater_equal, _binop_scalar),
+ "_lesser_scalar" : _mx_compare(_op.less, _binop_scalar),
+ "_lesser_equal_scalar" : _mx_compare(_op.less_equal, _binop_scalar),
+ "_maximum_scalar" : _binop_scalar(_op.maximum),
+ "_minimum_scalar" : _binop_scalar(_op.minimum),
# reduction ops
"max" : _reduce(_op.max),
"min" : _reduce(_op.min),
import numpy as np
+import operator
import tvm
from tvm.contrib import graph_runtime
verify(20, 1, -1)
verify(20, 1, -1.5)
+def _mx_symbol(F, op_name, inputs):
+ op = getattr(F, op_name)
+ return op(*inputs)
+
+def test_forward_broadcast_ops():
+ for op in ["broadcast_add", "broadcast_sub", "broadcast_mul",
+ "broadcast_div", "broadcast_mod", "broadcast_maximum",
+ "broadcast_minimum", "broadcast_equal", "broadcast_not_equal",
+ "broadcast_greater", "broadcast_greater_equal",
+ "broadcast_lesser", "broadcast_lesser_equal"]:
+ a_shape = (3, 4, 5)
+ b_shape = (4, 5)
+ if op == "broadcast_mod":
+ dtype = 'int32'
+ a_np = np.random.randint(1, 100, size=a_shape).astype(dtype)
+ b_np = np.random.randint(1, 100, size=b_shape).astype(dtype)
+ else:
+ dtype = 'float32'
+ a_np = np.random.uniform(size=a_shape).astype(dtype)
+ b_np = np.random.uniform(size=b_shape).astype(dtype)
+ mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
+ ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
+ shapes = {'a': a_shape, 'b': b_shape}
+ new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(new_sym)(a_np, b_np)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
+def test_forward_elemwise_ops():
+ for op in ["elemwise_add", "elemwise_sub", "elemwise_mul",
+ "elemwise_div", "maximum", "minimum"]:
+ shape = (3, 4, 5)
+ dtype = 'float32'
+ a_np = np.random.uniform(size=shape).astype(dtype)
+ b_np = np.random.uniform(size=shape).astype(dtype)
+ mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
+ ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
+ shapes = {'a': shape, 'b': shape}
+ new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(new_sym)(a_np, b_np)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
+def test_forward_scalar_ops():
+ for op in [operator.add, operator.sub, operator.mul, operator.truediv,
+ operator.pow, operator.lt, operator.le, operator.eq,
+ operator.ne, operator.gt, operator.ge]:
+ dtype='float32'
+ a_shape = (3, 4, 5)
+ a_np = np.random.uniform(size=a_shape).astype(dtype)
+ b_scalar = 2.3
+ mx_sym = op(mx.sym.var('a'), b_scalar)
+ ref_res = op(mx.nd.array(a_np), b_scalar)
+ shapes = {'a': a_shape}
+ new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(new_sym)(a_np)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+ for op in ["maximum", "minimum"]:
+ dtype='float32'
+ a_shape = (3, 4, 5)
+ a_np = np.random.uniform(size=a_shape).astype(dtype)
+ b_scalar = 2.3
+ mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), b_scalar])
+ ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), b_scalar])
+ shapes = {'a': a_shape}
+ new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(new_sym)(a_np)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
if __name__ == '__main__':
test_forward_mlp()
test_forward_argmin()
test_forward_where()
test_forward_arange()
+ test_forward_broadcast_ops()
+ test_forward_elemwise_ops()
+ test_forward_scalar_ops()