extract TestAutogradComplex into its own test file (#63400)
authorMichael Dagitses <mikeyd@fb.com>
Thu, 2 Sep 2021 11:04:59 +0000 (04:04 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 11:34:35 +0000 (04:34 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63400

This is the first step to break up test_autograd.py for #63205.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D30541499

Pulled By: dagitses

fbshipit-source-id: 8d9d32007938b9eade0e88f95a6a3190e7e2ef01

test/autograd/test_complex.py [new file with mode: 0644]
test/test_autograd.py
tools/testing/modulefinder_determinator.py

diff --git a/test/autograd/test_complex.py b/test/autograd/test_complex.py
new file mode 100644 (file)
index 0000000..74fcfda
--- /dev/null
@@ -0,0 +1,103 @@
+import torch
+
+from torch.testing._internal.common_utils import TestCase, run_tests, gradcheck
+
+
+class TestAutogradComplex(TestCase):
+    def test_view_func_for_complex_views(self):
+        # case 1: both parent and child have view_func
+        x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
+        y = x.detach().requires_grad_(True)
+
+        x0 = x.clone()
+        x1 = torch.view_as_complex(x0)
+        x2 = torch.view_as_real(x1)
+        x2.mul_(2)
+        x2.sum().backward()
+
+        y0 = y.clone()
+        y0.mul_(2)
+        y0.sum().backward()
+
+        self.assertEqual(x.grad, y.grad)
+
+        # case 2: parent has view_func but child does not
+        x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
+        y = x.detach().requires_grad_(True)
+
+        def fn(a):
+            b = a.clone()
+            b1 = torch.view_as_complex(b)
+            b2 = b1.reshape(b1.numel())
+            return b2
+
+        x0 = fn(x)
+        x0.mul_(2)
+        x0.sum().backward()
+
+        y0 = fn(y)
+        y1 = y0.mul(2)
+        y1.sum().backward()
+
+        self.assertEqual(x.grad, y.grad)
+
+        # case 3: parent does not have a view_func but child does
+        x = torch.randn(10, dtype=torch.cdouble, requires_grad=True)
+        y = x.detach().requires_grad_(True)
+
+        def fn(a, dim0_size=5):
+            b = a.clone()
+            b1 = b.reshape(dim0_size, 2)
+            b2 = torch.view_as_real(b1)
+            return b2
+
+        x0 = fn(x)
+        x0.mul_(2)
+        x0.sum().backward()
+
+        y0 = fn(y)
+        y1 = y0.mul(2)
+        y1.sum().backward()
+
+        self.assertEqual(x.grad, y.grad)
+
+    def test_view_with_multi_output(self):
+        x = torch.randn(2, 2, 2, dtype=torch.double)
+
+        x1 = torch.view_as_complex(x)
+        # Taking an invalid view should always be allowed as long as it is not
+        # modified inplace
+        res = x1.unbind(0)
+
+        with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
+            res[0] += torch.rand(2, requires_grad=True)
+
+        x.requires_grad_(True)
+        x1 = torch.view_as_complex(x)
+        # Taking an invalid view should always be allowed as long as it is not
+        # modified inplace
+        res = x1.unbind(0)
+
+        with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
+            res[0] += torch.rand(2, requires_grad=True)
+
+    def as_identity(self):
+        # view_as_real and view_as_complex behavior should be like an identity
+        def func(z):
+            z_ = torch.view_as_complex(z)
+            z_select = torch.select(z_, z_.dim() - 1, 0)
+            z_select_real = torch.view_as_real(z_select)
+            return z_select_real.sum()
+
+        z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True)
+        gradcheck(func, [z])
+        func(z).backward()
+
+        z1 = z.clone().detach().requires_grad_(True)
+        torch.select(z1, z1.dim() - 2, 0).sum().backward()
+
+        self.assertEqual(z.grad, z1.grad)
+
+
+if __name__ == '__main__':
+    run_tests()
index ebe3aa5..fde64b0 100644 (file)
@@ -28,7 +28,6 @@ from torch.testing import make_tensor
 from torch.testing._internal.common_cuda import TEST_CUDA
 from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack,
                                                   suppress_warnings, slowTest,
-                                                  load_tests,
                                                   IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck,
                                                   TEST_WITH_ROCM, disable_gc,
                                                   gradcheck, gradgradcheck)
@@ -44,11 +43,6 @@ from torch.testing._internal.common_device_type import (instantiate_device_type_
                                                         deviceCountAtLeast, skipCUDAIfCudnnVersionLessThan,
                                                         skipCUDAIf, skipMeta)
 
-
-# load_tests from common_utils is used to automatically filter tests for
-# sharding on sandcastle. This line silences flake warnings
-load_tests = load_tests
-
 import pickle
 
 PRECISION = 1e-4
@@ -6173,101 +6167,6 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
         test_case.assertEqual(self_variable.size(), self_variable.grad.size())
 
 
-class TestAutogradComplex(TestCase):
-    def test_view_func_for_complex_views(self):
-        # case 1: both parent and child have view_func
-        x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
-        y = x.detach().requires_grad_(True)
-
-        x0 = x.clone()
-        x1 = torch.view_as_complex(x0)
-        x2 = torch.view_as_real(x1)
-        x2.mul_(2)
-        x2.sum().backward()
-
-        y0 = y.clone()
-        y0.mul_(2)
-        y0.sum().backward()
-
-        self.assertEqual(x.grad, y.grad)
-
-        # case 2: parent has view_func but child does not
-        x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
-        y = x.detach().requires_grad_(True)
-
-        def fn(a):
-            b = a.clone()
-            b1 = torch.view_as_complex(b)
-            b2 = b1.reshape(b1.numel())
-            return b2
-
-        x0 = fn(x)
-        x0.mul_(2)
-        x0.sum().backward()
-
-        y0 = fn(y)
-        y1 = y0.mul(2)
-        y1.sum().backward()
-
-        self.assertEqual(x.grad, y.grad)
-
-        # case 3: parent does not have a view_func but child does
-        x = torch.randn(10, dtype=torch.cdouble, requires_grad=True)
-        y = x.detach().requires_grad_(True)
-
-        def fn(a, dim0_size=5):
-            b = a.clone()
-            b1 = b.reshape(dim0_size, 2)
-            b2 = torch.view_as_real(b1)
-            return b2
-
-        x0 = fn(x)
-        x0.mul_(2)
-        x0.sum().backward()
-
-        y0 = fn(y)
-        y1 = y0.mul(2)
-        y1.sum().backward()
-
-        self.assertEqual(x.grad, y.grad)
-
-    def test_view_with_multi_output(self):
-        x = torch.randn(2, 2, 2, dtype=torch.double)
-
-        x1 = torch.view_as_complex(x)
-        # Taking an invalid view should always be allowed as long as it is not
-        # modified inplace
-        res = x1.unbind(0)
-
-        with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
-            res[0] += torch.rand(2, requires_grad=True)
-
-        x.requires_grad_(True)
-        x1 = torch.view_as_complex(x)
-        # Taking an invalid view should always be allowed as long as it is not
-        # modified inplace
-        res = x1.unbind(0)
-
-        with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
-            res[0] += torch.rand(2, requires_grad=True)
-
-    def as_identity(self):
-        # view_as_real and view_as_complex behavior should be like an identity
-        def func(z):
-            z_ = torch.view_as_complex(z)
-            z_select = torch.select(z_, z_.dim() - 1, 0)
-            z_select_real = torch.view_as_real(z_select)
-            return z_select_real.sum()
-
-        z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True)
-        gradcheck(func, [z])
-        func(z).backward()
-
-        z1 = z.clone().detach().requires_grad_(True)
-        torch.select(z1, z1.dim() - 2, 0).sum().backward()
-
-        self.assertEqual(z.grad, z1.grad)
-
 class TestAutogradFunctional(TestCase):
     def _assert_same_struct(self, res, base):
         # base and res should be Tensors or tuple of Tensors with the same size
@@ -9640,6 +9539,11 @@ class TestMultithreadAutograd(TestCase):
         torch.autograd.gradcheck(fn, [inp_r, inp_c], check_forward_ad=True)
         torch.autograd.gradcheck(fn, [inp_c, inp_r], check_forward_ad=True)
 
+# Import test cases from below autograd/ here. These are found
+# implicitly by the loader, so Flake8 thinks they are unused, hence
+# the suppressions.
+
+from autograd.test_complex import TestAutogradComplex  # noqa: F401
 
 # e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA
 instantiate_device_type_tests(
index b6c94e7..32dc103 100644 (file)
@@ -48,7 +48,10 @@ TARGET_DET_LIST = [
     "distributed/test_pg_wrapper",
     "distributed/test_store",
     "distributions/test_distributions",
-    "test_autograd",
+    # test_autograd.py is not slow, so it does not belong here. But
+    # note that if you try to add it back it will run into
+    # https://bugs.python.org/issue40350 because it imports files
+    # under test/autograd/.
     "test_binary_ufuncs",
     "test_cpp_extensions_aot_ninja",
     "test_cpp_extensions_aot_no_ninja",