From: James Sun Date: Mon, 17 Dec 2018 23:36:28 +0000 (-0800) Subject: Allow future type parsing X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2201 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c66adfc16b8c3e486c2fbcc18de2f0cffd73b208;p=platform%2Fupstream%2Fpytorch.git Allow future type parsing Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14887 Differential Revision: D13490984 Pulled By: highker fbshipit-source-id: 165fe995867be273793f983154aa6cbce13e4396 --- diff --git a/test/test_jit.py b/test/test_jit.py index 72bdf53..3e072de 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -11064,6 +11064,31 @@ class TestAsync(JitTestCase): y = torch.jit._wait(fut) # assert nothing; only to make sure the fake python path works + def test_async_parsing(self): + @torch.jit.script + def foo(x): + # type: (Tensor) -> List[Tensor] + return [torch.neg(x), x.t()] + + @torch.jit.script + def bar(x): + futures = torch.jit.annotate(List[Future[List[Tensor]]], []) + for _ in range(3): + future = torch.jit.annotate( + Future[List[Tensor]], + torch.jit._fork(foo, x) + ) + futures.append(future) + + output = torch.jit.annotate(List[List[Tensor]], []) + for i in range(3): + output.append(torch.jit._wait(futures[i])) + return output + + x = torch.rand(3, 3) + result = bar(x) + self.assertEqual(len(result), 3) + def test_async_script(self): @torch.jit.script def foo(x): diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 88d0844..5b018a5 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -2747,6 +2747,13 @@ const std::unordered_map> &subscr auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin()); return OptionalType::create(elem_type); }}, + {"Future", [](Subscript subscript) -> TypePtr { + if (subscript.subscript_exprs().size() != 1) { + throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size(); + } + auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin()); + return FutureType::create(elem_type); + }}, }; return map; }