From e62189ad698303dfd5d577ee437437182ae77918 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Thu, 5 Aug 2021 14:19:56 -0700 Subject: [PATCH] [jit] Better checking for overload function declarations. (#59956) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59956 Issue #50175. Basically two things need to be checked and are lacking currently: 1. Overload declarations should always have a single `pass` statement as the body. 2. There should be always an implementation provided for decls which doesn't have the torch.jit._overload decorator. So in this case we need to check whether we are actually compiling a function body with decorator ahead. Test Plan: python test/test_jit.py TestScript.test_function_overloads Imported from OSS Reviewed By: gmagogsfm Differential Revision: D29106555 fbshipit-source-id: 2d9d7df2fb51ab6db0e1b726f9644e4cfbf733d6 --- test/test_jit.py | 41 +++++++++++++++ torch/_jit_internal.py | 63 ++++++++++++++++------- torch/_sources.py | 104 ++++++++++++++++++++++++++++++++++++++ torch/_utils_internal.py | 28 +--------- torch/fx/experimental/rewriter.py | 2 +- torch/jit/_recursive.py | 11 ++-- torch/jit/_script.py | 4 ++ torch/jit/annotations.py | 2 +- torch/jit/frontend.py | 58 +++------------------ torch/serialization.py | 2 +- 10 files changed, 213 insertions(+), 102 deletions(-) create mode 100644 torch/_sources.py diff --git a/test/test_jit.py b/test/test_jit.py index 807f82f..6a83f7d 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -14485,6 +14485,47 @@ dedent """ with self.assertRaisesRegex(Exception, "Parameters not specified"): torch.jit.script(test) + def test_function_overload_misuse(self): + with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"): + @torch.jit._overload + def wrong_decl_body(x: str) -> str: + return x + "0" + + with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"): + class MyClass: + @torch.jit._overload_method + def method(self): + return 0 + + @torch.jit._overload + def null_overload(x: int) -> int: ... # noqa: E704 + + @torch.jit._overload + def null_overload(x: str) -> str: # noqa: F811 + pass + + def null_overload_driver(): + return null_overload(0) + + with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'): + torch.jit.script(null_overload_driver) + + class OverloadMisuse(torch.nn.Module): + def __init__(self): + super().__init__() + + @torch.jit._overload_method + def forward(self, x: int): + pass + + @torch.jit._overload_method + def forward(self, x: Tensor): # noqa: F811 + pass + + with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'): + m = torch.jit.script(OverloadMisuse()) + + def test_script_method_torch_function_overload(self): class MyCustomTensor(torch.Tensor): pass diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 3754b2f..cd980e9 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -18,13 +18,12 @@ import builtins import typing import io import pickle -import functools # This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`. # Explicitly ask to import `torch.distributed.__init__` first. # Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised. import torch.distributed.rpc -from torch._utils_internal import get_source_lines_and_file from torch._C import Future as CFuture +from torch._sources import get_source_lines_and_file, parse_def, fake_range from torch.futures import Future import torch.package._mangling as package_mangling from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union # noqa: F401 @@ -716,7 +715,50 @@ def copy_torchscript_modifier(orig, new) -> None: # qualified_name => list[overload_functions] _overloaded_fns : Dict[str, List[Callable]] = {} # noqa: T484 + +_OVERLOAD_EXAMPLE = ''' +Example usage of overload function: +@torch.jit._overload +def my_function(x: type0) -> type0: # decl 1 + pass + +@torch.jit._overload +def my_function(x: type1) -> type1: # decl 2 + pass + +def my_function(x): # implementation + if isinstance(x, type0): + return x + elif isinstance(x, type1): + return x +''' + +def get_overload_no_implementation_error_message(kind, obj): + sourcelines, file_lineno, filename = get_source_lines_and_file(obj) + return ( + f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make ' + f'sure a definition is provided and defined after all overload declarations.\n' + f'File "{filename}", line {file_lineno}:\n' + ''.join(sourcelines) + "\n" + _OVERLOAD_EXAMPLE + ) + +def _check_overload_body(func): + parsed_def = parse_def(func) + body = parsed_def.ast.body[0].body + + def is_pass(x): + return isinstance(x, ast.Pass) + + def is_ellipsis(x): + return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis) + + if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])): + msg = "Only `pass` statement or `...` can be the body of overload declaration:\n" + msg += '\n'.join(parsed_def.source.split("\n")[:3]) + msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE + raise RuntimeError(msg) + def _overload(func): + _check_overload_body(func) qual_name = _qualified_name(func) global _overloaded_fns fn_overload_list = _overloaded_fns.get(qual_name) @@ -762,6 +804,7 @@ _overloaded_methods : Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484 _overloaded_method_class_fileno = {} def _overload_method(func): + _check_overload_body(func) qual_name = _qualified_name(func) global _overloaded_methods class_name_map = _overloaded_methods.get(qual_name, None) @@ -994,22 +1037,6 @@ def _qualified_name(obj) -> str: return module_name + "." + name -# Thin wrapper around SourceRangeFactory to store extra metadata -# about the function-to-be-compiled. -class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): - def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True): - super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len) - self.uses_true_division = uses_true_division - self.filename = filename - -@functools.lru_cache(maxsize=None) -def make_source_context(*args): - return SourceContext(*args) - -def fake_range(): - return SourceContext('', None, 0, 0).make_raw_range(0, 1) - - def _try_get_dispatched_fn(fn): if not callable(fn): return None diff --git a/torch/_sources.py b/torch/_sources.py new file mode 100644 index 0000000..2464949 --- /dev/null +++ b/torch/_sources.py @@ -0,0 +1,104 @@ +import ast +import functools +import inspect +from textwrap import dedent +from typing import Any, Optional, Tuple, List, NamedTuple +from torch._C import ErrorReport +from torch._C._jit_tree_views import SourceRangeFactory + +def get_source_lines_and_file( + obj: Any, + error_msg: Optional[str] = None, +) -> Tuple[List[str], int, Optional[str]]: + """ + Wrapper around inspect.getsourcelines and inspect.getsourcefile. + + Returns: (sourcelines, file_lino, filename) + """ + filename = None # in case getsourcefile throws + try: + filename = inspect.getsourcefile(obj) + sourcelines, file_lineno = inspect.getsourcelines(obj) + except OSError as e: + msg = (f"Can't get source for {obj}. TorchScript requires source access in " + "order to carry out compilation, make sure original .py files are " + "available.") + if error_msg: + msg += '\n' + error_msg + raise OSError(msg) from e + + return sourcelines, file_lineno, filename + + +def normalize_source_lines(sourcelines: List[str]) -> List[str]: + """ + This helper function accepts a list of source lines. It finds the + indentation level of the function definition (`def`), then it indents + all lines in the function body to a point at or greater than that + level. This allows for comments and continued string literals that + are at a lower indentation than the rest of the code. + Args: + sourcelines: function source code, separated into lines by + the '\n' character + Returns: + A list of source lines that have been correctly aligned + """ + + def remove_prefix(text, prefix): + return text[text.startswith(prefix) and len(prefix):] + + # Find the line and line number containing the function definition + for i, l in enumerate(sourcelines): + if l.lstrip().startswith("def"): + idx = i + break + fn_def = sourcelines[idx] + + # Get a string representing the amount of leading whitespace + whitespace = fn_def.split("def")[0] + + # Add this leading whitespace to all lines before and after the `def` + aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]] + aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]] + + # Put it together again + aligned_prefix.append(fn_def) + return aligned_prefix + aligned_suffix + + +# Thin wrapper around SourceRangeFactory to store extra metadata +# about the function-to-be-compiled. +class SourceContext(SourceRangeFactory): + def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True): + super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len) + self.uses_true_division = uses_true_division + self.filename = filename + + +@functools.lru_cache(maxsize=None) +def make_source_context(*args): + return SourceContext(*args) + + +def fake_range(): + return SourceContext('', None, 0, 0).make_raw_range(0, 1) + + +class ParsedDef(NamedTuple): + ast: ast.Module + ctx: SourceContext + source: str + filename: Optional[str] + file_lineno: int + +def parse_def(fn): + sourcelines, file_lineno, filename = get_source_lines_and_file(fn, ErrorReport.call_stack()) + sourcelines = normalize_source_lines(sourcelines) + source = ''.join(sourcelines) + dedent_src = dedent(source) + py_ast = ast.parse(dedent_src) + if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): + raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") + leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) + ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True) + return ParsedDef(py_ast, ctx, source, filename, file_lineno) diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 5e0a8de..1bbfe32 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -1,9 +1,7 @@ - import os -import inspect import sys import tempfile -from typing import Any, List, Optional, Tuple + # this arbitrary-looking assortment of functionality is provided here # to have a central place for overrideable behavior. The motivating @@ -44,30 +42,6 @@ def resolve_library_path(path: str) -> str: return os.path.realpath(path) -def get_source_lines_and_file( - obj: Any, - error_msg: Optional[str] = None, -) -> Tuple[List[str], int, Optional[str]]: - """ - Wrapper around inspect.getsourcelines and inspect.getsourcefile. - - Returns: (sourcelines, file_lino, filename) - """ - filename = None # in case getsourcefile throws - try: - filename = inspect.getsourcefile(obj) - sourcelines, file_lineno = inspect.getsourcelines(obj) - except OSError as e: - msg = (f"Can't get source for {obj}. TorchScript requires source access in " - "order to carry out compilation, make sure original .py files are " - "available.") - if error_msg: - msg += '\n' + error_msg - raise OSError(msg) from e - - return sourcelines, file_lineno, filename - - TEST_MASTER_ADDR = '127.0.0.1' TEST_MASTER_PORT = 29500 # USE_GLOBAL_DEPS controls whether __init__.py tries to load diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py index b462026..b3f71d5 100644 --- a/torch/fx/experimental/rewriter.py +++ b/torch/fx/experimental/rewriter.py @@ -6,7 +6,7 @@ from types import FunctionType from typing import cast, Union, Callable, Dict, Optional, Any from torch.fx._symbolic_trace import Tracer from torch.fx.graph import Graph -from torch.jit.frontend import normalize_source_lines +from torch._sources import normalize_source_lines import torch class AST_Rewriter(ast.NodeTransformer): diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 636ccd2..d85d43e 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -8,6 +8,7 @@ import warnings from typing import Dict, List, Set, Type import torch._jit_internal as _jit_internal +from torch._sources import fake_range from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def, get_class_properties from torch.jit._builtins import _find_builtin from torch.jit._check import AttributeTypeIsSupportedChecker @@ -148,10 +149,10 @@ def infer_concrete_type_builder(nn_module, share_types=True): inferred = False try: if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]: - ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range()) + ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], fake_range()) attr_type = torch._C.InferredType(ann_to_type) elif isinstance(item, torch.jit.Attribute): - ann_to_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range()) + ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range()) attr_type = torch._C.InferredType(ann_to_type) else: attr_type = torch._C._jit_try_infer_type(item) @@ -620,6 +621,10 @@ def get_overload_annotations(mod, jit_ignored_properties): if method_overloads is None: continue + if item.__func__ in method_overloads: + raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message( + 'method', item.__func__)) + names = [name + "__" + str(i) for i in range(len(method_overloads))] overloads[item] = list(zip(names, method_overloads)) @@ -639,7 +644,7 @@ def get_overload_name_mapping(overload_info): return overload_name_mappings def _check_no_signature(func): - signature = torch.jit.annotations.get_signature(func, None, _jit_internal.fake_range(), inspect.ismethod(func)) + signature = torch.jit.annotations.get_signature(func, None, fake_range(), inspect.ismethod(func)) if signature is None: qual_name = _jit_internal._qualified_name(func) raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name)) diff --git a/torch/jit/_script.py b/torch/jit/_script.py index abd9750..0c3e5ef 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1337,6 +1337,10 @@ def _get_overloads(obj): if uncompiled_overloads is None: return existing_compiled_fns + if obj in uncompiled_overloads: + raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message( + 'function', obj)) + compiled_fns = [] for overload_fn in uncompiled_overloads: compiled_fns.append( diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index ced9958..f2cf789 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -16,7 +16,7 @@ from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \ from textwrap import dedent -from torch._utils_internal import get_source_lines_and_file +from torch._sources import get_source_lines_and_file from typing import Type if torch.distributed.rpc.is_available(): diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index a78f8ab..b0228b1 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -17,9 +17,9 @@ from torch._C._jit_tree_views import ( SliceExpr, Subscript, TernaryIf, With, WithItem, Property, DictComp, ) -from torch._utils_internal import get_source_lines_and_file +from torch._sources import get_source_lines_and_file, parse_def, make_source_context from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name -from torch._jit_internal import make_source_context, should_drop, is_static_fn, FunctionModifiers # noqa: F401 +from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401 import torch.jit.annotations _IS_ASTUNPARSE_INSTALLED = False @@ -215,42 +215,6 @@ def get_jit_class_def(cls, self_name): return build_class_def(ctx, class_ast, methods, properties, self_name, assigns) -def normalize_source_lines(sourcelines: List[str]) -> List[str]: - """ - This helper function accepts a list of source lines. It finds the - indentation level of the function definition (`def`), then it indents - all lines in the function body to a point at or greater than that - level. This allows for comments and continued string literals that - are at a lower indentation than the rest of the code. - Args: - sourcelines: function source code, separated into lines by - the '\n' character - Returns: - A list of source lines that have been correctly aligned - """ - - def remove_prefix(text, prefix): - return text[text.startswith(prefix) and len(prefix):] - - # Find the line and line number containing the function definition - for i, l in enumerate(sourcelines): - if l.lstrip().startswith("def"): - idx = i - break - fn_def = sourcelines[idx] - - # Get a string representing the amount of leading whitespace - whitespace = fn_def.split("def")[0] - - # Add this leading whitespace to all lines before and after the `def` - aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]] - aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]] - - # Put it together again - aligned_prefix.append(fn_def) - return aligned_prefix + aligned_suffix - - def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): """ Build a JIT AST (TreeView) from the given function. @@ -266,17 +230,9 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): but we want the result AST to have the name "forward". self_name: If this function is a method, what the type name of `self` is. """ - sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack()) - sourcelines = normalize_source_lines(sourcelines) - source = ''.join(sourcelines) - dedent_src = dedent(source) - py_ast = ast.parse(dedent_src) - if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): - raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") - leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) - type_line = torch.jit.annotations.get_type_line(source) - ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True) - fn_def = py_ast.body[0] + parsed_def = parse_def(fn) + type_line = torch.jit.annotations.get_type_line(parsed_def.source) + fn_def = parsed_def.ast.body[0] if is_classmethod: arg_name = fn_def.args.args[0].arg @@ -288,7 +244,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): if should_drop(fn): unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")") if len(unused_fn_def.body) != 1 or not isinstance(unused_fn_def.body[0], ast.FunctionDef): - raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") + raise RuntimeError(f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}") unused_def = unused_fn_def.body[0] fn_def.body = unused_def.body # kwarg/vararg not supported by `build_def` @@ -305,7 +261,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): qualname = get_qualified_name(fn) pdt_arg_types = type_trace_db.get_args_types(qualname) - return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types) + return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types) # TODO: more robust handling of recognizing ignore context manager def is_torch_jit_ignore_context_manager(stmt): diff --git a/torch/serialization.py b/torch/serialization.py index d84ae9a..4443561 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -11,7 +11,7 @@ import warnings from contextlib import closing, contextmanager from ._utils import _import_dotted_name from ._six import string_classes as _string_classes -from torch._utils_internal import get_source_lines_and_file +from torch._sources import get_source_lines_and_file from torch.types import Storage from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO import copyreg -- 2.7.4