Allow calling a Python function with a dict
authorDavid Riazati <davidriazati@fb.com>
Tue, 12 Feb 2019 05:48:58 +0000 (21:48 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Feb 2019 05:52:44 +0000 (21:52 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16989

Differential Revision: D14037896

Pulled By: driazati

fbshipit-source-id: 5f26d2d8fabf0f267909a3383f19d984645f94d0

test/test_jit.py
torch/_jit_internal.py
torch/csrc/jit/python_ir.cpp
torch/jit/annotations.py

index 051b4ea..f2ab4bf 100644 (file)
@@ -9933,6 +9933,18 @@ a")
 
         self.checkScript(list_of_dicts, ())
 
+    def dict_to_python(self):
+        def python_lookup(my_dict, keys):
+            # type: (Dict[str, int], List[str]) -> List[int]
+            return [my_dict[k] for k in keys]
+
+        def fn(my_dict, keys):
+            # type: (Dict[str, int], List[str]) -> List[int]
+            return python_lookup(my_dict, keys)
+
+        a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
+        self.checkScript(fn, (a_dict, ('a', 'c')))
+
 
 class MnistNet(nn.Module):
     def __init__(self):
index 07fcb83..ccda08d 100644 (file)
@@ -162,7 +162,7 @@ def ignore(fn):
 
 try:
     import typing
-    from typing import Tuple, List
+    from typing import Tuple, List, Dict
 
     def is_tuple(ann):
         # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
@@ -174,6 +174,11 @@ try:
         return ann.__module__ == 'typing' and \
             (getattr(ann, '__origin__', None) is typing.List or
              getattr(ann, '__origin__', None) is list)
+
+    def is_dict(ann):
+        return ann.__module__ == 'typing' and \
+            (getattr(ann, '__origin__', None) is typing.Dict or
+             getattr(ann, '__origin__', None) is dict)
 except ImportError:
     # A minimal polyfill for versions of Python that don't have typing.
     # Note that this means that they also don't support the fancy annotation syntax, so
@@ -196,8 +201,17 @@ except ImportError:
         def __getitem__(self, types):
             return TupleInstance(types)
 
+    class DictInstance(object):
+        def __init__(self, types):
+            setattr(self, '__args__', types)
+
+    class DictCls(object):
+        def __getitem__(self, types):
+            return DictInstance(types)
+
     Tuple = TupleCls()
     List = ListCls()
+    Dict = DictCls()
 
     def is_tuple(ann):
         return isinstance(ann, TupleInstance)
@@ -205,6 +219,9 @@ except ImportError:
     def is_list(ann):
         return isinstance(ann, ListInstance)
 
+    def is_dict(ann):
+        return isinstance(ann, DictInstance)
+
 
 # allows BroadcastingList instance to be subscriptable
 class BroadcastingListCls(object):
index b193af5..c2c350d 100644 (file)
@@ -622,6 +622,8 @@ void initPythonIRBindings(PyObject* module_) {
       .def_static("get", &TensorType::get);
   py::class_<BoolType, Type, std::shared_ptr<BoolType>>(m, "BoolType")
       .def_static("get", &BoolType::get);
+  py::class_<StringType, Type, std::shared_ptr<StringType>>(m, "StringType")
+      .def_static("get", &StringType::get);
 
   py::class_<TupleType, Type, std::shared_ptr<TupleType>>(m, "TupleType")
       .def(
@@ -638,6 +640,10 @@ void initPythonIRBindings(PyObject* module_) {
       .def_static("ofInts", &ListType::ofInts)
       .def_static("ofTensors", &ListType::ofTensors)
       .def("getElementType", &ListType::getElementType);
+  py::class_<DictType, Type, std::shared_ptr<DictType>>(m, "DictType")
+      .def(py::init([](TypePtr key, TypePtr value) {
+        return DictType::create(key, value);
+      }));
 
   py::class_<Use>(m, "Use")
       .def_readonly("user", &Use::user)
index aa7880e..22e5c3f 100644 (file)
@@ -3,8 +3,10 @@ import sys
 import ast
 import inspect
 import torch
-from .._jit_internal import List, BroadcastingList1, BroadcastingList2, BroadcastingList3, Tuple, is_tuple, is_list
-from torch._C import DynamicType, TupleType, FloatType, IntType, ListType
+from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
+    BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict
+from torch._C import DynamicType, TupleType, FloatType, IntType, \
+    ListType, StringType, DictType
 from textwrap import dedent
 
 
@@ -29,6 +31,7 @@ _eval_env = {
     'typing': Module('typing', {'Tuple': Tuple}),
     'Tuple': Tuple,
     'List': List,
+    'Dict': Dict,
 }
 
 
@@ -167,8 +170,14 @@ def ann_to_type(ann):
         return TupleType([ann_to_type(a) for a in ann.__args__])
     elif is_list(ann):
         return ListType(ann_to_type(ann.__args__[0]))
+    elif is_dict(ann):
+        key = ann_to_type(ann.__args__[0])
+        value = ann_to_type(ann.__args__[1])
+        return DictType(key, value)
     elif ann is float:
         return FloatType.get()
     elif ann is int:
         return IntType.get()
-    raise ValueError("The only supported annotations kinds are Tensor and Tuple[...]")
+    elif ann is str:
+        return StringType.get()
+    raise ValueError("Unknown type annotation: '{}'".format(ann.__name__))