From: David Riazati Date: Tue, 12 Feb 2019 05:48:58 +0000 (-0800) Subject: Allow calling a Python function with a dict X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1349 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d26645354138892581135bf5a2ec4c91e95c6842;p=platform%2Fupstream%2Fpytorch.git Allow calling a Python function with a dict Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16989 Differential Revision: D14037896 Pulled By: driazati fbshipit-source-id: 5f26d2d8fabf0f267909a3383f19d984645f94d0 --- diff --git a/test/test_jit.py b/test/test_jit.py index 051b4ea..f2ab4bf 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9933,6 +9933,18 @@ a") self.checkScript(list_of_dicts, ()) + def dict_to_python(self): + def python_lookup(my_dict, keys): + # type: (Dict[str, int], List[str]) -> List[int] + return [my_dict[k] for k in keys] + + def fn(my_dict, keys): + # type: (Dict[str, int], List[str]) -> List[int] + return python_lookup(my_dict, keys) + + a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2} + self.checkScript(fn, (a_dict, ('a', 'c'))) + class MnistNet(nn.Module): def __init__(self): diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 07fcb83..ccda08d 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -162,7 +162,7 @@ def ignore(fn): try: import typing - from typing import Tuple, List + from typing import Tuple, List, Dict def is_tuple(ann): # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule @@ -174,6 +174,11 @@ try: return ann.__module__ == 'typing' and \ (getattr(ann, '__origin__', None) is typing.List or getattr(ann, '__origin__', None) is list) + + def is_dict(ann): + return ann.__module__ == 'typing' and \ + (getattr(ann, '__origin__', None) is typing.Dict or + getattr(ann, '__origin__', None) is dict) except ImportError: # A minimal polyfill for versions of Python that don't have typing. # Note that this means that they also don't support the fancy annotation syntax, so @@ -196,8 +201,17 @@ except ImportError: def __getitem__(self, types): return TupleInstance(types) + class DictInstance(object): + def __init__(self, types): + setattr(self, '__args__', types) + + class DictCls(object): + def __getitem__(self, types): + return DictInstance(types) + Tuple = TupleCls() List = ListCls() + Dict = DictCls() def is_tuple(ann): return isinstance(ann, TupleInstance) @@ -205,6 +219,9 @@ except ImportError: def is_list(ann): return isinstance(ann, ListInstance) + def is_dict(ann): + return isinstance(ann, DictInstance) + # allows BroadcastingList instance to be subscriptable class BroadcastingListCls(object): diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index b193af5..c2c350d 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -622,6 +622,8 @@ void initPythonIRBindings(PyObject* module_) { .def_static("get", &TensorType::get); py::class_>(m, "BoolType") .def_static("get", &BoolType::get); + py::class_>(m, "StringType") + .def_static("get", &StringType::get); py::class_>(m, "TupleType") .def( @@ -638,6 +640,10 @@ void initPythonIRBindings(PyObject* module_) { .def_static("ofInts", &ListType::ofInts) .def_static("ofTensors", &ListType::ofTensors) .def("getElementType", &ListType::getElementType); + py::class_>(m, "DictType") + .def(py::init([](TypePtr key, TypePtr value) { + return DictType::create(key, value); + })); py::class_(m, "Use") .def_readonly("user", &Use::user) diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index aa7880e..22e5c3f 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -3,8 +3,10 @@ import sys import ast import inspect import torch -from .._jit_internal import List, BroadcastingList1, BroadcastingList2, BroadcastingList3, Tuple, is_tuple, is_list -from torch._C import DynamicType, TupleType, FloatType, IntType, ListType +from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \ + BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict +from torch._C import DynamicType, TupleType, FloatType, IntType, \ + ListType, StringType, DictType from textwrap import dedent @@ -29,6 +31,7 @@ _eval_env = { 'typing': Module('typing', {'Tuple': Tuple}), 'Tuple': Tuple, 'List': List, + 'Dict': Dict, } @@ -167,8 +170,14 @@ def ann_to_type(ann): return TupleType([ann_to_type(a) for a in ann.__args__]) elif is_list(ann): return ListType(ann_to_type(ann.__args__[0])) + elif is_dict(ann): + key = ann_to_type(ann.__args__[0]) + value = ann_to_type(ann.__args__[1]) + return DictType(key, value) elif ann is float: return FloatType.get() elif ann is int: return IntType.get() - raise ValueError("The only supported annotations kinds are Tensor and Tuple[...]") + elif ann is str: + return StringType.get() + raise ValueError("Unknown type annotation: '{}'".format(ann.__name__))