Ellipsis in subscript
authorNikolay Korovaiko <korovaikon@gmail.com>
Tue, 16 Apr 2019 05:05:20 +0000 (22:05 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 16 Apr 2019 05:10:44 +0000 (22:10 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17763

Differential Revision: D14893533

Pulled By: Krovatkin

fbshipit-source-id: c46b4e386d3aa30e6dc03e3052d2e5ff097fa74b

test/test_jit.py
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/parser.cpp
torch/csrc/jit/script/python_tree_views.cpp
torch/csrc/jit/script/tree_views.h
torch/jit/frontend.py

index 47901d4..ff4a9b3 100644 (file)
@@ -8124,6 +8124,36 @@ a")
 
         self.assertEqual(8, bar(torch.ones(1, 1)))
 
+    def test_ellipsis_mid(self):
+        def ellipsize(x):
+            # type: (Tensor) -> List[int]
+            return x[2, ..., 0:4, 4:8].size()
+
+        dummy = torch.zeros(8, 8, 8, 8, 8)
+        self.checkScript(ellipsize, (dummy,), optimize=True)
+
+    def test_ellipsis_mid_select(self):
+        def ellipsize(x):
+            # type: (Tensor) -> List[int]
+            return x[2, ..., 4, 4, 4:8, 2].size()
+
+        dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8)
+        self.checkScript(ellipsize, (dummy,), optimize=True)
+
+    def test_ellipsis_start(self):
+        def ellipsize(x):
+            # type: (Tensor) -> List[int]
+            return x[..., 0:4, 4:8].size()
+        dummy = torch.zeros(8, 8, 8, 8, 8)
+        self.checkScript(ellipsize, (dummy,), optimize=True)
+
+    def test_ellipsis_end(self):
+        def ellipsize(x):
+            # type: (Tensor) -> List[int]
+            return x[0:4, 2, ...].size()
+        dummy = torch.zeros(8, 8, 8, 8, 8)
+        self.checkScript(ellipsize, (dummy,), optimize=True)
+
     def test_tracing_slicing(self):
         @_trace(torch.zeros(10))
         def foo_trace(x):
index 6048ed6..33c7168 100644 (file)
@@ -1,5 +1,4 @@
 #include <torch/csrc/jit/script/compiler.h>
-
 #include <c10/util/Exception.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
 #include <torch/csrc/jit/interpreter.h>
@@ -2433,16 +2432,10 @@ struct to_ir {
   Value* emitSelect(
       const SourceRange& loc,
       Value* input,
-      int64_t dim,
+      Value* dim,
       Value* index) {
     return emitBuiltinCall(
-        loc,
-        *graph,
-        aten::select,
-        c10::nullopt,
-        {input, graph->insertConstant(dim, nullptr, loc), index},
-        {},
-        true);
+        loc, *graph, aten::select, c10::nullopt, {input, dim, index}, {}, true);
   }
 
   // Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end,
@@ -2450,7 +2443,7 @@ struct to_ir {
   Value* emitSlice(
       const SourceRange& loc,
       Value* input,
-      c10::optional<int64_t> dim, // Only used for tensor slicing
+      Value* dim, // Only used for tensor slicing
       const SliceExpr& slice) {
     std::vector<NamedValue> args;
     args.reserve(4);
@@ -2460,8 +2453,8 @@ struct to_ir {
     // aten::slice, we should separate it from this function.
     if (dim) {
       AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
-      args.emplace_back(
-          loc, "dim", graph->insertConstant(dim.value(), nullptr, loc));
+
+      args.emplace_back(dim);
     } else {
       AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
     }
@@ -2530,15 +2523,43 @@ struct to_ir {
       dim++;
     };
 
+    // before ellipsis, dimension index should be `dim`
+    // after ellipsis, dimension index should be `-offset`
+    int offset = 0;
+    size_t ellipsis_dim = 0;
+    auto insert_value_for_dim = [&](int64_t dim) {
+      return (offset == 0)
+          ? graph->insertConstant(dim, nullptr, loc)
+          :
+          // NB: offset is incremented to move to the next dimension index
+          graph->insertConstant(offset++, nullptr, loc);
+    };
+
     for (const auto& subscript_expr : subscript_exprs) {
+      // NB: ellipsis_dim is **always** incremented
+      // (comparing to dim) in order to compute
+      // the correct offsets for the remaining
+      // dimension indices following an ellipsis "..."
+      // token
+      ellipsis_dim++;
+      if (subscript_expr.kind() == TK_DOTS) {
+        offset = -(subscript_exprs.size() - ellipsis_dim);
+        ++dim;
+        continue;
+      }
       if (subscript_expr.kind() == TK_SLICE_EXPR) {
-        sliceable = emitSlice(loc, sliceable, dim, SliceExpr(subscript_expr));
+        auto dim_val = insert_value_for_dim(dim);
+        sliceable =
+            emitSlice(loc, sliceable, dim_val, SliceExpr(subscript_expr));
         ++dim;
         continue;
       }
       auto index = emitExpr(subscript_expr, OptionalType::ofTensor());
       if (index->type() == IntType::get()) {
-        sliceable = emitSelect(loc, sliceable, dim, index);
+        // NB: note, select squeezes out a dimension,
+        // so dim is **not** incremented
+        auto dim_val = insert_value_for_dim(dim);
+        sliceable = emitSelect(loc, sliceable, dim_val, index);
         continue;
       } else if (index->type()->isSubtypeOf(NoneType::get())) {
         sliceable = emitUnsqueeze(loc, sliceable, dim);
@@ -2614,10 +2635,10 @@ struct to_ir {
     AT_ASSERT(subscript_exprs.size() == 1);
     AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR);
     auto slice_exp = SliceExpr(subscript_exprs[0]);
-    c10::optional<int64_t> maybe_dim;
+    Value* maybe_dim = nullptr;
     if (sliceable->type()->isSubtypeOf(TensorType::get())) {
       // If the sliceable object is a tensor, specify a default dimension
-      maybe_dim = 0;
+      maybe_dim = graph->insertConstant(0, nullptr, loc);
     }
     return emitSlice(loc, sliceable, maybe_dim, slice_exp);
   }
index 46a121e..ee6d385 100644 (file)
@@ -162,6 +162,10 @@ struct ParserImpl {
       case TK_STRINGLITERAL: {
         prefix = parseConcatenatedStringLiterals();
       } break;
+      case TK_DOTS: {
+        prefix = Dots::create(L.cur().range);
+        L.next();
+      } break;
       default: {
         Ident name = parseIdent();
         prefix = Var::create(name.range(), name);
index 3937a52..4bd97db 100644 (file)
@@ -160,6 +160,8 @@ void initTreeViewBindings(PyObject* module) {
       }));
   py::class_<Pass, Stmt>(m, "Pass").def(
       py::init([](const SourceRange& range) { return Pass::create(range); }));
+      py::class_<Dots, Expr>(m, "Dots").def(
+          py::init([](const SourceRange& range) { return Dots::create(range); }));
   py::class_<If, Stmt>(m, "If").def(
       py::init([](const SourceRange& range,
                   const Expr& cond,
index 5736e45..67b73dd 100644 (file)
@@ -295,6 +295,7 @@ struct Expr : public TreeView {
       case '^':
       case '|':
       case TK_LIST_COMP:
+      case TK_DOTS:
         return;
       default:
         throw ErrorReport(tree)
@@ -499,7 +500,7 @@ struct For : public Stmt {
   }
 };
 
-//TODO: supports only single comprehension for now
+// TODO: supports only single comprehension for now
 struct ListComp : public Expr {
   explicit ListComp(const TreeRef& tree) : Expr(tree) {
     tree->match(TK_LIST_COMP);
@@ -642,6 +643,15 @@ struct Pass : public Stmt {
   }
 };
 
+struct Dots : public Expr {
+  explicit Dots(const TreeRef& tree) : Expr(tree) {
+    tree_->match(TK_DOTS);
+  }
+  static Dots create(const SourceRange& range) {
+    return Dots(Compound::create(TK_DOTS, range, {}));
+  }
+};
+
 struct ExprStmt : public Stmt {
   explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
     tree_->match(TK_EXPR_STMT);
index 435e593..ba167ad 100644 (file)
@@ -414,6 +414,11 @@ class ExprBuilder(Builder):
         return Apply(func, args, kwargs)
 
     @staticmethod
+    def build_Ellipsis(ctx, expr):
+        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 3)  # len("...") == 3
+        return Dots(r)
+
+    @staticmethod
     def build_Name(ctx, expr):
         r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
         if expr.id.startswith(_reserved_prefix):
@@ -531,6 +536,8 @@ class ExprBuilder(Builder):
                     sub_exprs.append(build_Index(ctx, base, expr))
                 elif sub_type is ast.Slice:
                     sub_exprs.append(build_SliceExpr(ctx, base, expr))
+                elif sub_type is ast.Ellipsis:
+                    sub_exprs.append(Dots(base.range()))
                 else:
                     raise NotSupportedError(base.range(),
                                             "slicing multiple dimensions with "