return _op.topk(inputs[0], **new_attrs)
-def _mx_SequenceMask(inputs, attrs):
+def _mx_sequence_mask(inputs, attrs):
assert len(inputs) == 1 or len(inputs) == 2
new_attrs = {}
use_sequence_length = attrs.get_bool('use_sequence_length', False)
return inputs[0]
+def _mx_contrib_div_sqrt_dim(inputs, _):
+ assert len(inputs) == 1
+ ndim = len(_infer_type(inputs[0]).checked_type.shape)
+ dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32"))
+ sqrt_dim = _op.sqrt(dim.astype('float32'))
+ out = inputs[0] / sqrt_dim
+ return out
+
+
def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs]
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"topk" : _mx_topk,
- "SequenceMask" : _mx_SequenceMask,
+ "SequenceMask" : _mx_sequence_mask,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output,
"smooth_l1" : _mx_smooth_l1,
+ "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
# vision
"_contrib_BilinearResize2D" : _mx_resize,
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
params = {}
for k, v in symbol.collect_params().items():
params[k] = _nd.array(v.data().asnumpy())
- data = mx.sym.Variable("data")
- sym = symbol(data)
+ inputs = []
+ for name in shape:
+ inputs.append(mx.sym.Variable(name))
+ sym = symbol(*inputs)
if isinstance(sym, (list, tuple)):
sym = mx.sym.Group(sym)
shape, dtype = _update_shape_dtype(shape, dtype, params)
verify((5, 4, 3), False, 1.0, 1, 'float64', 'float64')
verify((5, 4, 3, 2), True, 1.0, 0, 'float32', 'float32')
+def test_forward_contrib_div_sqrt_dim():
+ def verify(shape):
+ x_np = np.random.uniform(size=shape).astype("float32")
+ ref_res = mx.nd.contrib.div_sqrt_dim(mx.nd.array(x_np))
+ mx_sym = mx.sym.contrib.div_sqrt_dim(mx.sym.var("x"))
+ mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+ op_res = intrp.evaluate()(x_np)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+ verify((3, 4))
+ verify((3, 4, 5))
if __name__ == '__main__':
test_forward_mlp()
test_forward_argsort()
test_forward_topk()
test_forward_sequence_mask()
+ test_forward_contrib_div_sqrt_dim()
s: Schedule
The computation schedule for the op.
"""
+ outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(op):
BB = s.cache_read(B, "shared", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")
+ if op not in s.outputs:
+ s[C].compute_inline()
+ C = s.outputs[0].output(0)
b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 64)