add reverse to list (#17001)
authorMichael Kösel <thecodez@users.noreply.github.com>
Fri, 15 Mar 2019 18:43:33 +0000 (11:43 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Mar 2019 18:53:37 +0000 (11:53 -0700)
Summary:
Add reverse functionality to list. See https://github.com/pytorch/pytorch/issues/16662

```python
import torch

torch.jit.script
def foo():
    a = [1, 2, 3, 4]
a.reverse()

    return a
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17001

Reviewed By: eellison

Differential Revision: D14092019

Pulled By: driazati

fbshipit-source-id: b353c763677c22312b64dde0db268e2988610ba1

test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp

index e470bd9..896cc4b 100644 (file)
@@ -4146,6 +4146,30 @@ a")
 
         self.assertEqual(foo(), [1, 2, 3, 4])
 
+    def test_mutable_list_reverse_empty(self):
+        def test_reverse_empty():
+            a = []
+            a.reverse()
+
+            return a == []
+        self.checkScript(test_reverse_empty, ())
+
+    def test_mutable_list_reverse(self):
+        def test_reverse():
+            a = [1, 2, 3, 4]
+            a.reverse()
+
+            return a == [4, 3, 2, 1]
+        self.checkScript(test_reverse, ())
+
+    def test_mutable_tensor_list_reverse(self):
+        def test_tensor_reverse():
+            a = [torch.tensor(1), torch.tensor(2)]
+            a.reverse()
+
+            return a == [torch.tensor(2), torch.tensor(1)]
+        self.checkScript(test_tensor_reverse, ())
+
     def test_mutable_list_pop_empty(self):
         @torch.jit.script
         def test_pop_empty():
index 5db825c..0075b08 100644 (file)
@@ -991,6 +991,17 @@ int listAppend(Stack& stack) {
 }
 
 template <typename TList>
+int listReverse(Stack& stack) {
+  TList a;
+  pop(stack, a);
+
+  auto& elements = a->elements();
+  std::reverse(elements.begin(), elements.end());
+
+  return 0;
+}
+
+template <typename TList>
 int listPop(Stack& stack) {
   TList list;
   int64_t idx;
@@ -1440,6 +1451,9 @@ RegisterOperators reg2({
           "(c) el) -> " decl_type "[](a!)",                                 \
           listAppend<Shared<c_type>, c_type::ElemType>),                    \
       Operator(                                                             \
+          "aten::reverse( " decl_type "[](a!) self) -> ()",                 \
+          listReverse<Shared<c_type>>),                                     \
+      Operator(                                                             \
           "aten::extend(" decl_type "[](a!) self, " decl_type               \
           " [] other) -> ()",                                               \
           listExtend<Shared<c_type>>),                                      \
@@ -1482,6 +1496,9 @@ RegisterOperators reg2({
           " el) -> " decl_type "[](a!)",                               \
           listAppend<Shared<c_type>, c_type::ElemType>),               \
       Operator(                                                        \
+          "aten::reverse(" decl_type "[](a!) self) -> ()",             \
+          listReverse<Shared<c_type>>),                                \
+      Operator(                                                        \
           "aten::extend(" decl_type "[](a!) self, " decl_type          \
           " [] other) -> ()",                                          \
           listExtend<Shared<c_type>>),                                 \
@@ -1524,7 +1541,7 @@ RegisterOperators reg2({
 #undef CREATE_MUTABLE_LIST_OPS
 
 #define CREATE_LIST_OPS(decl_type, c_type)                                          \
-  Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>),         \
+      Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>),     \
       Operator(                                                                     \
           "aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type           \
           "[]",                                                                     \