import torch.nn.functional as F
import torch.nn.parallel as dp
import torch.optim as optim
+import torch.cuda
import torch.jit.quantized
from contextlib import contextmanager
from itertools import product, chain
import torch.jit.frontend
from torch.autograd import Variable, Function
+from torch.nn import Module
from torch.autograd.function import traceable
from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes
ListType, StringType, DictType
from copy import deepcopy
import random
-from typing import List, Dict, Optional
+from typing import List, Dict, Optional, Tuple
from torch.jit.frontend import NotSupportedError
from torch.jit import BatchTensor
+from torch import Tensor
+from torch.jit.annotations import BroadcastingList2, BroadcastingList3
# For testing truediv in python 2
from test_module.future_div import div_int_future, div_float_future
finally:
os.unlink(f.name)
else:
- @contextmanager
+ @contextmanager # noqa: T484
def TemporaryFileName():
with tempfile.NamedTemporaryFile() as f:
yield f.name
with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
@torch.jit.script
- def hints_bad_types(x, a=10, b=0.5):
+ def hints_bad_types(x, a=10, b=0.5): # noqa: T484
# type: (Tensor, float, int) -> Tensor
return x + a + b
def sum_list(a):
# type: (int) -> int
sum = 0
- for i in a:
+ for i in a: # noqa: T484
sum += i
return sum
x = 1
else:
x = torch.jit._unwrap_optional(x)
- return x
+ return x # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def or_error(x, y):
- # type: (Optional[int], Optional[int]) -> int
+ # type: (Optional[int], Optional[int]) -> None
if x is None or y is None:
- print(x + y)
+ print(x + y) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def and_error(x, y):
- # type: (Optional[int], Optional[int]) -> int
+ # type: (Optional[int], Optional[int]) -> None
if x is None and y is None:
pass
else:
- print(x + y)
+ print(x + y) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
# type: (Optional[int]) -> None
x_none = x is not None
if x_none:
- print(x + 1)
+ print(x + 1) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
# type: (Optional[int], Optional[int]) -> None
x_none = x is not None
if y is not None and x_none:
- print(x + y)
+ print(x + y) # noqa: T484
def test_while_write_outer_then_read(self):
def func(a, b):
self.checkScript(multiple_returns, [a], optimize=True)
with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
- @torch.jit.script
+ torch.jit.CompilationUnit('''
def no_return_bad_annotation(a):
# type: (Tensor) -> Tensor
a + 1
+ ''')
def test_error(self):
@torch.jit.script
hiddens = hx
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
- from typing import Tuple
-
class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super(ScriptWrapper, self).__init__()
with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
def foo():
# type: () -> Tensor
- return ((3, 4),)
+ return ((3, 4),) # noqa: T484
@torch.jit.script
def bar():
if x:
y = [1]
else:
- y = [None]
+ y = [None] # noqa: T484
return y[0]
@torch.jit.script
print(int_fn((1, 1, 1)))
with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
- @torch.jit.script
+ @torch.jit.script # noqa: T484
def fn(x):
- # type: (BroadcastingListx[int]) -> List[int]
+ # type: (BroadcastingListx[int]) -> List[int] # noqa: T484
return x
- # TODO: the type comment in this seems to trip up flake8 for some reason
- # even though we have a noqa comment. Figure out why
+ # using CU so that flake8 error on int[2] is not raised (noqa not working)
with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
- @torch.jit.script
- def nested(x, y):
- # type: (int, Tuple[int, int[2]]) -> List[int] # noqa: T484
- return x
+ cu = torch.jit.CompilationUnit('''
+ def nested(x, y):
+ # type: (int, Tuple[int, int[2]]) -> List[int]
+ return x # noqa: T484
+ ''')
def test_ntuple_builtins(self):
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
def somefunc():
# type: () -> Tuple[Tuple[Tensor, Tensor]]
- return torch.zeros(3, 4), torch.zeros(4, 5)
+ return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484
@torch.jit.script
def wrong_return_type():
def test(x):
# type: (Optional[int]) -> int
x = torch.jit._unwrap_optional(x)
- x = x + x
+ x = x + x # noqa: T484
return x
self.checkScript(test, (3,))
@torch.jit.script
def return_tup(x):
# type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
- return x, x
+ return x, x # noqa: T484
def test_annotated_script_fn_arg_mismatch(self):
with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
@torch.jit.script
def tuple_arg(x):
# type: (Tuple[Tensor, Tensor]) -> Tensor
- return x + 1
+ return x + 1 # noqa: T484
def test_script_non_tensor_args_outputs(self):
@torch.jit.script
self.assertEqual(y, y_hat)
def test_async_script_capture(self):
- class Module(torch.jit.ScriptModule):
+ class Mod(torch.jit.ScriptModule):
__constants__ = ['const']
def __init__(self):
- super(Module, self).__init__(False)
+ super(Mod, self).__init__(False)
self.const = 42
self.param = nn.Parameter(torch.randn(2, 2))
x1 = torch.rand(3, 4)
x2 = torch.rand(5, 6)
- m = Module()
+ m = Mod()
y, y_hat = m.wait_script(x1, x2)
self.assertEqual(y, y_hat)
def forward(self, x):
return (torch.neg(x), x)
- class Module(torch.jit.ScriptModule):
+ class Mod(torch.jit.ScriptModule):
def __init__(self):
- super(Module, self).__init__(False)
+ super(Mod, self).__init__(False)
x = torch.rand(3, 3)
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
# return a nested structure of tensors
return (tensor_list, tensor_tuple, tensor_tuple[1])
- class Tuple(nn.Module):
+ class TupleCl(nn.Module):
def __init__(self):
- super(Tuple, self).__init__()
- self.module = Module()
+ super(TupleCl, self).__init__()
+ self.module = Mod()
def forward(self, x):
z = torch.neg(x)
return tuple(list)
x = torch.rand(3, 3)
- module = torch.jit.trace(Tuple(), (x), _force_outplace=True)
+ module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
# Make sure we have forks
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
@torch.jit.script
class FooTest:
def __init__(self, x):
- # type: (int)
+ # type: (int) -> None
self.foo = x
def incFooTest(self, y):
- # type: (int)
+ # type: (int) -> None
self.foo = self.foo + y
@torch.jit.script
def fn(x):
- # type: (int)
+ # type: (int) -> int
foo = FooTest(x)
foo.incFooTest(2)
return foo.foo
@torch.jit.script
class FooTest:
def __init__(self, x):
- # type: (bool)
+ # type: (bool) -> None
self.foo = x
@torch.jit.script
@torch.jit.script
def fn(foo):
- # type: (FooTest)
+ # type: (FooTest) -> Tensor
return foo.attr
@torch.jit.script