self.checkScript(foo, [torch.rand(2, 3)])
+ def test_list_python_op(self):
+ def python_list_op(lst):
+ # type: (List[Tensor]) -> Tensor
+ return lst[0]
+
+ def fn(lst):
+ # type: (List[Tensor]) -> Tensor
+ return python_list_op(lst)
+
+ self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],))
+
class MnistNet(nn.Module):
def __init__(self):
return ann.__module__ == 'typing' and \
(getattr(ann, '__origin__', None) is typing.Tuple or
getattr(ann, '__origin__', None) is tuple)
+
+ def is_list(ann):
+ return ann.__module__ == 'typing' and \
+ (getattr(ann, '__origin__', None) is typing.List or
+ getattr(ann, '__origin__', None) is list)
except ImportError:
# A minimal polyfill for versions of Python that don't have typing.
# Note that this means that they also don't support the fancy annotation syntax, so
def is_tuple(ann):
return isinstance(ann, TupleInstance)
+ def is_list(ann):
+ return isinstance(ann, ListInstance)
+
# allows BroadcastingList instance to be subscriptable
class BroadcastingListCls(object):
return types;
});
py::class_<ListType, Type, std::shared_ptr<ListType>>(m, "ListType")
+ .def(
+ py::init([](TypePtr a) { return ListType::create(a); }))
.def_static("ofInts", &ListType::ofInts)
.def_static("ofTensors", &ListType::ofTensors)
.def("getElementType", &ListType::getElementType);
import ast
import inspect
import torch
-from .._jit_internal import List, Tuple, BroadcastingList1, BroadcastingList2, BroadcastingList3, is_tuple
-from torch._C import DynamicType, TupleType, FloatType, IntType
+from .._jit_internal import List, BroadcastingList1, BroadcastingList2, BroadcastingList3, Tuple, is_tuple, is_list
+from torch._C import DynamicType, TupleType, FloatType, IntType, ListType
from textwrap import dedent
'Tensor': torch.Tensor,
'typing': Module('typing', {'Tuple': Tuple}),
'Tuple': Tuple,
+ 'List': List,
}
try:
arg_ann = eval(arg_ann_str, _eval_env)
- except SyntaxError:
- raise RuntimeError("Failed to parse the argument list of a type annotation")
+ except (NameError, SyntaxError) as e:
+ raise RuntimeError("Failed to parse the argument list of a type annotation: {}".format(str(e)))
if not isinstance(arg_ann, tuple):
arg_ann = (arg_ann,)
try:
ret_ann = eval(ret_ann_str, _eval_env)
- except SyntaxError:
- raise RuntimeError("Failed to parse the return type of a type annotation")
+ except (NameError, SyntaxError) as e:
+ raise RuntimeError("Failed to parse the return type of a type annotation: {}".format(str(e)))
arg_types = [ann_to_type(ann) for ann in arg_ann]
return arg_types, ann_to_type(ret_ann)
return DynamicType.get()
elif is_tuple(ann):
return TupleType([ann_to_type(a) for a in ann.__args__])
+ elif is_list(ann):
+ return ListType(ann_to_type(ann.__args__[0]))
elif ann is float:
return FloatType.get()
elif ann is int: