self.checkScript(list_of_dicts, ())
+ def dict_to_python(self):
+ def python_lookup(my_dict, keys):
+ # type: (Dict[str, int], List[str]) -> List[int]
+ return [my_dict[k] for k in keys]
+
+ def fn(my_dict, keys):
+ # type: (Dict[str, int], List[str]) -> List[int]
+ return python_lookup(my_dict, keys)
+
+ a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
+ self.checkScript(fn, (a_dict, ('a', 'c')))
+
class MnistNet(nn.Module):
def __init__(self):
try:
import typing
- from typing import Tuple, List
+ from typing import Tuple, List, Dict
def is_tuple(ann):
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
return ann.__module__ == 'typing' and \
(getattr(ann, '__origin__', None) is typing.List or
getattr(ann, '__origin__', None) is list)
+
+ def is_dict(ann):
+ return ann.__module__ == 'typing' and \
+ (getattr(ann, '__origin__', None) is typing.Dict or
+ getattr(ann, '__origin__', None) is dict)
except ImportError:
# A minimal polyfill for versions of Python that don't have typing.
# Note that this means that they also don't support the fancy annotation syntax, so
def __getitem__(self, types):
return TupleInstance(types)
+ class DictInstance(object):
+ def __init__(self, types):
+ setattr(self, '__args__', types)
+
+ class DictCls(object):
+ def __getitem__(self, types):
+ return DictInstance(types)
+
Tuple = TupleCls()
List = ListCls()
+ Dict = DictCls()
def is_tuple(ann):
return isinstance(ann, TupleInstance)
def is_list(ann):
return isinstance(ann, ListInstance)
+ def is_dict(ann):
+ return isinstance(ann, DictInstance)
+
# allows BroadcastingList instance to be subscriptable
class BroadcastingListCls(object):
.def_static("get", &TensorType::get);
py::class_<BoolType, Type, std::shared_ptr<BoolType>>(m, "BoolType")
.def_static("get", &BoolType::get);
+ py::class_<StringType, Type, std::shared_ptr<StringType>>(m, "StringType")
+ .def_static("get", &StringType::get);
py::class_<TupleType, Type, std::shared_ptr<TupleType>>(m, "TupleType")
.def(
.def_static("ofInts", &ListType::ofInts)
.def_static("ofTensors", &ListType::ofTensors)
.def("getElementType", &ListType::getElementType);
+ py::class_<DictType, Type, std::shared_ptr<DictType>>(m, "DictType")
+ .def(py::init([](TypePtr key, TypePtr value) {
+ return DictType::create(key, value);
+ }));
py::class_<Use>(m, "Use")
.def_readonly("user", &Use::user)
import ast
import inspect
import torch
-from .._jit_internal import List, BroadcastingList1, BroadcastingList2, BroadcastingList3, Tuple, is_tuple, is_list
-from torch._C import DynamicType, TupleType, FloatType, IntType, ListType
+from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
+ BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict
+from torch._C import DynamicType, TupleType, FloatType, IntType, \
+ ListType, StringType, DictType
from textwrap import dedent
'typing': Module('typing', {'Tuple': Tuple}),
'Tuple': Tuple,
'List': List,
+ 'Dict': Dict,
}
return TupleType([ann_to_type(a) for a in ann.__args__])
elif is_list(ann):
return ListType(ann_to_type(ann.__args__[0]))
+ elif is_dict(ann):
+ key = ann_to_type(ann.__args__[0])
+ value = ann_to_type(ann.__args__[1])
+ return DictType(key, value)
elif ann is float:
return FloatType.get()
elif ann is int:
return IntType.get()
- raise ValueError("The only supported annotations kinds are Tensor and Tuple[...]")
+ elif ann is str:
+ return StringType.get()
+ raise ValueError("Unknown type annotation: '{}'".format(ann.__name__))