scale=scale)
return _impl
+def _get_dims(data):
+ import torch
+ if isinstance(data, _expr.Expr):
+ dims = _infer_shape(data)
+ elif isinstance(data, list):
+ dims = data
+ elif isinstance(data, (torch.Tensor, np.ndarray)):
+ dims = data.shape
+ else:
+ msg = "Data type %s could not be parsed" % type(data)
+ raise AssertionError(msg)
+ return dims
+
+def _layer_norm():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ ndims = len(_get_dims(inputs[1]))
+ assert ndims == 1, "Support only normalization over last one dimension."
+
+ return _op.nn.layer_norm(data,
+ gamma=inputs[1],
+ beta=inputs[2],
+ axis=-1,
+ epsilon=float(inputs[4]),
+ center=False,
+ scale=False)
+ return _impl
+
def _transpose():
def _impl(inputs, input_types):
data = inputs[0]
"aten::contiguous" : _contiguous(),
"aten::batch_norm" : _batch_norm(),
"aten::instance_norm" : _instance_norm(),
+ "aten::layer_norm" : _layer_norm(),
"aten::transpose" : _transpose(),
"aten::transpose_" : _transpose(),
"aten::t" : _transpose(),
(torch.nn.InstanceNorm3d(16), inp_3d)]:
verify_model(ins_norm.eval(), input_data=inp)
+def test_forward_layernorm():
+ inp = torch.rand((20, 5, 10, 10))
+ verify_model(torch.nn.LayerNorm(10).eval(), input_data=inp)
def test_forward_transpose():
torch.set_grad_enabled(False)
test_forward_contiguous()
test_forward_batchnorm()
test_forward_instancenorm()
+ test_forward_layernorm()
test_forward_transpose()
test_forward_size()
test_forward_view()