[jit] Better checking for overload function declarations. (#59956)
authorZhengxu Chen <zhxchen17@fb.com>
Thu, 5 Aug 2021 21:19:56 +0000 (14:19 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 5 Aug 2021 21:21:48 +0000 (14:21 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59956

Issue #50175. Basically two things need to be checked and are lacking currently:
1. Overload declarations should always have a single `pass` statement as the body.
2. There should be always an implementation provided for decls which doesn't
   have the torch.jit._overload decorator. So in this case we need to check
   whether we are actually compiling a function body with decorator ahead.

Test Plan:
python test/test_jit.py TestScript.test_function_overloads

Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D29106555

fbshipit-source-id: 2d9d7df2fb51ab6db0e1b726f9644e4cfbf733d6

test/test_jit.py
torch/_jit_internal.py
torch/_sources.py [new file with mode: 0644]
torch/_utils_internal.py
torch/fx/experimental/rewriter.py
torch/jit/_recursive.py
torch/jit/_script.py
torch/jit/annotations.py
torch/jit/frontend.py
torch/serialization.py

index 807f82f..6a83f7d 100644 (file)
@@ -14485,6 +14485,47 @@ dedent """
         with self.assertRaisesRegex(Exception, "Parameters not specified"):
             torch.jit.script(test)
 
+    def test_function_overload_misuse(self):
+        with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
+            @torch.jit._overload
+            def wrong_decl_body(x: str) -> str:
+                return x + "0"
+
+        with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
+            class MyClass:
+                @torch.jit._overload_method
+                def method(self):
+                    return 0
+
+        @torch.jit._overload
+        def null_overload(x: int) -> int: ...  # noqa: E704
+
+        @torch.jit._overload
+        def null_overload(x: str) -> str:  # noqa: F811
+            pass
+
+        def null_overload_driver():
+            return null_overload(0)
+
+        with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'):
+            torch.jit.script(null_overload_driver)
+
+        class OverloadMisuse(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            @torch.jit._overload_method
+            def forward(self, x: int):
+                pass
+
+            @torch.jit._overload_method
+            def forward(self, x: Tensor):  # noqa: F811
+                pass
+
+        with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'):
+            m = torch.jit.script(OverloadMisuse())
+
+
     def test_script_method_torch_function_overload(self):
         class MyCustomTensor(torch.Tensor):
             pass
index 3754b2f..cd980e9 100644 (file)
@@ -18,13 +18,12 @@ import builtins
 import typing
 import io
 import pickle
-import functools
 # This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
 # Explicitly ask to import `torch.distributed.__init__` first.
 # Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
 import torch.distributed.rpc
-from torch._utils_internal import get_source_lines_and_file
 from torch._C import Future as CFuture
+from torch._sources import get_source_lines_and_file, parse_def, fake_range
 from torch.futures import Future
 import torch.package._mangling as package_mangling
 from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union  # noqa: F401
@@ -716,7 +715,50 @@ def copy_torchscript_modifier(orig, new) -> None:
 # qualified_name => list[overload_functions]
 _overloaded_fns : Dict[str, List[Callable]] = {}  # noqa: T484
 
+
+_OVERLOAD_EXAMPLE = '''
+Example usage of overload function:
+@torch.jit._overload
+def my_function(x: type0) -> type0: # decl 1
+    pass
+
+@torch.jit._overload
+def my_function(x: type1) -> type1: # decl 2
+    pass
+
+def my_function(x):                 # implementation
+    if isinstance(x, type0):
+        return x
+    elif isinstance(x, type1):
+        return x
+'''
+
+def get_overload_no_implementation_error_message(kind, obj):
+    sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
+    return (
+        f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
+        f'sure a definition is provided and defined after all overload declarations.\n'
+        f'File "{filename}", line {file_lineno}:\n' + ''.join(sourcelines) + "\n" + _OVERLOAD_EXAMPLE
+    )
+
+def _check_overload_body(func):
+    parsed_def = parse_def(func)
+    body = parsed_def.ast.body[0].body
+
+    def is_pass(x):
+        return isinstance(x, ast.Pass)
+
+    def is_ellipsis(x):
+        return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis)
+
+    if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
+        msg = "Only `pass` statement or `...` can be the body of overload declaration:\n"
+        msg += '\n'.join(parsed_def.source.split("\n")[:3])
+        msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
+        raise RuntimeError(msg)
+
 def _overload(func):
+    _check_overload_body(func)
     qual_name = _qualified_name(func)
     global _overloaded_fns
     fn_overload_list = _overloaded_fns.get(qual_name)
@@ -762,6 +804,7 @@ _overloaded_methods : Dict[str, Dict[str, List[Callable]]] = {}  # noqa: T484
 _overloaded_method_class_fileno = {}
 
 def _overload_method(func):
+    _check_overload_body(func)
     qual_name = _qualified_name(func)
     global _overloaded_methods
     class_name_map = _overloaded_methods.get(qual_name, None)
@@ -994,22 +1037,6 @@ def _qualified_name(obj) -> str:
     return module_name + "." + name
 
 
-# Thin wrapper around SourceRangeFactory to store extra metadata
-# about the function-to-be-compiled.
-class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
-    def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
-        super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
-        self.uses_true_division = uses_true_division
-        self.filename = filename
-
-@functools.lru_cache(maxsize=None)
-def make_source_context(*args):
-    return SourceContext(*args)
-
-def fake_range():
-    return SourceContext('', None, 0, 0).make_raw_range(0, 1)
-
-
 def _try_get_dispatched_fn(fn):
     if not callable(fn):
         return None
diff --git a/torch/_sources.py b/torch/_sources.py
new file mode 100644 (file)
index 0000000..2464949
--- /dev/null
@@ -0,0 +1,104 @@
+import ast
+import functools
+import inspect
+from textwrap import dedent
+from typing import Any, Optional, Tuple, List, NamedTuple
+from torch._C import ErrorReport
+from torch._C._jit_tree_views import SourceRangeFactory
+
+def get_source_lines_and_file(
+    obj: Any,
+    error_msg: Optional[str] = None,
+) -> Tuple[List[str], int, Optional[str]]:
+    """
+    Wrapper around inspect.getsourcelines and inspect.getsourcefile.
+
+    Returns: (sourcelines, file_lino, filename)
+    """
+    filename = None  # in case getsourcefile throws
+    try:
+        filename = inspect.getsourcefile(obj)
+        sourcelines, file_lineno = inspect.getsourcelines(obj)
+    except OSError as e:
+        msg = (f"Can't get source for {obj}. TorchScript requires source access in "
+               "order to carry out compilation, make sure original .py files are "
+               "available.")
+        if error_msg:
+            msg += '\n' + error_msg
+        raise OSError(msg) from e
+
+    return sourcelines, file_lineno, filename
+
+
+def normalize_source_lines(sourcelines: List[str]) -> List[str]:
+    """
+    This helper function accepts a list of source lines. It finds the
+    indentation level of the function definition (`def`), then it indents
+    all lines in the function body to a point at or greater than that
+    level. This allows for comments and continued string literals that
+    are at a lower indentation than the rest of the code.
+    Args:
+        sourcelines: function source code, separated into lines by
+                        the '\n' character
+    Returns:
+        A list of source lines that have been correctly aligned
+    """
+
+    def remove_prefix(text, prefix):
+        return text[text.startswith(prefix) and len(prefix):]
+
+    # Find the line and line number containing the function definition
+    for i, l in enumerate(sourcelines):
+        if l.lstrip().startswith("def"):
+            idx = i
+            break
+    fn_def = sourcelines[idx]
+
+    # Get a string representing the amount of leading whitespace
+    whitespace = fn_def.split("def")[0]
+
+    # Add this leading whitespace to all lines before and after the `def`
+    aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
+    aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
+
+    # Put it together again
+    aligned_prefix.append(fn_def)
+    return aligned_prefix + aligned_suffix
+
+
+# Thin wrapper around SourceRangeFactory to store extra metadata
+# about the function-to-be-compiled.
+class SourceContext(SourceRangeFactory):
+    def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
+        super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
+        self.uses_true_division = uses_true_division
+        self.filename = filename
+
+
+@functools.lru_cache(maxsize=None)
+def make_source_context(*args):
+    return SourceContext(*args)
+
+
+def fake_range():
+    return SourceContext('', None, 0, 0).make_raw_range(0, 1)
+
+
+class ParsedDef(NamedTuple):
+    ast: ast.Module
+    ctx: SourceContext
+    source: str
+    filename: Optional[str]
+    file_lineno: int
+
+def parse_def(fn):
+    sourcelines, file_lineno, filename = get_source_lines_and_file(fn, ErrorReport.call_stack())
+    sourcelines = normalize_source_lines(sourcelines)
+    source = ''.join(sourcelines)
+    dedent_src = dedent(source)
+    py_ast = ast.parse(dedent_src)
+    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
+        raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
+    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
+    ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True)
+    return ParsedDef(py_ast, ctx, source, filename, file_lineno)
index 5e0a8de..1bbfe32 100644 (file)
@@ -1,9 +1,7 @@
-
 import os
-import inspect
 import sys
 import tempfile
-from typing import Any, List, Optional, Tuple
+
 
 # this arbitrary-looking assortment of functionality is provided here
 # to have a central place for overrideable behavior. The motivating
@@ -44,30 +42,6 @@ def resolve_library_path(path: str) -> str:
     return os.path.realpath(path)
 
 
-def get_source_lines_and_file(
-    obj: Any,
-    error_msg: Optional[str] = None,
-) -> Tuple[List[str], int, Optional[str]]:
-    """
-    Wrapper around inspect.getsourcelines and inspect.getsourcefile.
-
-    Returns: (sourcelines, file_lino, filename)
-    """
-    filename = None  # in case getsourcefile throws
-    try:
-        filename = inspect.getsourcefile(obj)
-        sourcelines, file_lineno = inspect.getsourcelines(obj)
-    except OSError as e:
-        msg = (f"Can't get source for {obj}. TorchScript requires source access in "
-               "order to carry out compilation, make sure original .py files are "
-               "available.")
-        if error_msg:
-            msg += '\n' + error_msg
-        raise OSError(msg) from e
-
-    return sourcelines, file_lineno, filename
-
-
 TEST_MASTER_ADDR = '127.0.0.1'
 TEST_MASTER_PORT = 29500
 # USE_GLOBAL_DEPS controls whether __init__.py tries to load
index b462026..b3f71d5 100644 (file)
@@ -6,7 +6,7 @@ from types import FunctionType
 from typing import cast, Union, Callable, Dict, Optional, Any
 from torch.fx._symbolic_trace import Tracer
 from torch.fx.graph import Graph
-from torch.jit.frontend import normalize_source_lines
+from torch._sources import normalize_source_lines
 import torch
 
 class AST_Rewriter(ast.NodeTransformer):
index 636ccd2..d85d43e 100644 (file)
@@ -8,6 +8,7 @@ import warnings
 from typing import Dict, List, Set, Type
 
 import torch._jit_internal as _jit_internal
+from torch._sources import fake_range
 from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def, get_class_properties
 from torch.jit._builtins import _find_builtin
 from torch.jit._check import AttributeTypeIsSupportedChecker
@@ -148,10 +149,10 @@ def infer_concrete_type_builder(nn_module, share_types=True):
         inferred = False
         try:
             if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
-                ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
+                ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], fake_range())
                 attr_type = torch._C.InferredType(ann_to_type)
             elif isinstance(item, torch.jit.Attribute):
-                ann_to_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
+                ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range())
                 attr_type = torch._C.InferredType(ann_to_type)
             else:
                 attr_type = torch._C._jit_try_infer_type(item)
@@ -620,6 +621,10 @@ def get_overload_annotations(mod, jit_ignored_properties):
             if method_overloads is None:
                 continue
 
+            if item.__func__ in method_overloads:
+                raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
+                    'method', item.__func__))
+
             names = [name + "__" + str(i) for i in range(len(method_overloads))]
             overloads[item] = list(zip(names, method_overloads))
 
@@ -639,7 +644,7 @@ def get_overload_name_mapping(overload_info):
     return overload_name_mappings
 
 def _check_no_signature(func):
-    signature = torch.jit.annotations.get_signature(func, None, _jit_internal.fake_range(), inspect.ismethod(func))
+    signature = torch.jit.annotations.get_signature(func, None, fake_range(), inspect.ismethod(func))
     if signature is None:
         qual_name = _jit_internal._qualified_name(func)
         raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
index abd9750..0c3e5ef 100644 (file)
@@ -1337,6 +1337,10 @@ def _get_overloads(obj):
     if uncompiled_overloads is None:
         return existing_compiled_fns
 
+    if obj in uncompiled_overloads:
+        raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
+            'function', obj))
+
     compiled_fns = []
     for overload_fn in uncompiled_overloads:
         compiled_fns.append(
index ced9958..f2cf789 100644 (file)
@@ -16,7 +16,7 @@ from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
 
 
 from textwrap import dedent
-from torch._utils_internal import get_source_lines_and_file
+from torch._sources import get_source_lines_and_file
 from typing import Type
 
 if torch.distributed.rpc.is_available():
index a78f8ab..b0228b1 100644 (file)
@@ -17,9 +17,9 @@ from torch._C._jit_tree_views import (
     SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
     DictComp,
 )
-from torch._utils_internal import get_source_lines_and_file
+from torch._sources import get_source_lines_and_file, parse_def, make_source_context
 from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
-from torch._jit_internal import make_source_context, should_drop, is_static_fn, FunctionModifiers  # noqa: F401
+from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers  # noqa: F401
 import torch.jit.annotations
 
 _IS_ASTUNPARSE_INSTALLED = False
@@ -215,42 +215,6 @@ def get_jit_class_def(cls, self_name):
     return build_class_def(ctx, class_ast, methods, properties, self_name, assigns)
 
 
-def normalize_source_lines(sourcelines: List[str]) -> List[str]:
-    """
-    This helper function accepts a list of source lines. It finds the
-    indentation level of the function definition (`def`), then it indents
-    all lines in the function body to a point at or greater than that
-    level. This allows for comments and continued string literals that
-    are at a lower indentation than the rest of the code.
-    Args:
-        sourcelines: function source code, separated into lines by
-                        the '\n' character
-    Returns:
-        A list of source lines that have been correctly aligned
-    """
-
-    def remove_prefix(text, prefix):
-        return text[text.startswith(prefix) and len(prefix):]
-
-    # Find the line and line number containing the function definition
-    for i, l in enumerate(sourcelines):
-        if l.lstrip().startswith("def"):
-            idx = i
-            break
-    fn_def = sourcelines[idx]
-
-    # Get a string representing the amount of leading whitespace
-    whitespace = fn_def.split("def")[0]
-
-    # Add this leading whitespace to all lines before and after the `def`
-    aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
-    aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
-
-    # Put it together again
-    aligned_prefix.append(fn_def)
-    return aligned_prefix + aligned_suffix
-
-
 def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
     """
     Build a JIT AST (TreeView) from the given function.
@@ -266,17 +230,9 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
             but we want the result AST to have the name "forward".
         self_name: If this function is a method, what the type name of `self` is.
     """
-    sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
-    sourcelines = normalize_source_lines(sourcelines)
-    source = ''.join(sourcelines)
-    dedent_src = dedent(source)
-    py_ast = ast.parse(dedent_src)
-    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
-        raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
-    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
-    type_line = torch.jit.annotations.get_type_line(source)
-    ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True)
-    fn_def = py_ast.body[0]
+    parsed_def = parse_def(fn)
+    type_line = torch.jit.annotations.get_type_line(parsed_def.source)
+    fn_def = parsed_def.ast.body[0]
 
     if is_classmethod:
         arg_name = fn_def.args.args[0].arg
@@ -288,7 +244,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
     if should_drop(fn):
         unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")")
         if len(unused_fn_def.body) != 1 or not isinstance(unused_fn_def.body[0], ast.FunctionDef):
-            raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
+            raise RuntimeError(f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}")
         unused_def = unused_fn_def.body[0]
         fn_def.body = unused_def.body
         # kwarg/vararg not supported by `build_def`
@@ -305,7 +261,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
         qualname = get_qualified_name(fn)
         pdt_arg_types = type_trace_db.get_args_types(qualname)
 
-    return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
+    return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
 
 # TODO: more robust handling of recognizing ignore context manager
 def is_torch_jit_ignore_context_manager(stmt):
index d84ae9a..4443561 100644 (file)
@@ -11,7 +11,7 @@ import warnings
 from contextlib import closing, contextmanager
 from ._utils import _import_dotted_name
 from ._six import string_classes as _string_classes
-from torch._utils_internal import get_source_lines_and_file
+from torch._sources import get_source_lines_and_file
 from torch.types import Storage
 from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
 import copyreg