From d1497debf2aeb27235afb3d73a970388ac8100ab Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 21 Mar 2019 09:06:30 -0700 Subject: [PATCH] Fix B903 lint: save memory for data classes with slots/namedtuple (#18184) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18184 ghimport-source-id: 2ce860b07c58d06dc10cd7e5b97d4ef7c709a50d Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18184 Fix B903 lint: save memory for data classes with slots/namedtuple** * #18181 Fix B902 lint error: invalid first argument. * #18178 Fix B006 lint errors: using mutable structure in default argument. * #18177 Fix lstrip bug revealed by B005 lint Signed-off-by: Edward Z. Yang Differential Revision: D14530872 fbshipit-source-id: e26cecab3a8545e7638454c28e654e7b82a3c08a --- .flake8 | 2 +- aten/src/ATen/function_wrapper.py | 2 ++ test/common_methods_invocations.py | 5 ++--- test/test_jit.py | 8 ++++---- torch/autograd/profiler.py | 8 ++------ torch/jit/__init__.py | 5 +---- 6 files changed, 12 insertions(+), 18 deletions(-) diff --git a/.flake8 b/.flake8 index 180d8f9..9137fa6 100644 --- a/.flake8 +++ b/.flake8 @@ -6,5 +6,5 @@ max-line-length = 120 ignore = E203,E305,E402,E501,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504,C408, # ignores below are temporary, fix them and remove please! - B007,B008,B903 + B007,B008 exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index c3182ee..61eff13 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -194,6 +194,8 @@ CALL_TEMPLATE = CodeTemplate("${cname}(${actuals})") class NYIError(Exception): """Indicates we don't support this declaration yet""" + __slots__ = ['reason'] + def __init__(self, reason): self.reason = reason diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index 756bba9..abb7856 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -2,6 +2,7 @@ import torch from torch._six import inf, nan, istuple from functools import reduce, wraps from operator import mul, itemgetter +import collections from torch.autograd import Variable, Function, detect_anomaly from torch.testing import make_non_contiguous from common_utils import (skipIfNoLapack, @@ -72,9 +73,7 @@ def prod_zeros(dim_size, dim_select): return result -class non_differentiable(object): - def __init__(self, tensor): - self.tensor = tensor +non_differentiable = collections.namedtuple('non_differentiable', ['tensor']) class dont_convert(tuple): diff --git a/test/test_jit.py b/test/test_jit.py index c4ab4e8..0a194bc 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -13942,7 +13942,7 @@ class TestClassType(JitTestCase): self.assertEqual(fn(input), input) def test_get_attr(self): - @torch.jit.script + @torch.jit.script # noqa: B903 class FooTest: def __init__(self, x): self.foo = x @@ -14005,7 +14005,7 @@ class TestClassType(JitTestCase): def test_type_annotations(self): with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"): - @torch.jit.script + @torch.jit.script # noqa: B903 class FooTest: def __init__(self, x): # type: (bool) -> None @@ -14026,7 +14026,7 @@ class TestClassType(JitTestCase): self.attr = x def test_class_type_as_param(self): - @torch.jit.script + @torch.jit.script # noqa: B903 class FooTest: def __init__(self, x): self.attr = x @@ -14094,7 +14094,7 @@ class TestClassType(JitTestCase): self.assertEqual(input, output) def test_save_load_with_classes_nested(self): - @torch.jit.script + @torch.jit.script # noqa: B903 class FooNestedTest: def __init__(self, y): self.y = y diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 574eddf..5670286 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -3,7 +3,7 @@ import re import os import sys import itertools -from collections import defaultdict +from collections import defaultdict, namedtuple import torch from torch._six import FileNotFoundError @@ -366,11 +366,7 @@ class Interval(object): return self.end - self.start -class Kernel(object): - def __init__(self, name, device, interval): - self.name = name - self.device = device - self.interval = interval +Kernel = namedtuple('Kernel', ['name', 'device', 'interval']) # TODO: record TID too diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index ae8e7b9..c9d9d6f 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1564,10 +1564,7 @@ def annotate(the_type, the_value): return the_value -class Attribute(object): - def __init__(self, value, the_type): - self.value = value - self.type = the_type +Attribute = collections.namedtuple('Attribute', ['value', 'type']) if not torch._C._jit_init(): -- 2.7.4