Support for basic list comprehensions (#17267)
authorNikolay Korovaiko <korovaikon@gmail.com>
Fri, 22 Mar 2019 22:22:23 +0000 (15:22 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 22:25:13 +0000 (15:25 -0700)
Summary:
Supports the following syntax:
```
        torch.jit.script
        def comp(l):
            # type: (List[float]) -> List[float]

            n = [x * 3 for x in l]
            return n
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17267

Differential Revision: D14581119

Pulled By: Krovatkin

fbshipit-source-id: 6fd091a8a9ab607386ac58fda6ad88bf8aea380e

test/test_jit.py
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/lexer.cpp
torch/csrc/jit/script/lexer.h
torch/csrc/jit/script/parser.cpp
torch/csrc/jit/script/python_tree_views.cpp
torch/csrc/jit/script/tree_views.h
torch/jit/frontend.py

index 20673f6..65f991b 100644 (file)
@@ -4098,6 +4098,47 @@ a")
             return a == [0, 1, 2, 3]
         self.checkScript(test_append, ())
 
+    def test_comprehensions_basic(self):
+        def comp(l):
+            # type: (List[int]) -> List[int]
+
+            n = [x * 3 for x in l]
+            return n
+
+        comp([1, 2, 3])
+        self.checkScript(comp, ([1, 2, 3],))
+
+    def test_comprehensions_basic_float(self):
+        def comp(l):
+            # type: (List[float]) -> List[float]
+
+            n = [x * 3 for x in l]
+            return n
+
+        self.checkScript(comp, ([1.0, 2.0, 3.0],))
+
+    def test_comprehensions_two_comps(self):
+        @torch.jit.script
+        def comp(l1, l2):
+            # type: (List[int], List[int]) -> List[int]
+
+            n = [x * 3 for x in l1]
+            n2 = [x + 2 for x in l2]
+            return n + n2
+
+        self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7])
+
+    def test_comprehensions_wrong_expr_type(self):
+        with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
+            @torch.jit.script
+            def comp(l):
+                # type: (List[int]) -> List[float]
+
+                n = [float(x) for x in l]
+                return n
+
+            comp([1, 2, 3])
+
     def test_mutable_list_append_2(self):
         def test_append_2():
             a = [0, 1]
index 699eccf..0b1074f 100644 (file)
@@ -17,6 +17,7 @@
 
 #include <c10/util/Optional.h>
 
+#include <atomic>
 #include <climits>
 #include <set>
 
@@ -968,6 +969,50 @@ struct to_ir {
     return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
   }
 
+  Value* emitListComprehension(const ListComp& lc) {
+    // this avoids a race condition where we would re-use the same temp name
+    static std::atomic<size_t> tmp_count{0};
+    const auto tmp_name =
+        std::string("___list_acc") + std::to_string(tmp_count++);
+    const auto list_value = emitExpr(lc.iter());
+    if (list_value->type()->kind() != TypeKind::ListType) {
+      // TODO: constraining iterators to be simple lists for now
+      // as it makes easy to get list's element type.
+      throw ErrorReport(lc.range())
+          << "iterator expression is expected to be a list";
+    }
+    auto elem_types = list_value->type()->containedTypes();
+    // TODO: users can easily change the type to (x,1) or float(x)
+    // as in `float(x) for x in my_list_of_ints`
+    // eventually, we would probably want to temporarily inject x
+    // so we can evaluate the generator expression (e.g. `float(x)`) depending
+    // on x
+
+    // given `[x*2 for x in my_list]` this generates the following AST:
+    // __list_acc = []
+    // for x in my_list:
+    //  __list_acc.append(x*2)
+    const auto n = graph->insertNode(
+        graph->createList(elem_types.at(0), at::ArrayRef<Value*>{}));
+    environment_stack->setVar(lc.range(), tmp_name, n->output());
+    const auto tmp_list_ident = Ident::create(lc.range(), tmp_name);
+    const auto tmp_list_var = Var::create(lc.range(), tmp_list_ident);
+    const auto append_ident = Ident::create(lc.range(), "append");
+    const auto dot_op = Select::create(lc.range(), tmp_list_var, append_ident);
+    const auto append_args_list = List<Expr>::create(lc.range(), {lc.elt()});
+    const auto append_attrs = List<Attribute>::create(lc.range(), {});
+    const auto apply_append =
+        Apply::create(lc.range(), dot_op, append_args_list, append_attrs);
+    const auto expr_stmt = ExprStmt::create(lc.range(), apply_append);
+    const auto stmt_list = List<Stmt>::create(lc.range(), {expr_stmt});
+    const auto iters_list = List<Expr>::create(lc.range(), {lc.iter()});
+    const auto targets_list = List<Expr>::create(lc.range(), {lc.target()});
+    const auto for_loop =
+        For::create(lc.range(), targets_list, iters_list, stmt_list);
+    emitFor(for_loop);
+    return n->output();
+  }
+
   // Insert subtyping refinements
   void insertRefinements(const Refinements& ref) {
     for (const auto& name_mappings : ref.mappings_) {
@@ -2340,6 +2385,10 @@ struct to_ir {
             ->insertNode(graph->createDict(key_type, value_type, keys, values))
             ->output();
       } break;
+      case TK_LIST_COMP: {
+        auto lc = ListComp(tree);
+        return emitListComprehension(lc);
+      } break;
       default:
         throw ErrorReport(tree) << "Cannot emit expr for: " << tree;
         break;
index fe86e4b..eaf94d6 100644 (file)
@@ -12,6 +12,7 @@ namespace script {
 
 static const std::unordered_map<int, int> binary_prec = {
     {TK_IF, 1},
+    {TK_FOR, 1},
     {TK_AND, 2},
     {TK_OR, 2},
     // reserve a level for unary not
index aaaa2e8..13bc587 100644 (file)
@@ -95,6 +95,7 @@ namespace script {
   _(TK_RAISE, "raise", "raise")                  \
   _(TK_ASSERT, "assert", "assert")               \
   _(TK_DOTS, "dots", "...")                      \
+  _(TK_LIST_COMP, "list comprehension", "")      \
   _(TK_PASS, "pass", "pass")                     \
   _(TK_CLASS_DEF, "class", "class")
 
index c0db32a..0a3d61e 100644 (file)
@@ -127,7 +127,19 @@ struct ParserImpl {
       } break;
       case '[': {
         auto list = parseList('[', ',', ']', &ParserImpl::parseExp);
-        prefix = ListLiteral::create(list.range(), List<Expr>(list));
+
+        if (list.size() == 1 && (*list.begin()).kind() == TK_LIST_COMP) {
+          prefix = *list.begin();
+        } else {
+          for (auto se : list) {
+            if (se.kind() == TK_LIST_COMP) {
+              throw ErrorReport(list.range())
+                  << " expected a single list comprehension within '[' , ']'";
+            }
+          }
+          prefix = ListLiteral::create(list.range(), List<Expr>(list));
+        }
+
       } break;
       case '{': {
         L.next();
@@ -239,6 +251,14 @@ struct ParserImpl {
         continue;
       }
 
+      if (kind == TK_FOR) {
+        auto target = parseExp();
+        L.expect(TK_IN);
+        auto iter = parseExp();
+        prefix = ListComp::create(pos, Expr(prefix), target, iter);
+        continue;
+      }
+
       prefix = c(kind, pos, {prefix, parseExp(binary_prec)});
     }
     return Expr(prefix);
@@ -331,6 +351,7 @@ struct ParserImpl {
 
     auto subscript_exprs =
         parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
+
     return Subscript::create(range, Expr(value), subscript_exprs);
   }
 
index ea1ddb0..3937a52 100644 (file)
@@ -237,6 +237,11 @@ void initTreeViewBindings(PyObject* module) {
           [](const Expr& cond, const Expr& true_expr, const Expr& false_expr) {
             return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
           }));
+  py::class_<ListComp, Expr>(m, "ListComp")
+      .def(py::init(
+          [](const SourceRange& range, const Expr& elt, const Expr& target, const Expr& iter) {
+            return ListComp::create(range, elt, target, iter);
+          }));
   py::class_<ListLiteral, Expr>(m, "ListLiteral")
       .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
         return ListLiteral::create(range, wrap_list(range, std::move(args)));
index b3a5b60..5736e45 100644 (file)
@@ -294,6 +294,7 @@ struct Expr : public TreeView {
       case '&':
       case '^':
       case '|':
+      case TK_LIST_COMP:
         return;
       default:
         throw ErrorReport(tree)
@@ -498,6 +499,30 @@ struct For : public Stmt {
   }
 };
 
+//TODO: supports only single comprehension for now
+struct ListComp : public Expr {
+  explicit ListComp(const TreeRef& tree) : Expr(tree) {
+    tree->match(TK_LIST_COMP);
+  }
+  Expr elt() const {
+    return Expr(subtree(0));
+  }
+  Expr target() const {
+    return Expr(subtree(1));
+  }
+  Expr iter() const {
+    return Expr(subtree(2));
+  }
+  // TODO: no ifs for now
+  static ListComp create(
+      const SourceRange& range,
+      const Expr& elt,
+      const Expr& target,
+      const Expr& iter) {
+    return ListComp(Compound::create(TK_LIST_COMP, range, {elt, target, iter}));
+  }
+};
+
 struct Global : public Stmt {
   explicit Global(const TreeRef& tree) : Stmt(tree) {
     tree_->match(TK_GLOBAL);
index 70089dd..539d883 100644 (file)
@@ -584,6 +584,20 @@ class ExprBuilder(Builder):
         return StringLiteral(r, value)
 
     @staticmethod
+    def build_ListComp(ctx, stmt):
+        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
+        if (len(stmt.generators) > 1):
+            raise NotSupportedError(r, "multiple comprehension generators not supported yet")
+
+        if (len(stmt.generators[0].ifs) != 0):
+            raise NotSupportedError(r, "comprehension ifs not supported yet")
+
+        elt_expr = build_expr(ctx, stmt.elt)
+        target_expr = build_expr(ctx, stmt.generators[0].target)
+        iter_expr = build_expr(ctx, stmt.generators[0].iter)
+        return ListComp(r, elt_expr, target_expr, iter_expr)
+
+    @staticmethod
     def build_Starred(ctx, expr):
         r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
         return Starred(r, build_expr(ctx, expr.value))