exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
from copy import deepcopy
import random
-from typing import List, Optional
+from typing import List, Dict, Optional
from torch.jit.frontend import NotSupportedError
from torch.jit import BatchTensor
self.checkScript(list_of_dicts, ())
+ def test_dict_mutability(self):
+ @torch.jit.script
+ def fn():
+ # type: () -> Dict[str, int]
+ a = torch.jit.annotate(Dict[str, int], {})
+ a['ok'] = 10
+ return a
+
+ self.assertEqual(fn(), {'ok': 10})
+
def dict_to_python(self):
def python_lookup(my_dict, keys):
# type: (Dict[str, int], List[str]) -> List[int]
auto key_type = parseType().first;
L.expect(',');
auto value_type = parseType().first;
- alias_info = parseAliasAnnotation();
L.expect(')');
+ alias_info = parseAliasAnnotation();
+
value = DictType::create(key_type, value_type);
} else {
auto value_alias = parseBaseType();
} break;
case prim::DictConstruct: {
auto dict_type = node->output()->type()->expect<DictType>();
- if (node->inputs().size() == 0 &&
- !dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
- !dict_type->getValueType()->isSubtypeOf(TensorType::get())) {
+ bool is_default_type =
+ dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
+ dict_type->getKeyType()->isSubtypeOf(TensorType::get());
+ if (node->inputs().size() == 0 && !is_default_type) {
stmt << "annotate(" << node->output()->type()->python_str()
<< ", {})";
} else {
return 0;
}
+int dictSetItem(Stack& stack) {
+ auto value = pop(stack);
+ auto idx = pop(stack);
+ auto& dict = pop(stack).toGenericDict()->elements();
+ dict[idx] = value;
+ push(stack, dict);
+ return 0;
+}
+
int dictLen(Stack& stack) {
auto dict = pop(stack).toGenericDictRef();
push(stack, int64_t(dict.size()));
// NOTE: this must be after the other list specializations so that operator
// resolution doesn't pick this up first
CREATE_MUTABLE_LIST_OPS("t", GenericList),
+#undef CREATE_IMMUTABLE_LIST_OPS
+#undef CREATE_MUTABLE_LIST_OPS
#define CREATE_LIST_OPS(decl_type, c_type) \
Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
#define CREATE_DICT_OPS(key_type) \
Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \
Operator( \
- "aten::keys(Dict(" key_type ", t) self) -> " key_type "[]", \
+ "aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)", \
dictKeys), \
- Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \
+ Operator("aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues),\
Operator( \
"prim::DictIndex(Dict(" key_type ", t) self, " key_type \
- " key) -> t", \
- dictIndex)
+ " key) -> t(*)", \
+ dictIndex), \
+ Operator( \
+ "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \
+ " idx, t v) -> ()", \
+ dictSetItem)
CREATE_DICT_OPS("str"),
CREATE_DICT_OPS("int"),