Dict mutability (#16884)
authorDavid Riazati <davidriazati@fb.com>
Fri, 22 Feb 2019 00:09:43 +0000 (16:09 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Feb 2019 00:24:17 +0000 (16:24 -0800)
Summary:
Adds `aten::_set_item` for `dict[key]` calls
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16884

Differential Revision: D14000488

Pulled By: driazati

fbshipit-source-id: ea1b46e0a736d095053effb4bc52753f696617b2

test/test_jit.py
torch/csrc/jit/operator.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/register_prim_ops.cpp

index 4e56c9f..d154ffa 100644 (file)
@@ -41,7 +41,7 @@ from common_methods_invocations import create_input, unpack_variables, \
     exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
 from copy import deepcopy
 import random
-from typing import List, Optional
+from typing import List, Dict, Optional
 from torch.jit.frontend import NotSupportedError
 from torch.jit import BatchTensor
 
@@ -10328,6 +10328,16 @@ a")
 
         self.checkScript(list_of_dicts, ())
 
+    def test_dict_mutability(self):
+        @torch.jit.script
+        def fn():
+            # type: () -> Dict[str, int]
+            a = torch.jit.annotate(Dict[str, int], {})
+            a['ok'] = 10
+            return a
+
+        self.assertEqual(fn(), {'ok': 10})
+
     def dict_to_python(self):
         def python_lookup(my_dict, keys):
             # type: (Dict[str, int], List[str]) -> List[int]
index 919560e..ff2d404 100644 (file)
@@ -167,8 +167,9 @@ struct SchemaParser {
       auto key_type = parseType().first;
       L.expect(',');
       auto value_type = parseType().first;
-      alias_info = parseAliasAnnotation();
       L.expect(')');
+      alias_info = parseAliasAnnotation();
+
       value = DictType::create(key_type, value_type);
     } else {
       auto value_alias = parseBaseType();
index 4126bcc..92a9d70 100644 (file)
@@ -801,9 +801,10 @@ struct PythonPrintPass {
       } break;
       case prim::DictConstruct: {
         auto dict_type = node->output()->type()->expect<DictType>();
-        if (node->inputs().size() == 0 &&
-            !dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
-            !dict_type->getValueType()->isSubtypeOf(TensorType::get())) {
+        bool is_default_type =
+            dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
+            dict_type->getKeyType()->isSubtypeOf(TensorType::get());
+        if (node->inputs().size() == 0 && !is_default_type) {
           stmt << "annotate(" << node->output()->type()->python_str()
                << ", {})";
         } else {
index 9729b97..dad49b7 100644 (file)
@@ -1179,6 +1179,15 @@ int listSetItem<Shared<BoolList>, bool>(Stack& stack) {
   return 0;
 }
 
+int dictSetItem(Stack& stack) {
+  auto value = pop(stack);
+  auto idx = pop(stack);
+  auto& dict = pop(stack).toGenericDict()->elements();
+  dict[idx] = value;
+  push(stack, dict);
+  return 0;
+}
+
 int dictLen(Stack& stack) {
   auto dict = pop(stack).toGenericDictRef();
   push(stack, int64_t(dict.size()));
@@ -1316,6 +1325,8 @@ RegisterOperators reg2({
     // NOTE: this must be after the other list specializations so that operator
     // resolution doesn't pick this up first
     CREATE_MUTABLE_LIST_OPS("t", GenericList),
+#undef CREATE_IMMUTABLE_LIST_OPS
+#undef CREATE_MUTABLE_LIST_OPS
 
 #define CREATE_LIST_OPS(decl_type, c_type)                                          \
   Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>),         \
@@ -1506,13 +1517,17 @@ RegisterOperators reg2({
 #define CREATE_DICT_OPS(key_type)                                              \
   Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen),           \
       Operator(                                                                \
-          "aten::keys(Dict(" key_type ", t) self) -> " key_type "[]",          \
+          "aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)",       \
           dictKeys),                                                           \
-      Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \
+      Operator("aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues),\
       Operator(                                                                \
           "prim::DictIndex(Dict(" key_type ", t) self, " key_type              \
-          " key) -> t",                                                        \
-          dictIndex)
+          " key) -> t(*)",                                                     \
+          dictIndex),                                                          \
+      Operator(                                                                \
+          "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type             \
+          " idx, t v) -> ()",                                                  \
+          dictSetItem)
 
     CREATE_DICT_OPS("str"),
     CREATE_DICT_OPS("int"),