[Frontend][MXNet] Support a few contrib ops in mxnet (#5819)
authorHaichen Shen <shenhaichen@gmail.com>
Wed, 17 Jun 2020 20:15:14 +0000 (13:15 -0700)
committerGitHub <noreply@github.com>
Wed, 17 Jun 2020 20:15:14 +0000 (13:15 -0700)
* support for bert in mxnet1.6 and gluonnlp0.9

* fix converter

* Add test cases

* add a todo

python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/nnvm_common.py
tests/python/frontend/mxnet/test_forward.py

index f77c3b5..2454a55 100644 (file)
@@ -17,6 +17,7 @@
 # pylint: disable=invalid-name, import-self, len-as-condition, no-else-return, too-many-lines
 """MXNet symbol frontend."""
 import json
+import math
 import numpy as np
 import tvm
 from tvm.ir import IRModule
@@ -655,6 +656,15 @@ def _mx_leaky_relu(inputs, attrs):
         upper_bound = attrs.get_float("upper_bound")
         alpha = (lower_bound + upper_bound) / 2.0
         return _op.nn.leaky_relu(inputs[0], alpha=alpha)
+    if act_type == "gelu":
+        # 0.5 * x * (1 + erf(x / sqrt(2)))
+        sqrt2 = _expr.const(math.sqrt(2), dtype="float32")
+        erf = _op.erf(_op.divide(inputs[0], sqrt2))
+        one = _expr.const(1, dtype="float32")
+        erf_plus_one = _op.add(one, erf)
+        half = _expr.const(0.5, dtype="float32")
+        half_x = _op.multiply(inputs[0], half)
+        return _op.multiply(half_x, erf_plus_one)
     raise tvm.error.OpNotImplemented(
         'Operator {} is not supported for frontend MXNet.'.format(act_type))
 
@@ -784,6 +794,42 @@ def _mx_make_loss(inputs, attrs):
     return inputs[0]
 
 
+def _mx_contrib_arange_like(inputs, attrs):
+    assert len(inputs) == 1
+    if attrs.get_int("repeat", 1) != 1:
+        raise tvm.error.OpAttributeUnimplemented(
+            'Attribute "repeat" is not supported in operator arange_like.')
+    ty = _infer_type(inputs[0]).checked_type
+    assert ty
+    shape, dtype = get_const_tuple(ty.shape), ty.dtype
+    axis = attrs.get_int("axis", None)
+    if axis is None:
+        n_elems = 1
+        for dim in shape:
+            if not isinstance(dim, int):
+                raise tvm.error.OpError("Don't support arange_like with symbolic input shape.")
+            n_elems *= dim
+    else:
+        axis = axis + len(shape) if axis < 0 else axis
+        assert 0 <= axis < len(shape)
+        n_elems = shape[axis]
+        if not isinstance(n_elems, int):
+            raise tvm.error.OpError("Don't support arange_like with symbolic input shape.")
+        shape = (n_elems,)
+    start = attrs.get_float("start", 0.)
+    step = attrs.get_float("step", 1.)
+    stop = start + step * n_elems
+    new_attrs = {}
+    new_attrs["start"] = _expr.const(start, dtype=dtype)
+    new_attrs["stop"] = _expr.const(stop, dtype=dtype)
+    new_attrs["step"] = _expr.const(step, dtype=dtype)
+    new_attrs["dtype"] = dtype
+    ret = _op.arange(**new_attrs)
+    if len(shape) > 1:
+        ret = _op.reshape(ret, shape)
+    return ret
+
+
 def _mx_repeat(inputs, attrs):
     assert len(inputs) == 1
     new_attrs = {}
@@ -1278,6 +1324,56 @@ def _mx_contrib_fifo_buffer(inputs, attrs):
     return _op.nn.fifo_buffer(*inputs, **new_attrs)
 
 
+def _mx_contrib_interleaved_matmul_selfatt_qk(inputs, attrs):
+    """
+    tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
+    q_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3))
+    q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True)
+    q_proj = mx.nd.contrib.div_sqrt_dim(q_proj)
+    k_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3))
+    k_proj = mx.nd.reshape(k_proj, shape=(-1, 0, 0), reverse=True)
+    output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True)
+    """
+    assert len(inputs) == 1
+    qkv = inputs[0]
+    num_heads = attrs.get_int('heads')
+    qkv = _op.reshape(qkv, newshape=(0, 0, num_heads, 3, -1))
+    q_proj = _op.take(qkv, _expr.const(0, "int32"), axis=3)
+    q_proj = _op.transpose(q_proj, axes=[1, 2, 0, 3])
+    q_proj = _op.reverse_reshape(q_proj, newshape=(-1, 0, 0))
+    q_proj = _mx_contrib_div_sqrt_dim([q_proj], None)
+    k_proj = _op.take(qkv, _expr.const(1, "int32"), axis=3)
+    k_proj = _op.transpose(k_proj, axes=[1, 2, 0, 3])
+    k_proj = _op.reverse_reshape(k_proj, newshape=(-1, 0, 0))
+    ret = _op.nn.batch_matmul(q_proj, k_proj)
+    return ret
+
+
+def _mx_contrib_interleaved_matmul_selfatt_valatt(inputs, attrs):
+    """
+    tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
+    v_proj = mx.nd.transpose(tmp[:,:,:,2,:], axes=(1, 2, 0, 3))
+    v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True)
+    output = mx.nd.batch_dot(attention, v_proj)
+    output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True)
+    output = mx.nd.transpose(output, axes=(2, 0, 1, 3))
+    output = mx.nd.reshape(output, shape=(0, 0, -1))
+    """
+    assert len(inputs) == 2
+    qkv, att = inputs
+    num_heads = attrs.get_int("heads")
+    qkv = _op.reshape(qkv, newshape=(0, 0, num_heads, 3, -1))
+    v_proj = _op.take(qkv, _expr.const(2, "int32"), axis=3)
+    v_proj = _op.transpose(v_proj, axes=(1, 2, 0, 3))
+    v_proj = _op.reverse_reshape(v_proj, newshape=(-1, 0, 0))
+    v_proj = _op.transpose(v_proj, axes=[0, 2, 1])
+    out = _op.nn.batch_matmul(att, v_proj)
+    out = _op.reverse_reshape(out, newshape=(-1, num_heads, 0, 0))
+    out = _op.transpose(out, axes=(2, 0, 1, 3))
+    out = _op.reshape(out, newshape=(0, 0, -1))
+    return out
+
+
 def _mx_cond(inputs, attrs, subgraphs):
     assert len(subgraphs) == 3
     cond_input_locs = json.loads(attrs.get_str("cond_input_locs"))
@@ -2110,6 +2206,7 @@ _convert_map = {
     "smooth_l1"     : _mx_smooth_l1,
     "make_loss"     : _mx_make_loss,
     "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
+    "_contrib_arange_like": _mx_contrib_arange_like,
     "one_hot"           : _mx_one_hot,
     "depth_to_space"    : _mx_depth_to_space,
     "space_to_depth"    : _mx_space_to_depth,
@@ -2130,6 +2227,8 @@ _convert_map = {
     # NLP
     "RNN"               : _mx_rnn_layer,
     "_rnn_param_concat" : _mx_rnn_param_concat,
+    "_contrib_interleaved_matmul_selfatt_qk" : _mx_contrib_interleaved_matmul_selfatt_qk,
+    "_contrib_interleaved_matmul_selfatt_valatt" : _mx_contrib_interleaved_matmul_selfatt_valatt,
     # control flow
     "_cond"             : _mx_cond,
     # Depricated:
index 072c7ad..a2eea94 100644 (file)
@@ -57,7 +57,8 @@ def _init_op(new_op):
 def _softmax_op(new_op):
     """softmax/log_softmax"""
     def _impl(inputs, attrs, _dtype='float32'):
-        assert len(inputs) == 1
+        # TODO(@icemelon9): currently ignore the 2nd input to softmax for mxnet 1.6
+        # assert len(inputs) == 1
         axis = attrs.get_int("axis", -1)
         return new_op(inputs[0], axis=axis)
     return _impl
index 8b3e04b..ae5ed45 100644 (file)
@@ -133,6 +133,12 @@ def test_forward_prelu():
     mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
     verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
 
+def test_forward_gelu():
+    data = mx.sym.var('data')
+    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
+    mx_sym = mx.sym.LeakyReLU(data, act_type='gelu')
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+
 def test_forward_softrelu():
     data = mx.sym.var('data')
     data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
@@ -1228,6 +1234,78 @@ def test_forward_correlation():
     verify((5, 1, 11, 11), kernel_size = 5, max_displacement = 1, stride1 = 1, stride2 = 1, pad_size = 2, is_multiply = False)
 
 
+def test_forward_arange_like():
+    def verify(data_shape, start=None, step=None, axis=None):
+        attrs = {}
+        if start is not None:
+            attrs['start'] = start
+        if step is not None:
+            attrs['step'] = step
+        if axis is not None:
+            attrs['axis'] = axis
+        data = mx.sym.var('data')
+        data_np = np.random.uniform(size=data_shape).astype("float32")
+        ref_res = mx.nd.contrib.arange_like(mx.nd.array(data_np), **attrs)
+        
+        mx_sym = mx.sym.contrib.arange_like(data, **attrs)
+        mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()()
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
+    verify(data_shape=(3,), start=0., step=1.)
+    verify(data_shape=(3, 4, 5), start=0., step=1.)
+    verify(data_shape=(3, 4, 5), start=0., step=1., axis=-1)
+    verify(data_shape=(3, 4, 5), start=2., step=3., axis=1)
+
+
+def test_forward_interleaved_matmul_selfatt_qk():
+    def verify(batch, seq_length, num_heads, head_dim):
+        data_shape = (seq_length, batch, num_heads * head_dim * 3)
+        data = mx.sym.var('data')
+        data_np = np.random.uniform(size=data_shape).astype('float32')
+        ref_res = mx.nd.contrib.interleaved_matmul_selfatt_qk(
+            mx.nd.array(data_np), heads=num_heads)
+
+        mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_qk(data, heads=num_heads)
+        mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(data_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
+
+    verify(1, 10, 3, 16)
+    verify(3, 10, 6, 8)
+
+
+def test_forward_interleaved_matmul_selfatt_valatt():
+    def verify(batch, seq_length, num_heads, head_dim):
+        data_shape = (seq_length, batch, num_heads * head_dim * 3)
+        weight_shape = (batch * num_heads, seq_length, seq_length)
+        data = mx.sym.var('data')
+        weight = mx.sym.var('weight')
+        data_np = np.random.uniform(size=data_shape).astype('float32')
+        weight_np = np.random.uniform(size=weight_shape).astype('float32')
+        ref_res = mx.nd.contrib.interleaved_matmul_selfatt_valatt(
+            mx.nd.array(data_np), mx.nd.array(weight_np), heads=num_heads)
+
+        mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_valatt(
+            data, weight, heads=num_heads)
+        mod, _ = relay.frontend.from_mxnet(
+            mx_sym, {"data": data_shape, "weight": weight_shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(data=data_np, weight=weight_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
+
+    verify(1, 10, 4, 16)
+    verify(3, 10, 6, 8)
+
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -1236,6 +1314,7 @@ if __name__ == '__main__':
     test_forward_elu()
     test_forward_rrelu()
     test_forward_prelu()
+    test_forward_gelu()
     test_forward_softrelu()
     test_forward_softmin()
     test_forward_fc_flatten()
@@ -1297,3 +1376,6 @@ if __name__ == '__main__':
     test_forward_correlation()
     test_forward_grid_generator()
     test_forward_bilinear_sampler()
+    test_forward_arange_like()
+    test_forward_interleaved_matmul_selfatt_qk()
+    test_forward_interleaved_matmul_selfatt_valatt()