Allow future type parsing
authorJames Sun <jamessun@fb.com>
Mon, 17 Dec 2018 23:36:28 +0000 (15:36 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 17 Dec 2018 23:39:52 +0000 (15:39 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14887

Differential Revision: D13490984

Pulled By: highker

fbshipit-source-id: 165fe995867be273793f983154aa6cbce13e4396

test/test_jit.py
torch/csrc/jit/script/compiler.cpp

index 72bdf53..3e072de 100644 (file)
@@ -11064,6 +11064,31 @@ class TestAsync(JitTestCase):
         y = torch.jit._wait(fut)
         # assert nothing; only to make sure the fake python path works
 
+    def test_async_parsing(self):
+        @torch.jit.script
+        def foo(x):
+            # type: (Tensor) -> List[Tensor]
+            return [torch.neg(x), x.t()]
+
+        @torch.jit.script
+        def bar(x):
+            futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
+            for _ in range(3):
+                future = torch.jit.annotate(
+                    Future[List[Tensor]],
+                    torch.jit._fork(foo, x)
+                )
+                futures.append(future)
+
+            output = torch.jit.annotate(List[List[Tensor]], [])
+            for i in range(3):
+                output.append(torch.jit._wait(futures[i]))
+            return output
+
+        x = torch.rand(3, 3)
+        result = bar(x)
+        self.assertEqual(len(result), 3)
+
     def test_async_script(self):
         @torch.jit.script
         def foo(x):
index 88d0844..5b018a5 100644 (file)
@@ -2747,6 +2747,13 @@ const std::unordered_map<std::string, std::function<TypePtr(Subscript)>> &subscr
       auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
       return OptionalType::create(elem_type);
     }},
+    {"Future", [](Subscript subscript) -> TypePtr {
+      if (subscript.subscript_exprs().size() != 1) {
+        throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
+      }
+      auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
+      return FutureType::create(elem_type);
+    }},
   };
   return map;
 }