Allow List arguments to Python Ops (#15721)
authorDavid Riazati <davidriazati@fb.com>
Mon, 7 Jan 2019 21:49:20 +0000 (13:49 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 7 Jan 2019 21:51:53 +0000 (13:51 -0800)
Summary:
Adds `List` to eval environment for type lines and allows `List` to be used on PythonOps (follows the same style as the `Tuple` code), fixes #15661
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15721

Differential Revision: D13578540

Pulled By: driazati

fbshipit-source-id: fce54dc3c0931d8b017b2e3483f0ac53826dda94

test/test_jit.py
torch/_jit_internal.py
torch/csrc/jit/python_ir.cpp
torch/jit/annotations.py

index 6d5b640..ac48e67 100644 (file)
@@ -9137,6 +9137,17 @@ a")
 
         self.checkScript(foo, [torch.rand(2, 3)])
 
+    def test_list_python_op(self):
+        def python_list_op(lst):
+            # type: (List[Tensor]) -> Tensor
+            return lst[0]
+
+        def fn(lst):
+            # type: (List[Tensor]) -> Tensor
+            return python_list_op(lst)
+
+        self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],))
+
 
 class MnistNet(nn.Module):
     def __init__(self):
index 154336c..45800c8 100644 (file)
@@ -160,6 +160,11 @@ try:
         return ann.__module__ == 'typing' and \
             (getattr(ann, '__origin__', None) is typing.Tuple or
              getattr(ann, '__origin__', None) is tuple)
+
+    def is_list(ann):
+        return ann.__module__ == 'typing' and \
+            (getattr(ann, '__origin__', None) is typing.List or
+             getattr(ann, '__origin__', None) is list)
 except ImportError:
     # A minimal polyfill for versions of Python that don't have typing.
     # Note that this means that they also don't support the fancy annotation syntax, so
@@ -188,6 +193,9 @@ except ImportError:
     def is_tuple(ann):
         return isinstance(ann, TupleInstance)
 
+    def is_list(ann):
+        return isinstance(ann, ListInstance)
+
 
 # allows BroadcastingList instance to be subscriptable
 class BroadcastingListCls(object):
index f829452..6d3ab9a 100644 (file)
@@ -541,6 +541,8 @@ void initPythonIRBindings(PyObject* module_) {
         return types;
       });
   py::class_<ListType, Type, std::shared_ptr<ListType>>(m, "ListType")
+      .def(
+          py::init([](TypePtr a) { return ListType::create(a); }))
       .def_static("ofInts", &ListType::ofInts)
       .def_static("ofTensors", &ListType::ofTensors)
       .def("getElementType", &ListType::getElementType);
index 39de241..01ee46f 100644 (file)
@@ -3,8 +3,8 @@ import sys
 import ast
 import inspect
 import torch
-from .._jit_internal import List, Tuple, BroadcastingList1, BroadcastingList2, BroadcastingList3, is_tuple
-from torch._C import DynamicType, TupleType, FloatType, IntType
+from .._jit_internal import List, BroadcastingList1, BroadcastingList2, BroadcastingList3, Tuple, is_tuple, is_list
+from torch._C import DynamicType, TupleType, FloatType, IntType, ListType
 from textwrap import dedent
 
 
@@ -28,6 +28,7 @@ _eval_env = {
     'Tensor': torch.Tensor,
     'typing': Module('typing', {'Tuple': Tuple}),
     'Tuple': Tuple,
+    'List': List,
 }
 
 
@@ -88,16 +89,16 @@ def parse_type_line(type_line):
 
     try:
         arg_ann = eval(arg_ann_str, _eval_env)
-    except SyntaxError:
-        raise RuntimeError("Failed to parse the argument list of a type annotation")
+    except (NameError, SyntaxError) as e:
+        raise RuntimeError("Failed to parse the argument list of a type annotation: {}".format(str(e)))
 
     if not isinstance(arg_ann, tuple):
         arg_ann = (arg_ann,)
 
     try:
         ret_ann = eval(ret_ann_str, _eval_env)
-    except SyntaxError:
-        raise RuntimeError("Failed to parse the return type of a type annotation")
+    except (NameError, SyntaxError) as e:
+        raise RuntimeError("Failed to parse the return type of a type annotation: {}".format(str(e)))
 
     arg_types = [ann_to_type(ann) for ann in arg_ann]
     return arg_types, ann_to_type(ret_ann)
@@ -162,6 +163,8 @@ def ann_to_type(ann):
         return DynamicType.get()
     elif is_tuple(ann):
         return TupleType([ann_to_type(a) for a in ann.__args__])
+    elif is_list(ann):
+        return ListType(ann_to_type(ann.__args__[0]))
     elif ann is float:
         return FloatType.get()
     elif ann is int: