From ada10ad416f91602d46797b50b3faebcacc6e767 Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Mon, 15 Apr 2019 22:05:20 -0700 Subject: [PATCH] Ellipsis in subscript Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17763 Differential Revision: D14893533 Pulled By: Krovatkin fbshipit-source-id: c46b4e386d3aa30e6dc03e3052d2e5ff097fa74b --- test/test_jit.py | 30 ++++++++++++++++ torch/csrc/jit/script/compiler.cpp | 53 ++++++++++++++++++++--------- torch/csrc/jit/script/parser.cpp | 4 +++ torch/csrc/jit/script/python_tree_views.cpp | 2 ++ torch/csrc/jit/script/tree_views.h | 12 ++++++- torch/jit/frontend.py | 7 ++++ 6 files changed, 91 insertions(+), 17 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 47901d4..ff4a9b3 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8124,6 +8124,36 @@ a") self.assertEqual(8, bar(torch.ones(1, 1))) + def test_ellipsis_mid(self): + def ellipsize(x): + # type: (Tensor) -> List[int] + return x[2, ..., 0:4, 4:8].size() + + dummy = torch.zeros(8, 8, 8, 8, 8) + self.checkScript(ellipsize, (dummy,), optimize=True) + + def test_ellipsis_mid_select(self): + def ellipsize(x): + # type: (Tensor) -> List[int] + return x[2, ..., 4, 4, 4:8, 2].size() + + dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8) + self.checkScript(ellipsize, (dummy,), optimize=True) + + def test_ellipsis_start(self): + def ellipsize(x): + # type: (Tensor) -> List[int] + return x[..., 0:4, 4:8].size() + dummy = torch.zeros(8, 8, 8, 8, 8) + self.checkScript(ellipsize, (dummy,), optimize=True) + + def test_ellipsis_end(self): + def ellipsize(x): + # type: (Tensor) -> List[int] + return x[0:4, 2, ...].size() + dummy = torch.zeros(8, 8, 8, 8, 8) + self.checkScript(ellipsize, (dummy,), optimize=True) + def test_tracing_slicing(self): @_trace(torch.zeros(10)) def foo_trace(x): diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 6048ed6..33c7168 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -1,5 +1,4 @@ #include - #include #include #include @@ -2433,16 +2432,10 @@ struct to_ir { Value* emitSelect( const SourceRange& loc, Value* input, - int64_t dim, + Value* dim, Value* index) { return emitBuiltinCall( - loc, - *graph, - aten::select, - c10::nullopt, - {input, graph->insertConstant(dim, nullptr, loc), index}, - {}, - true); + loc, *graph, aten::select, c10::nullopt, {input, dim, index}, {}, true); } // Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end, @@ -2450,7 +2443,7 @@ struct to_ir { Value* emitSlice( const SourceRange& loc, Value* input, - c10::optional dim, // Only used for tensor slicing + Value* dim, // Only used for tensor slicing const SliceExpr& slice) { std::vector args; args.reserve(4); @@ -2460,8 +2453,8 @@ struct to_ir { // aten::slice, we should separate it from this function. if (dim) { AT_ASSERT(input->type()->isSubtypeOf(TensorType::get())); - args.emplace_back( - loc, "dim", graph->insertConstant(dim.value(), nullptr, loc)); + + args.emplace_back(dim); } else { AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get())); } @@ -2530,15 +2523,43 @@ struct to_ir { dim++; }; + // before ellipsis, dimension index should be `dim` + // after ellipsis, dimension index should be `-offset` + int offset = 0; + size_t ellipsis_dim = 0; + auto insert_value_for_dim = [&](int64_t dim) { + return (offset == 0) + ? graph->insertConstant(dim, nullptr, loc) + : + // NB: offset is incremented to move to the next dimension index + graph->insertConstant(offset++, nullptr, loc); + }; + for (const auto& subscript_expr : subscript_exprs) { + // NB: ellipsis_dim is **always** incremented + // (comparing to dim) in order to compute + // the correct offsets for the remaining + // dimension indices following an ellipsis "..." + // token + ellipsis_dim++; + if (subscript_expr.kind() == TK_DOTS) { + offset = -(subscript_exprs.size() - ellipsis_dim); + ++dim; + continue; + } if (subscript_expr.kind() == TK_SLICE_EXPR) { - sliceable = emitSlice(loc, sliceable, dim, SliceExpr(subscript_expr)); + auto dim_val = insert_value_for_dim(dim); + sliceable = + emitSlice(loc, sliceable, dim_val, SliceExpr(subscript_expr)); ++dim; continue; } auto index = emitExpr(subscript_expr, OptionalType::ofTensor()); if (index->type() == IntType::get()) { - sliceable = emitSelect(loc, sliceable, dim, index); + // NB: note, select squeezes out a dimension, + // so dim is **not** incremented + auto dim_val = insert_value_for_dim(dim); + sliceable = emitSelect(loc, sliceable, dim_val, index); continue; } else if (index->type()->isSubtypeOf(NoneType::get())) { sliceable = emitUnsqueeze(loc, sliceable, dim); @@ -2614,10 +2635,10 @@ struct to_ir { AT_ASSERT(subscript_exprs.size() == 1); AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR); auto slice_exp = SliceExpr(subscript_exprs[0]); - c10::optional maybe_dim; + Value* maybe_dim = nullptr; if (sliceable->type()->isSubtypeOf(TensorType::get())) { // If the sliceable object is a tensor, specify a default dimension - maybe_dim = 0; + maybe_dim = graph->insertConstant(0, nullptr, loc); } return emitSlice(loc, sliceable, maybe_dim, slice_exp); } diff --git a/torch/csrc/jit/script/parser.cpp b/torch/csrc/jit/script/parser.cpp index 46a121e..ee6d385 100644 --- a/torch/csrc/jit/script/parser.cpp +++ b/torch/csrc/jit/script/parser.cpp @@ -162,6 +162,10 @@ struct ParserImpl { case TK_STRINGLITERAL: { prefix = parseConcatenatedStringLiterals(); } break; + case TK_DOTS: { + prefix = Dots::create(L.cur().range); + L.next(); + } break; default: { Ident name = parseIdent(); prefix = Var::create(name.range(), name); diff --git a/torch/csrc/jit/script/python_tree_views.cpp b/torch/csrc/jit/script/python_tree_views.cpp index 3937a52..4bd97db 100644 --- a/torch/csrc/jit/script/python_tree_views.cpp +++ b/torch/csrc/jit/script/python_tree_views.cpp @@ -160,6 +160,8 @@ void initTreeViewBindings(PyObject* module) { })); py::class_(m, "Pass").def( py::init([](const SourceRange& range) { return Pass::create(range); })); + py::class_(m, "Dots").def( + py::init([](const SourceRange& range) { return Dots::create(range); })); py::class_(m, "If").def( py::init([](const SourceRange& range, const Expr& cond, diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h index 5736e45..67b73dd 100644 --- a/torch/csrc/jit/script/tree_views.h +++ b/torch/csrc/jit/script/tree_views.h @@ -295,6 +295,7 @@ struct Expr : public TreeView { case '^': case '|': case TK_LIST_COMP: + case TK_DOTS: return; default: throw ErrorReport(tree) @@ -499,7 +500,7 @@ struct For : public Stmt { } }; -//TODO: supports only single comprehension for now +// TODO: supports only single comprehension for now struct ListComp : public Expr { explicit ListComp(const TreeRef& tree) : Expr(tree) { tree->match(TK_LIST_COMP); @@ -642,6 +643,15 @@ struct Pass : public Stmt { } }; +struct Dots : public Expr { + explicit Dots(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_DOTS); + } + static Dots create(const SourceRange& range) { + return Dots(Compound::create(TK_DOTS, range, {})); + } +}; + struct ExprStmt : public Stmt { explicit ExprStmt(const TreeRef& tree) : Stmt(tree) { tree_->match(TK_EXPR_STMT); diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 435e593..ba167ad 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -414,6 +414,11 @@ class ExprBuilder(Builder): return Apply(func, args, kwargs) @staticmethod + def build_Ellipsis(ctx, expr): + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 3) # len("...") == 3 + return Dots(r) + + @staticmethod def build_Name(ctx, expr): r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id)) if expr.id.startswith(_reserved_prefix): @@ -531,6 +536,8 @@ class ExprBuilder(Builder): sub_exprs.append(build_Index(ctx, base, expr)) elif sub_type is ast.Slice: sub_exprs.append(build_SliceExpr(ctx, base, expr)) + elif sub_type is ast.Ellipsis: + sub_exprs.append(Dots(base.range())) else: raise NotSupportedError(base.range(), "slicing multiple dimensions with " -- 2.7.4