use flake8-mypy (#17721)
authorElias Ellison <eellison@fb.com>
Thu, 7 Mar 2019 17:12:35 +0000 (09:12 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Mar 2019 17:15:54 +0000 (09:15 -0800)
Summary:
Use flake8 installed with mypy checks so that our linter matches fbcode. Mypy type errors also provide valuable signal
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17721

Differential Revision: D14357778

Pulled By: eellison

fbshipit-source-id: d8c9ea3fe3b5f550c3b70fe259e0eabf95e4c92d

.flake8
.travis.aten.yml
.travis.yml
test/test_jit.py
torch/_jit_internal.py
torch/jit/quantized.py

diff --git a/.flake8 b/.flake8
index 8510253..9048c94 100644 (file)
--- a/.flake8
+++ b/.flake8
@@ -1,4 +1,4 @@
 [flake8]
 max-line-length = 120
 ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504
-exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install
+exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include
index 0e9d802..2425845 100644 (file)
@@ -27,5 +27,5 @@ matrix:
     include:
         env: LINT_CHECK
         python: "2.7"
-        install: pip install flake8
+        install: pip install flake8-mypy
         script: flake8
index 8d1da41..6aa7077 100644 (file)
@@ -29,7 +29,7 @@ matrix:
         python: "3.7"
         dist: xenial    # required for Python 3.7 (travis-ci/travis-ci#9069)
         sudo: required  # required for Python 3.7 (travis-ci/travis-ci#9069)
-        install: pip install flake8
+        install: pip install flake8-mypy
         script: flake8
       - name: "MyPy typecheck"
         python: "3.6"
index 7c7824e..9879eeb 100644 (file)
@@ -5,11 +5,13 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.parallel as dp
 import torch.optim as optim
+import torch.cuda
 import torch.jit.quantized
 from contextlib import contextmanager
 from itertools import product, chain
 import torch.jit.frontend
 from torch.autograd import Variable, Function
+from torch.nn import Module
 from torch.autograd.function import traceable
 from torch.testing import assert_allclose
 from torch.onnx import OperatorExportTypes
@@ -44,9 +46,11 @@ from torch._C import TensorType, TupleType, FloatType, IntType, \
     ListType, StringType, DictType
 from copy import deepcopy
 import random
-from typing import List, Dict, Optional
+from typing import List, Dict, Optional, Tuple
 from torch.jit.frontend import NotSupportedError
 from torch.jit import BatchTensor
+from torch import Tensor
+from torch.jit.annotations import BroadcastingList2, BroadcastingList3
 
 # For testing truediv in python 2
 from test_module.future_div import div_int_future, div_float_future
@@ -96,7 +100,7 @@ if WINDOWS:
         finally:
             os.unlink(f.name)
 else:
-    @contextmanager
+    @contextmanager  # noqa: T484
     def TemporaryFileName():
         with tempfile.NamedTemporaryFile() as f:
             yield f.name
@@ -2262,7 +2266,7 @@ class TestJit(JitTestCase):
         with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
 
             @torch.jit.script
-            def hints_bad_types(x, a=10, b=0.5):
+            def hints_bad_types(x, a=10, b=0.5):  # noqa: T484
                 # type: (Tensor, float, int) -> Tensor
                 return x + a + b
 
@@ -3113,7 +3117,7 @@ class TestScript(JitTestCase):
             def sum_list(a):
                 # type: (int) -> int
                 sum = 0
-                for i in a:
+                for i in a:  # noqa: T484
                     sum += i
 
                 return sum
@@ -4727,23 +4731,23 @@ a")
                 x = 1
             else:
                 x = torch.jit._unwrap_optional(x)
-            return x
+            return x  # noqa: T484
 
         with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
             @torch.jit.script
             def or_error(x, y):
-                # type: (Optional[int], Optional[int]) -> int
+                # type: (Optional[int], Optional[int]) -> None
                 if x is None or y is None:
-                    print(x + y)
+                    print(x + y)  # noqa: T484
 
         with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
             @torch.jit.script
             def and_error(x, y):
-                # type: (Optional[int], Optional[int]) -> int
+                # type: (Optional[int], Optional[int]) -> None
                 if x is None and y is None:
                     pass
                 else:
-                    print(x + y)
+                    print(x + y)  # noqa: T484
 
         with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
             @torch.jit.script
@@ -4751,7 +4755,7 @@ a")
                 # type: (Optional[int]) -> None
                 x_none = x is not None
                 if x_none:
-                    print(x + 1)
+                    print(x + 1)  # noqa: T484
 
         with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
             @torch.jit.script
@@ -4759,7 +4763,7 @@ a")
                 # type: (Optional[int], Optional[int]) -> None
                 x_none = x is not None
                 if y is not None and x_none:
-                    print(x + y)
+                    print(x + y)  # noqa: T484
 
     def test_while_write_outer_then_read(self):
         def func(a, b):
@@ -5057,10 +5061,11 @@ a")
         self.checkScript(multiple_returns, [a], optimize=True)
 
         with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
-            @torch.jit.script
+            torch.jit.CompilationUnit('''
             def no_return_bad_annotation(a):
                 # type: (Tensor) -> Tensor
                 a + 1
+            ''')
 
     def test_error(self):
         @torch.jit.script
@@ -5654,8 +5659,6 @@ a")
                 hiddens = hx
 
             if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
-                from typing import Tuple
-
                 class ScriptWrapper(torch.jit.ScriptModule):
                     def __init__(self, cell):
                         super(ScriptWrapper, self).__init__()
@@ -6650,7 +6653,7 @@ a")
         with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
             def foo():
                 # type: () -> Tensor
-                return ((3, 4),)
+                return ((3, 4),)  # noqa: T484
 
             @torch.jit.script
             def bar():
@@ -6769,7 +6772,7 @@ a")
                 if x:
                     y = [1]
                 else:
-                    y = [None]
+                    y = [None]  # noqa: T484
                 return y[0]
 
         @torch.jit.script
@@ -6815,18 +6818,18 @@ a")
             print(int_fn((1, 1, 1)))
 
         with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
-            @torch.jit.script
+            @torch.jit.script  # noqa: T484
             def fn(x):
-                # type: (BroadcastingListx[int]) -> List[int]
+                # type: (BroadcastingListx[int]) -> List[int]  # noqa: T484
                 return x
 
-        # TODO: the type comment in this seems to trip up flake8 for some reason
-        # even though we have a noqa comment. Figure out why
+        # using CU so that flake8 error on int[2] is not raised (noqa not working)
         with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
-            @torch.jit.script
-            def nested(x, y):
-                # type: (int, Tuple[int, int[2]]) -> List[int] # noqa: T484
-                return x
+            cu = torch.jit.CompilationUnit('''
+                def nested(x, y):
+                    # type: (int, Tuple[int, int[2]]) -> List[int]
+                    return x  # noqa: T484
+            ''')
 
     def test_ntuple_builtins(self):
         from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
@@ -8349,7 +8352,7 @@ a")
         with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
             def somefunc():
                 # type: () -> Tuple[Tuple[Tensor, Tensor]]
-                return torch.zeros(3, 4), torch.zeros(4, 5)
+                return torch.zeros(3, 4), torch.zeros(4, 5)  # noqa: T484
 
             @torch.jit.script
             def wrong_return_type():
@@ -9029,7 +9032,7 @@ a")
         def test(x):
             # type: (Optional[int]) -> int
             x = torch.jit._unwrap_optional(x)
-            x = x + x
+            x = x + x  # noqa: T484
             return x
 
         self.checkScript(test, (3,))
@@ -9082,14 +9085,14 @@ a")
             @torch.jit.script
             def return_tup(x):
                 # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
-                return x, x
+                return x, x  # noqa: T484
 
     def test_annotated_script_fn_arg_mismatch(self):
         with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
             @torch.jit.script
             def tuple_arg(x):
                 # type: (Tuple[Tensor, Tensor]) -> Tensor
-                return x + 1
+                return x + 1  # noqa: T484
 
     def test_script_non_tensor_args_outputs(self):
         @torch.jit.script
@@ -13122,11 +13125,11 @@ class TestAsync(JitTestCase):
         self.assertEqual(y, y_hat)
 
     def test_async_script_capture(self):
-        class Module(torch.jit.ScriptModule):
+        class Mod(torch.jit.ScriptModule):
             __constants__ = ['const']
 
             def __init__(self):
-                super(Module, self).__init__(False)
+                super(Mod, self).__init__(False)
                 self.const = 42
                 self.param = nn.Parameter(torch.randn(2, 2))
 
@@ -13144,7 +13147,7 @@ class TestAsync(JitTestCase):
         x1 = torch.rand(3, 4)
         x2 = torch.rand(5, 6)
 
-        m = Module()
+        m = Mod()
         y, y_hat = m.wait_script(x1, x2)
 
         self.assertEqual(y, y_hat)
@@ -13244,9 +13247,9 @@ class TestAsync(JitTestCase):
             def forward(self, x):
                 return (torch.neg(x), x)
 
-        class Module(torch.jit.ScriptModule):
+        class Mod(torch.jit.ScriptModule):
             def __init__(self):
-                super(Module, self).__init__(False)
+                super(Mod, self).__init__(False)
                 x = torch.rand(3, 3)
                 self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
 
@@ -13266,10 +13269,10 @@ class TestAsync(JitTestCase):
                 # return a nested structure of tensors
                 return (tensor_list, tensor_tuple, tensor_tuple[1])
 
-        class Tuple(nn.Module):
+        class TupleCl(nn.Module):
             def __init__(self):
-                super(Tuple, self).__init__()
-                self.module = Module()
+                super(TupleCl, self).__init__()
+                self.module = Mod()
 
             def forward(self, x):
                 z = torch.neg(x)
@@ -13278,7 +13281,7 @@ class TestAsync(JitTestCase):
                 return tuple(list)
 
         x = torch.rand(3, 3)
-        module = torch.jit.trace(Tuple(), (x), _force_outplace=True)
+        module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
 
         # Make sure we have forks
         self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
@@ -13632,16 +13635,16 @@ class TestClassType(JitTestCase):
             @torch.jit.script
             class FooTest:
                 def __init__(self, x):
-                    # type: (int)
+                    # type: (int) -> None
                     self.foo = x
 
                 def incFooTest(self, y):
-                    # type: (int)
+                    # type: (int) -> None
                     self.foo = self.foo + y
 
             @torch.jit.script
             def fn(x):
-                # type: (int)
+                # type: (int) -> int
                 foo = FooTest(x)
                 foo.incFooTest(2)
                 return foo.foo
@@ -13689,7 +13692,7 @@ class TestClassType(JitTestCase):
                 @torch.jit.script
                 class FooTest:
                     def __init__(self, x):
-                        # type: (bool)
+                        # type: (bool) -> None
                         self.foo = x
 
                 @torch.jit.script
@@ -13718,7 +13721,7 @@ class TestClassType(JitTestCase):
 
             @torch.jit.script
             def fn(foo):
-                # type: (FooTest)
+                # type: (FooTest) -> Tensor
                 return foo.attr
 
             @torch.jit.script
index 28053c5..3667cfe 100644 (file)
@@ -9,24 +9,24 @@ import inspect
 from torch._six import builtins
 
 # Tracks standalone weak script functions
-compiled_weak_fns = weakref.WeakKeyDictionary()
+compiled_weak_fns = weakref.WeakKeyDictionary()  # noqa: T484
 
 # Tracks which methods should be converted to strong methods
-weak_script_methods = weakref.WeakKeyDictionary()
+weak_script_methods = weakref.WeakKeyDictionary()  # noqa: T484
 
 # Converted modules and their corresponding WeakScriptModuleProxy objects
-weak_modules = weakref.WeakKeyDictionary()
+weak_modules = weakref.WeakKeyDictionary()  # noqa: T484
 
 # Types that have been declared as weak modules
-weak_types = weakref.WeakKeyDictionary()
+weak_types = weakref.WeakKeyDictionary()  # noqa: T484
 
 # Wrapper functions that can call either of 2 functions depending on a boolean
 # argument
-boolean_dispatched = weakref.WeakKeyDictionary()
+boolean_dispatched = weakref.WeakKeyDictionary()  # noqa: T484
 
 # Python Op functions that should be ignored by the compiler. These will be replaced
 # with an operator that always throws an error
-ignored_fns = weakref.WeakSet()
+ignored_fns = weakref.WeakSet()  # noqa: T484
 
 COMPILATION_PENDING = object()
 COMPILED = object()
@@ -223,9 +223,9 @@ except ImportError:
         def __getitem__(self, types):
             return DictInstance(types)
 
-    Tuple = TupleCls()
-    List = ListCls()
-    Dict = DictCls()
+    Tuple = TupleCls()  # noqa: T484
+    List = ListCls()  # noqa: T484
+    Dict = DictCls()  # noqa: T484
 
     def is_tuple(ann):
         return isinstance(ann, TupleInstance)
index ec4e89a..b5775b9 100644 (file)
@@ -1,7 +1,9 @@
 import torch
 import copy
 import numbers
-from typing import Tuple
+from typing import Tuple, Optional
+from torch import Tensor
+from torch.jit import ScriptModule
 
 from torch.nn.utils.rnn import PackedSequence
 from torch.nn import _VF