Restore LayerNorm numerics test (#64385)
authorRichard Zou <zou3519@gmail.com>
Wed, 1 Sep 2021 22:12:05 +0000 (15:12 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 22:32:49 +0000 (15:32 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64385

It was deleted in https://github.com/pytorch/pytorch/pull/63276.

The numerics test was meant to check LayerNorm behavior on large inputs,
but we deleted it without realizing that.

Test Plan: - wait for tests.

Reviewed By: ngimel

Differential Revision: D30702950

Pulled By: zou3519

fbshipit-source-id: a480e26c45ec38fb628938b70416cdb22d976a46

test/test_nn.py

index e60ff69..5008c72 100644 (file)
@@ -13282,6 +13282,32 @@ class TestNNDeviceType(NNTestCase):
             self._test_LayerNorm_cuda_half(device)
 
     @onlyOnCPUAndCUDA
+    def test_LayerNorm_numeric(self, device):
+        def layer_norm_ref(X, gamma, beta, normalized_shape, eps):
+            feature_size = np.prod(normalized_shape)
+            X_view = X.view(-1, feature_size)
+            mean = X_view.mean(dim=-1, keepdim=True)
+            var = X_view.var(dim=-1, unbiased=False, keepdim=True)
+            Y = (X_view - mean) / torch.sqrt(var + eps)
+            Y = Y * gamma.view(-1) + beta.view(-1)
+            return Y.view(*X.size())
+
+        normalized_shape = [256, 256, 144]
+        layer_norm = nn.LayerNorm(normalized_shape).float().to(device)
+        X = torch.rand(2, *normalized_shape, dtype=torch.float32,
+                       device=device)
+
+        Y = layer_norm(X)
+        Y_ref = layer_norm_ref(X, layer_norm.weight.data, layer_norm.bias.data,
+                               normalized_shape, layer_norm.eps)
+        self.assertEqual(Y, Y_ref, rtol=0, atol=1e-5)
+
+        if self.device_type == 'cuda':
+            layer_norm.cpu()
+            Y_cpu = layer_norm(X.cpu())
+            self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5)
+
+    @onlyOnCPUAndCUDA
     def test_GroupNorm_general(self, device):
         self._test_GroupNorm_general(device)