From: David Riazati Date: Fri, 22 Feb 2019 00:09:43 +0000 (-0800) Subject: Dict mutability (#16884) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1149 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ac00a0cd479173f7fa800283a4ca899aa6328778;p=platform%2Fupstream%2Fpytorch.git Dict mutability (#16884) Summary: Adds `aten::_set_item` for `dict[key]` calls Pull Request resolved: https://github.com/pytorch/pytorch/pull/16884 Differential Revision: D14000488 Pulled By: driazati fbshipit-source-id: ea1b46e0a736d095053effb4bc52753f696617b2 --- diff --git a/test/test_jit.py b/test/test_jit.py index 4e56c9f..d154ffa 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -41,7 +41,7 @@ from common_methods_invocations import create_input, unpack_variables, \ exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL from copy import deepcopy import random -from typing import List, Optional +from typing import List, Dict, Optional from torch.jit.frontend import NotSupportedError from torch.jit import BatchTensor @@ -10328,6 +10328,16 @@ a") self.checkScript(list_of_dicts, ()) + def test_dict_mutability(self): + @torch.jit.script + def fn(): + # type: () -> Dict[str, int] + a = torch.jit.annotate(Dict[str, int], {}) + a['ok'] = 10 + return a + + self.assertEqual(fn(), {'ok': 10}) + def dict_to_python(self): def python_lookup(my_dict, keys): # type: (Dict[str, int], List[str]) -> List[int] diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 919560e..ff2d404 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -167,8 +167,9 @@ struct SchemaParser { auto key_type = parseType().first; L.expect(','); auto value_type = parseType().first; - alias_info = parseAliasAnnotation(); L.expect(')'); + alias_info = parseAliasAnnotation(); + value = DictType::create(key_type, value_type); } else { auto value_alias = parseBaseType(); diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 4126bcc..92a9d70 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -801,9 +801,10 @@ struct PythonPrintPass { } break; case prim::DictConstruct: { auto dict_type = node->output()->type()->expect(); - if (node->inputs().size() == 0 && - !dict_type->getKeyType()->isSubtypeOf(StringType::get()) && - !dict_type->getValueType()->isSubtypeOf(TensorType::get())) { + bool is_default_type = + dict_type->getKeyType()->isSubtypeOf(StringType::get()) && + dict_type->getKeyType()->isSubtypeOf(TensorType::get()); + if (node->inputs().size() == 0 && !is_default_type) { stmt << "annotate(" << node->output()->type()->python_str() << ", {})"; } else { diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 9729b97..dad49b7 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1179,6 +1179,15 @@ int listSetItem, bool>(Stack& stack) { return 0; } +int dictSetItem(Stack& stack) { + auto value = pop(stack); + auto idx = pop(stack); + auto& dict = pop(stack).toGenericDict()->elements(); + dict[idx] = value; + push(stack, dict); + return 0; +} + int dictLen(Stack& stack) { auto dict = pop(stack).toGenericDictRef(); push(stack, int64_t(dict.size())); @@ -1316,6 +1325,8 @@ RegisterOperators reg2({ // NOTE: this must be after the other list specializations so that operator // resolution doesn't pick this up first CREATE_MUTABLE_LIST_OPS("t", GenericList), +#undef CREATE_IMMUTABLE_LIST_OPS +#undef CREATE_MUTABLE_LIST_OPS #define CREATE_LIST_OPS(decl_type, c_type) \ Operator("aten::len(" decl_type "[] a) -> int", listLen>), \ @@ -1506,13 +1517,17 @@ RegisterOperators reg2({ #define CREATE_DICT_OPS(key_type) \ Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \ Operator( \ - "aten::keys(Dict(" key_type ", t) self) -> " key_type "[]", \ + "aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)", \ dictKeys), \ - Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \ + Operator("aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues),\ Operator( \ "prim::DictIndex(Dict(" key_type ", t) self, " key_type \ - " key) -> t", \ - dictIndex) + " key) -> t(*)", \ + dictIndex), \ + Operator( \ + "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \ + " idx, t v) -> ()", \ + dictSetItem) CREATE_DICT_OPS("str"), CREATE_DICT_OPS("int"),