with self.assertRaisesRegex(Exception, "Parameters not specified"):
torch.jit.script(test)
+ def test_function_overload_misuse(self):
+ with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
+ @torch.jit._overload
+ def wrong_decl_body(x: str) -> str:
+ return x + "0"
+
+ with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
+ class MyClass:
+ @torch.jit._overload_method
+ def method(self):
+ return 0
+
+ @torch.jit._overload
+ def null_overload(x: int) -> int: ... # noqa: E704
+
+ @torch.jit._overload
+ def null_overload(x: str) -> str: # noqa: F811
+ pass
+
+ def null_overload_driver():
+ return null_overload(0)
+
+ with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'):
+ torch.jit.script(null_overload_driver)
+
+ class OverloadMisuse(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ @torch.jit._overload_method
+ def forward(self, x: int):
+ pass
+
+ @torch.jit._overload_method
+ def forward(self, x: Tensor): # noqa: F811
+ pass
+
+ with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'):
+ m = torch.jit.script(OverloadMisuse())
+
+
def test_script_method_torch_function_overload(self):
class MyCustomTensor(torch.Tensor):
pass
import typing
import io
import pickle
-import functools
# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
# Explicitly ask to import `torch.distributed.__init__` first.
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
import torch.distributed.rpc
-from torch._utils_internal import get_source_lines_and_file
from torch._C import Future as CFuture
+from torch._sources import get_source_lines_and_file, parse_def, fake_range
from torch.futures import Future
import torch.package._mangling as package_mangling
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union # noqa: F401
# qualified_name => list[overload_functions]
_overloaded_fns : Dict[str, List[Callable]] = {} # noqa: T484
+
+_OVERLOAD_EXAMPLE = '''
+Example usage of overload function:
+@torch.jit._overload
+def my_function(x: type0) -> type0: # decl 1
+ pass
+
+@torch.jit._overload
+def my_function(x: type1) -> type1: # decl 2
+ pass
+
+def my_function(x): # implementation
+ if isinstance(x, type0):
+ return x
+ elif isinstance(x, type1):
+ return x
+'''
+
+def get_overload_no_implementation_error_message(kind, obj):
+ sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
+ return (
+ f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
+ f'sure a definition is provided and defined after all overload declarations.\n'
+ f'File "{filename}", line {file_lineno}:\n' + ''.join(sourcelines) + "\n" + _OVERLOAD_EXAMPLE
+ )
+
+def _check_overload_body(func):
+ parsed_def = parse_def(func)
+ body = parsed_def.ast.body[0].body
+
+ def is_pass(x):
+ return isinstance(x, ast.Pass)
+
+ def is_ellipsis(x):
+ return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis)
+
+ if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
+ msg = "Only `pass` statement or `...` can be the body of overload declaration:\n"
+ msg += '\n'.join(parsed_def.source.split("\n")[:3])
+ msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
+ raise RuntimeError(msg)
+
def _overload(func):
+ _check_overload_body(func)
qual_name = _qualified_name(func)
global _overloaded_fns
fn_overload_list = _overloaded_fns.get(qual_name)
_overloaded_method_class_fileno = {}
def _overload_method(func):
+ _check_overload_body(func)
qual_name = _qualified_name(func)
global _overloaded_methods
class_name_map = _overloaded_methods.get(qual_name, None)
return module_name + "." + name
-# Thin wrapper around SourceRangeFactory to store extra metadata
-# about the function-to-be-compiled.
-class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
- def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
- super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
- self.uses_true_division = uses_true_division
- self.filename = filename
-
-@functools.lru_cache(maxsize=None)
-def make_source_context(*args):
- return SourceContext(*args)
-
-def fake_range():
- return SourceContext('', None, 0, 0).make_raw_range(0, 1)
-
-
def _try_get_dispatched_fn(fn):
if not callable(fn):
return None
--- /dev/null
+import ast
+import functools
+import inspect
+from textwrap import dedent
+from typing import Any, Optional, Tuple, List, NamedTuple
+from torch._C import ErrorReport
+from torch._C._jit_tree_views import SourceRangeFactory
+
+def get_source_lines_and_file(
+ obj: Any,
+ error_msg: Optional[str] = None,
+) -> Tuple[List[str], int, Optional[str]]:
+ """
+ Wrapper around inspect.getsourcelines and inspect.getsourcefile.
+
+ Returns: (sourcelines, file_lino, filename)
+ """
+ filename = None # in case getsourcefile throws
+ try:
+ filename = inspect.getsourcefile(obj)
+ sourcelines, file_lineno = inspect.getsourcelines(obj)
+ except OSError as e:
+ msg = (f"Can't get source for {obj}. TorchScript requires source access in "
+ "order to carry out compilation, make sure original .py files are "
+ "available.")
+ if error_msg:
+ msg += '\n' + error_msg
+ raise OSError(msg) from e
+
+ return sourcelines, file_lineno, filename
+
+
+def normalize_source_lines(sourcelines: List[str]) -> List[str]:
+ """
+ This helper function accepts a list of source lines. It finds the
+ indentation level of the function definition (`def`), then it indents
+ all lines in the function body to a point at or greater than that
+ level. This allows for comments and continued string literals that
+ are at a lower indentation than the rest of the code.
+ Args:
+ sourcelines: function source code, separated into lines by
+ the '\n' character
+ Returns:
+ A list of source lines that have been correctly aligned
+ """
+
+ def remove_prefix(text, prefix):
+ return text[text.startswith(prefix) and len(prefix):]
+
+ # Find the line and line number containing the function definition
+ for i, l in enumerate(sourcelines):
+ if l.lstrip().startswith("def"):
+ idx = i
+ break
+ fn_def = sourcelines[idx]
+
+ # Get a string representing the amount of leading whitespace
+ whitespace = fn_def.split("def")[0]
+
+ # Add this leading whitespace to all lines before and after the `def`
+ aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
+ aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
+
+ # Put it together again
+ aligned_prefix.append(fn_def)
+ return aligned_prefix + aligned_suffix
+
+
+# Thin wrapper around SourceRangeFactory to store extra metadata
+# about the function-to-be-compiled.
+class SourceContext(SourceRangeFactory):
+ def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
+ super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
+ self.uses_true_division = uses_true_division
+ self.filename = filename
+
+
+@functools.lru_cache(maxsize=None)
+def make_source_context(*args):
+ return SourceContext(*args)
+
+
+def fake_range():
+ return SourceContext('', None, 0, 0).make_raw_range(0, 1)
+
+
+class ParsedDef(NamedTuple):
+ ast: ast.Module
+ ctx: SourceContext
+ source: str
+ filename: Optional[str]
+ file_lineno: int
+
+def parse_def(fn):
+ sourcelines, file_lineno, filename = get_source_lines_and_file(fn, ErrorReport.call_stack())
+ sourcelines = normalize_source_lines(sourcelines)
+ source = ''.join(sourcelines)
+ dedent_src = dedent(source)
+ py_ast = ast.parse(dedent_src)
+ if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
+ raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
+ leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
+ ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True)
+ return ParsedDef(py_ast, ctx, source, filename, file_lineno)
-
import os
-import inspect
import sys
import tempfile
-from typing import Any, List, Optional, Tuple
+
# this arbitrary-looking assortment of functionality is provided here
# to have a central place for overrideable behavior. The motivating
return os.path.realpath(path)
-def get_source_lines_and_file(
- obj: Any,
- error_msg: Optional[str] = None,
-) -> Tuple[List[str], int, Optional[str]]:
- """
- Wrapper around inspect.getsourcelines and inspect.getsourcefile.
-
- Returns: (sourcelines, file_lino, filename)
- """
- filename = None # in case getsourcefile throws
- try:
- filename = inspect.getsourcefile(obj)
- sourcelines, file_lineno = inspect.getsourcelines(obj)
- except OSError as e:
- msg = (f"Can't get source for {obj}. TorchScript requires source access in "
- "order to carry out compilation, make sure original .py files are "
- "available.")
- if error_msg:
- msg += '\n' + error_msg
- raise OSError(msg) from e
-
- return sourcelines, file_lineno, filename
-
-
TEST_MASTER_ADDR = '127.0.0.1'
TEST_MASTER_PORT = 29500
# USE_GLOBAL_DEPS controls whether __init__.py tries to load
from typing import cast, Union, Callable, Dict, Optional, Any
from torch.fx._symbolic_trace import Tracer
from torch.fx.graph import Graph
-from torch.jit.frontend import normalize_source_lines
+from torch._sources import normalize_source_lines
import torch
class AST_Rewriter(ast.NodeTransformer):
from typing import Dict, List, Set, Type
import torch._jit_internal as _jit_internal
+from torch._sources import fake_range
from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def, get_class_properties
from torch.jit._builtins import _find_builtin
from torch.jit._check import AttributeTypeIsSupportedChecker
inferred = False
try:
if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
- ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
+ ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], fake_range())
attr_type = torch._C.InferredType(ann_to_type)
elif isinstance(item, torch.jit.Attribute):
- ann_to_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
+ ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range())
attr_type = torch._C.InferredType(ann_to_type)
else:
attr_type = torch._C._jit_try_infer_type(item)
if method_overloads is None:
continue
+ if item.__func__ in method_overloads:
+ raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
+ 'method', item.__func__))
+
names = [name + "__" + str(i) for i in range(len(method_overloads))]
overloads[item] = list(zip(names, method_overloads))
return overload_name_mappings
def _check_no_signature(func):
- signature = torch.jit.annotations.get_signature(func, None, _jit_internal.fake_range(), inspect.ismethod(func))
+ signature = torch.jit.annotations.get_signature(func, None, fake_range(), inspect.ismethod(func))
if signature is None:
qual_name = _jit_internal._qualified_name(func)
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
if uncompiled_overloads is None:
return existing_compiled_fns
+ if obj in uncompiled_overloads:
+ raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
+ 'function', obj))
+
compiled_fns = []
for overload_fn in uncompiled_overloads:
compiled_fns.append(
from textwrap import dedent
-from torch._utils_internal import get_source_lines_and_file
+from torch._sources import get_source_lines_and_file
from typing import Type
if torch.distributed.rpc.is_available():
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
DictComp,
)
-from torch._utils_internal import get_source_lines_and_file
+from torch._sources import get_source_lines_and_file, parse_def, make_source_context
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
-from torch._jit_internal import make_source_context, should_drop, is_static_fn, FunctionModifiers # noqa: F401
+from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401
import torch.jit.annotations
_IS_ASTUNPARSE_INSTALLED = False
return build_class_def(ctx, class_ast, methods, properties, self_name, assigns)
-def normalize_source_lines(sourcelines: List[str]) -> List[str]:
- """
- This helper function accepts a list of source lines. It finds the
- indentation level of the function definition (`def`), then it indents
- all lines in the function body to a point at or greater than that
- level. This allows for comments and continued string literals that
- are at a lower indentation than the rest of the code.
- Args:
- sourcelines: function source code, separated into lines by
- the '\n' character
- Returns:
- A list of source lines that have been correctly aligned
- """
-
- def remove_prefix(text, prefix):
- return text[text.startswith(prefix) and len(prefix):]
-
- # Find the line and line number containing the function definition
- for i, l in enumerate(sourcelines):
- if l.lstrip().startswith("def"):
- idx = i
- break
- fn_def = sourcelines[idx]
-
- # Get a string representing the amount of leading whitespace
- whitespace = fn_def.split("def")[0]
-
- # Add this leading whitespace to all lines before and after the `def`
- aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
- aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
-
- # Put it together again
- aligned_prefix.append(fn_def)
- return aligned_prefix + aligned_suffix
-
-
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
"""
Build a JIT AST (TreeView) from the given function.
but we want the result AST to have the name "forward".
self_name: If this function is a method, what the type name of `self` is.
"""
- sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
- sourcelines = normalize_source_lines(sourcelines)
- source = ''.join(sourcelines)
- dedent_src = dedent(source)
- py_ast = ast.parse(dedent_src)
- if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
- raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
- leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
- type_line = torch.jit.annotations.get_type_line(source)
- ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True)
- fn_def = py_ast.body[0]
+ parsed_def = parse_def(fn)
+ type_line = torch.jit.annotations.get_type_line(parsed_def.source)
+ fn_def = parsed_def.ast.body[0]
if is_classmethod:
arg_name = fn_def.args.args[0].arg
if should_drop(fn):
unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")")
if len(unused_fn_def.body) != 1 or not isinstance(unused_fn_def.body[0], ast.FunctionDef):
- raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
+ raise RuntimeError(f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}")
unused_def = unused_fn_def.body[0]
fn_def.body = unused_def.body
# kwarg/vararg not supported by `build_def`
qualname = get_qualified_name(fn)
pdt_arg_types = type_trace_db.get_args_types(qualname)
- return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
+ return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
# TODO: more robust handling of recognizing ignore context manager
def is_torch_jit_ignore_context_manager(stmt):
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
-from torch._utils_internal import get_source_lines_and_file
+from torch._sources import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
import copyreg