From 5119cc7cdf01093747713410dc27c3ab22b28ba3 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 23 Apr 2019 16:30:49 -0700 Subject: [PATCH] builtin ivalues sort (#19572) Summary: Add sorting to all the lists which we specialize on (Tensor, int, float, bool). First part of https://github.com/pytorch/pytorch/issues/19372 Pull Request resolved: https://github.com/pytorch/pytorch/pull/19572 Differential Revision: D15052677 Pulled By: eellison fbshipit-source-id: 301e8e0e3e29e04aca1311410db0a474fd833cff --- test/test_jit.py | 37 ++++++++++++++++++++++++++++++++++++ torch/csrc/jit/register_prim_ops.cpp | 31 ++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/test/test_jit.py b/test/test_jit.py index 7e9d384..1cb922b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4118,6 +4118,43 @@ a") RuntimeError, "bool value of Tensor") + def test_list_sort(self): + template = dedent(''' + def func(): + li = {list_create} + li.sort() + return li + ''') + + lists = ["[]", "[1, 3, 2]", "[True, False, True]", "[1.2, .2, 3.2]", + "[torch.tensor(1.0), torch.tensor(0.2), torch.tensor(0.5)]", + "[torch.tensor(5), torch.tensor(-2), torch.tensor(4)]"] + for li in lists: + code = template.format(list_create=li) + scope = {} + exec(code, globals(), scope) + cu = torch.jit.CompilationUnit(code) + t1 = cu.func() + t2 = scope['func']() + self.assertEqual(t1, t2) + + def test_fail(x): + # type: (List[Tensor]) -> List[Tensor] + x.sort() + return x + + self.checkScriptRaisesRegex(test_fail, (([torch.zeros([2]), torch.zeros([2])],)), Exception, + "bool value of Tensor with more than one value") + + @torch.jit.script + def test_mutation(): + a = [1, 2, 3] + a.sort() + return a + + test_mutation() + FileCheck().check("aten::sort").run(test_mutation.graph_for()) + def test_list_slice(self): def test_regular_slice(): a = [0, 1, 2, 3, 4] diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 238c942..e15e58b 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1490,6 +1490,29 @@ int listSlice(Stack& stack) { return 0; } +template +int listSort(Stack& stack) { + TList list; + + pop(stack, list); + std::sort(list->elements().begin(), list->elements().end()); + return 0; +} + +// Specialization for at::Tensor +template <> +int listSort>(Stack& stack) { + Shared list; + pop(stack, list); + std::sort( + list->elements().begin(), + list->elements().end(), + [](const at::Tensor& a, const at::Tensor& b) { + return a.lt(b).is_nonzero(); + }); + return 0; +} + template int listSetItem(Stack& stack) { TList list; @@ -1796,6 +1819,14 @@ RegisterOperators reg2({ CREATE_LIST_OPS("Tensor", TensorList), CREATE_LIST_OPS("t", GenericList), #undef CREATE_LIST_OPS + Operator("aten::sort(int[](a!) self) -> ()", listSort>), + Operator( + "aten::sort(float[](a!) self) -> ()", + listSort>), + Operator( + "aten::sort(Tensor[](a!) self) -> ()", + listSort>), + Operator("aten::sort(bool[](a!) self) -> ()", listSort>), Operator("aten::eq(int[] a, int[] b) -> bool", listEq>), Operator( -- 2.7.4