[Frontend][MXNet] Fix mxnet converter for hybridblock and add div_sqrt_dim (#3701)
authorHaichen Shen <shenhaichen@gmail.com>
Wed, 7 Aug 2019 04:27:06 +0000 (21:27 -0700)
committerJared Roesch <roeschinc@gmail.com>
Wed, 7 Aug 2019 04:27:06 +0000 (21:27 -0700)
* Fix mxnet converter for hybrid block

* tweak

* fix rebase

* fix

* add test

python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py
topi/python/topi/cuda/batch_matmul.py

index 8d13d6c..3486263 100644 (file)
@@ -715,7 +715,7 @@ def _mx_topk(inputs, attrs):
     return _op.topk(inputs[0], **new_attrs)
 
 
-def _mx_SequenceMask(inputs, attrs):
+def _mx_sequence_mask(inputs, attrs):
     assert len(inputs) == 1 or len(inputs) == 2
     new_attrs = {}
     use_sequence_length = attrs.get_bool('use_sequence_length', False)
@@ -727,6 +727,15 @@ def _mx_SequenceMask(inputs, attrs):
         return inputs[0]
 
 
+def _mx_contrib_div_sqrt_dim(inputs, _):
+    assert len(inputs) == 1
+    ndim = len(_infer_type(inputs[0]).checked_type.shape)
+    dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32"))
+    sqrt_dim = _op.sqrt(dim.astype('float32'))
+    out = inputs[0] / sqrt_dim
+    return out
+
+
 def _mx_rnn_param_concat(inputs, _):
     # We don't need to concatenate RNN params because we will unravel the RNN op
     return [inputs]
@@ -1014,11 +1023,12 @@ _convert_map = {
     "Embedding"     : _mx_embedding,
     "argsort"       : _mx_argsort,
     "topk"          : _mx_topk,
-    "SequenceMask"  : _mx_SequenceMask,
+    "SequenceMask"  : _mx_sequence_mask,
     "SoftmaxOutput" : _mx_softmax_output,
     "SoftmaxActivation" : _mx_softmax_activation,
     "LinearRegressionOutput" : _mx_linear_regression_output,
     "smooth_l1"     : _mx_smooth_l1,
+    "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
     # vision
     "_contrib_BilinearResize2D" : _mx_resize,
     "_contrib_MultiBoxPrior" : _mx_multibox_prior,
@@ -1183,8 +1193,10 @@ def from_mxnet(symbol,
         params = {}
         for k, v in symbol.collect_params().items():
             params[k] = _nd.array(v.data().asnumpy())
-        data = mx.sym.Variable("data")
-        sym = symbol(data)
+        inputs = []
+        for name in shape:
+            inputs.append(mx.sym.Variable(name))
+        sym = symbol(*inputs)
         if isinstance(sym, (list, tuple)):
             sym = mx.sym.Group(sym)
         shape, dtype = _update_shape_dtype(shape, dtype, params)
index 09ae02b..451679c 100644 (file)
@@ -714,6 +714,19 @@ def test_forward_sequence_mask():
     verify((5, 4, 3), False, 1.0, 1, 'float64', 'float64')
     verify((5, 4, 3, 2), True, 1.0, 0, 'float32', 'float32')
 
+def test_forward_contrib_div_sqrt_dim():
+    def verify(shape):
+        x_np = np.random.uniform(size=shape).astype("float32")
+        ref_res = mx.nd.contrib.div_sqrt_dim(mx.nd.array(x_np))
+        mx_sym = mx.sym.contrib.div_sqrt_dim(mx.sym.var("x"))
+        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(x_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    verify((3, 4))
+    verify((3, 4, 5))
 
 if __name__ == '__main__':
     test_forward_mlp()
@@ -759,3 +772,4 @@ if __name__ == '__main__':
     test_forward_argsort()
     test_forward_topk()
     test_forward_sequence_mask()
+    test_forward_contrib_div_sqrt_dim()
index c973f1d..b5dd802 100644 (file)
@@ -38,6 +38,7 @@ def schedule_batch_matmul(outs):
     s: Schedule
         The computation schedule for the op.
     """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
 
     def _schedule(op):
@@ -49,6 +50,9 @@ def schedule_batch_matmul(outs):
         BB = s.cache_read(B, "shared", [C])
         BL = s.cache_read(BB, "local", [C])
         CC = s.cache_write(C, "local")
+        if op not in s.outputs:
+            s[C].compute_inline()
+            C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
         y_bn = get_max_power2_factor(M, 64)