[PYTORCH]LayerNorm support added (#5249)
authorSamuel <siju.samuel@huawei.com>
Mon, 6 Apr 2020 20:31:19 +0000 (02:01 +0530)
committerGitHub <noreply@github.com>
Mon, 6 Apr 2020 20:31:19 +0000 (05:31 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 977a899..708e025 100644 (file)
@@ -503,6 +503,34 @@ def _instance_norm():
                                     scale=scale)
     return _impl
 
+def _get_dims(data):
+    import torch
+    if isinstance(data, _expr.Expr):
+        dims = _infer_shape(data)
+    elif isinstance(data, list):
+        dims = data
+    elif isinstance(data, (torch.Tensor, np.ndarray)):
+        dims = data.shape
+    else:
+        msg = "Data type %s could not be parsed" % type(data)
+        raise AssertionError(msg)
+    return dims
+
+def _layer_norm():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        ndims = len(_get_dims(inputs[1]))
+        assert ndims == 1, "Support only normalization over last one dimension."
+
+        return _op.nn.layer_norm(data,
+                                 gamma=inputs[1],
+                                 beta=inputs[2],
+                                 axis=-1,
+                                 epsilon=float(inputs[4]),
+                                 center=False,
+                                 scale=False)
+    return _impl
+
 def _transpose():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1050,6 +1078,7 @@ _convert_map = {
     "aten::contiguous"                      : _contiguous(),
     "aten::batch_norm"                      : _batch_norm(),
     "aten::instance_norm"                   : _instance_norm(),
+    "aten::layer_norm"                      : _layer_norm(),
     "aten::transpose"                       : _transpose(),
     "aten::transpose_"                      : _transpose(),
     "aten::t"                               : _transpose(),
index e7c2e08..fa32dca 100644 (file)
@@ -561,6 +561,9 @@ def test_forward_instancenorm():
                           (torch.nn.InstanceNorm3d(16), inp_3d)]:
         verify_model(ins_norm.eval(), input_data=inp)
 
+def test_forward_layernorm():
+    inp = torch.rand((20, 5, 10, 10))
+    verify_model(torch.nn.LayerNorm(10).eval(), input_data=inp)
 
 def test_forward_transpose():
     torch.set_grad_enabled(False)
@@ -1132,6 +1135,7 @@ if __name__ == "__main__":
     test_forward_contiguous()
     test_forward_batchnorm()
     test_forward_instancenorm()
+    test_forward_layernorm()
     test_forward_transpose()
     test_forward_size()
     test_forward_view()