OpInfo for `nn.functional.layer_norm` (#63276)
authorKushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
Wed, 1 Sep 2021 15:48:25 +0000 (08:48 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 16:31:45 +0000 (09:31 -0700)
Summary:
Please see https://github.com/facebookresearch/functorch/issues/78 and https://github.com/pytorch/pytorch/issues/54261.

Note:

* This PR also adds a reference test inspired by existing tests in `test_nn.py`.

cc: mruberry zou3519

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63276

Reviewed By: ejguan

Differential Revision: D30452483

Pulled By: zou3519

fbshipit-source-id: 2578d01ca34e031668a41bd284db60c31ae1fba8

test/test_nn.py
torch/testing/_internal/common_methods_invocations.py

index 5008c72..e60ff69 100644 (file)
@@ -13282,32 +13282,6 @@ class TestNNDeviceType(NNTestCase):
             self._test_LayerNorm_cuda_half(device)
 
     @onlyOnCPUAndCUDA
-    def test_LayerNorm_numeric(self, device):
-        def layer_norm_ref(X, gamma, beta, normalized_shape, eps):
-            feature_size = np.prod(normalized_shape)
-            X_view = X.view(-1, feature_size)
-            mean = X_view.mean(dim=-1, keepdim=True)
-            var = X_view.var(dim=-1, unbiased=False, keepdim=True)
-            Y = (X_view - mean) / torch.sqrt(var + eps)
-            Y = Y * gamma.view(-1) + beta.view(-1)
-            return Y.view(*X.size())
-
-        normalized_shape = [256, 256, 144]
-        layer_norm = nn.LayerNorm(normalized_shape).float().to(device)
-        X = torch.rand(2, *normalized_shape, dtype=torch.float32,
-                       device=device)
-
-        Y = layer_norm(X)
-        Y_ref = layer_norm_ref(X, layer_norm.weight.data, layer_norm.bias.data,
-                               normalized_shape, layer_norm.eps)
-        self.assertEqual(Y, Y_ref, rtol=0, atol=1e-5)
-
-        if self.device_type == 'cuda':
-            layer_norm.cpu()
-            Y_cpu = layer_norm(X.cpu())
-            self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5)
-
-    @onlyOnCPUAndCUDA
     def test_GroupNorm_general(self, device):
         self._test_GroupNorm_general(device)
 
index 3579310..fe8e36f 100644 (file)
@@ -2548,6 +2548,42 @@ def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwar
 
     return list(generator())
 
+def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as input shape, normalized_shape and a kwarg dict for eps
+    cases: Tuple[Tuple[int], Tuple[int], dict] = (  # type: ignore[assignment]
+        ((1, 2, 3), (1, 2, 3), {'eps': 0.5}),
+        ((2, 2, 3), (2, 3), {'eps': -0.5}),
+        ((1,), (1,), {}),
+        ((1, 2), (2,), {}),
+        ((0, 1), (1,), {}),
+    )
+
+    def generator():
+        for input_shape, normalized_shape, kwargs in cases:
+            # Shape of weight and bias should be the same as normalized_shape
+            weight = make_arg(normalized_shape)
+            bias = make_arg(normalized_shape)
+            yield SampleInput(
+                make_arg(input_shape),
+                args=(normalized_shape, weight, bias),
+                kwargs=kwargs
+            )
+        # Without any optional args
+        yield SampleInput(make_arg((1, 2)), args=((2,),))
+
+        # TODO: @krshrimali, once to_numpy method in SampleInput class is modified to take None inputs,
+        # enable these inputs; see https://github.com/pytorch/pytorch/pull/63276#discussion_r691950400
+
+        # With weight and a `None` bias
+        # yield SampleInput(make_arg((1, 2)), args=((2,), make_arg((2,)), None))
+
+        # With `None` weight and bias (tests failing for this, see the link above)
+        # yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,))))
+
+    return list(generator())
+
 def sample_inputs_hardswish(self, device, dtype, requires_grad):
     N = 5
     # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ?
@@ -5595,6 +5631,21 @@ def reference_mse_loss(input, target, reduction="mean"):
         return se
 
 
+def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
+    feature_size = np.prod(normalized_shape)
+    inp_view = inp.reshape(-1, feature_size)  # type: ignore[call-overload]
+    mean = inp_view.mean(axis=-1, keepdims=True)
+    var = inp_view.var(axis=-1, ddof=0, keepdims=True)
+    Y = (inp_view - mean) / np.sqrt(var + eps)
+    if weight is None and bias is not None:
+        Y = Y + bias.reshape(-1)
+    elif weight is not None and bias is None:
+        Y = Y * weight.reshape(-1)
+    elif weight is not None and bias is not None:
+        Y = Y * weight.reshape(-1) + bias.reshape(-1)
+    return Y.reshape(*inp.shape)
+
+
 def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs):
     """Gradcheck wrapper for functions that take Hermitian matrices as input.
 
@@ -7235,6 +7286,20 @@ op_db: List[OpInfo] = [
                SkipInfo('TestJit', 'test_variant_consistency_jit'),
            ),
            supports_out=False,),
+    OpInfo('nn.functional.layer_norm',
+           aten_name='layer_norm',
+           aliases=('layer_norm',),
+           ref=reference_layer_norm,
+           dtypes=floating_types_and(torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
+                   'TestCommon', 'test_reference_testing'
+               ),
+           ],
+           sample_inputs_func=sample_inputs_layer_norm,),
     OpInfo('nn.functional.pad',
            variant_test_name='constant',
            aten_name='constant_pad_nd',