self.assertEqual(8, bar(torch.ones(1, 1)))
+ def test_ellipsis_mid(self):
+ def ellipsize(x):
+ # type: (Tensor) -> List[int]
+ return x[2, ..., 0:4, 4:8].size()
+
+ dummy = torch.zeros(8, 8, 8, 8, 8)
+ self.checkScript(ellipsize, (dummy,), optimize=True)
+
+ def test_ellipsis_mid_select(self):
+ def ellipsize(x):
+ # type: (Tensor) -> List[int]
+ return x[2, ..., 4, 4, 4:8, 2].size()
+
+ dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8)
+ self.checkScript(ellipsize, (dummy,), optimize=True)
+
+ def test_ellipsis_start(self):
+ def ellipsize(x):
+ # type: (Tensor) -> List[int]
+ return x[..., 0:4, 4:8].size()
+ dummy = torch.zeros(8, 8, 8, 8, 8)
+ self.checkScript(ellipsize, (dummy,), optimize=True)
+
+ def test_ellipsis_end(self):
+ def ellipsize(x):
+ # type: (Tensor) -> List[int]
+ return x[0:4, 2, ...].size()
+ dummy = torch.zeros(8, 8, 8, 8, 8)
+ self.checkScript(ellipsize, (dummy,), optimize=True)
+
def test_tracing_slicing(self):
@_trace(torch.zeros(10))
def foo_trace(x):
#include <torch/csrc/jit/script/compiler.h>
-
#include <c10/util/Exception.h>
#include <torch/csrc/jit/hooks_for_testing.h>
#include <torch/csrc/jit/interpreter.h>
Value* emitSelect(
const SourceRange& loc,
Value* input,
- int64_t dim,
+ Value* dim,
Value* index) {
return emitBuiltinCall(
- loc,
- *graph,
- aten::select,
- c10::nullopt,
- {input, graph->insertConstant(dim, nullptr, loc), index},
- {},
- true);
+ loc, *graph, aten::select, c10::nullopt, {input, dim, index}, {}, true);
}
// Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end,
Value* emitSlice(
const SourceRange& loc,
Value* input,
- c10::optional<int64_t> dim, // Only used for tensor slicing
+ Value* dim, // Only used for tensor slicing
const SliceExpr& slice) {
std::vector<NamedValue> args;
args.reserve(4);
// aten::slice, we should separate it from this function.
if (dim) {
AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
- args.emplace_back(
- loc, "dim", graph->insertConstant(dim.value(), nullptr, loc));
+
+ args.emplace_back(dim);
} else {
AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
}
dim++;
};
+ // before ellipsis, dimension index should be `dim`
+ // after ellipsis, dimension index should be `-offset`
+ int offset = 0;
+ size_t ellipsis_dim = 0;
+ auto insert_value_for_dim = [&](int64_t dim) {
+ return (offset == 0)
+ ? graph->insertConstant(dim, nullptr, loc)
+ :
+ // NB: offset is incremented to move to the next dimension index
+ graph->insertConstant(offset++, nullptr, loc);
+ };
+
for (const auto& subscript_expr : subscript_exprs) {
+ // NB: ellipsis_dim is **always** incremented
+ // (comparing to dim) in order to compute
+ // the correct offsets for the remaining
+ // dimension indices following an ellipsis "..."
+ // token
+ ellipsis_dim++;
+ if (subscript_expr.kind() == TK_DOTS) {
+ offset = -(subscript_exprs.size() - ellipsis_dim);
+ ++dim;
+ continue;
+ }
if (subscript_expr.kind() == TK_SLICE_EXPR) {
- sliceable = emitSlice(loc, sliceable, dim, SliceExpr(subscript_expr));
+ auto dim_val = insert_value_for_dim(dim);
+ sliceable =
+ emitSlice(loc, sliceable, dim_val, SliceExpr(subscript_expr));
++dim;
continue;
}
auto index = emitExpr(subscript_expr, OptionalType::ofTensor());
if (index->type() == IntType::get()) {
- sliceable = emitSelect(loc, sliceable, dim, index);
+ // NB: note, select squeezes out a dimension,
+ // so dim is **not** incremented
+ auto dim_val = insert_value_for_dim(dim);
+ sliceable = emitSelect(loc, sliceable, dim_val, index);
continue;
} else if (index->type()->isSubtypeOf(NoneType::get())) {
sliceable = emitUnsqueeze(loc, sliceable, dim);
AT_ASSERT(subscript_exprs.size() == 1);
AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR);
auto slice_exp = SliceExpr(subscript_exprs[0]);
- c10::optional<int64_t> maybe_dim;
+ Value* maybe_dim = nullptr;
if (sliceable->type()->isSubtypeOf(TensorType::get())) {
// If the sliceable object is a tensor, specify a default dimension
- maybe_dim = 0;
+ maybe_dim = graph->insertConstant(0, nullptr, loc);
}
return emitSlice(loc, sliceable, maybe_dim, slice_exp);
}
case TK_STRINGLITERAL: {
prefix = parseConcatenatedStringLiterals();
} break;
+ case TK_DOTS: {
+ prefix = Dots::create(L.cur().range);
+ L.next();
+ } break;
default: {
Ident name = parseIdent();
prefix = Var::create(name.range(), name);
}));
py::class_<Pass, Stmt>(m, "Pass").def(
py::init([](const SourceRange& range) { return Pass::create(range); }));
+ py::class_<Dots, Expr>(m, "Dots").def(
+ py::init([](const SourceRange& range) { return Dots::create(range); }));
py::class_<If, Stmt>(m, "If").def(
py::init([](const SourceRange& range,
const Expr& cond,
case '^':
case '|':
case TK_LIST_COMP:
+ case TK_DOTS:
return;
default:
throw ErrorReport(tree)
}
};
-//TODO: supports only single comprehension for now
+// TODO: supports only single comprehension for now
struct ListComp : public Expr {
explicit ListComp(const TreeRef& tree) : Expr(tree) {
tree->match(TK_LIST_COMP);
}
};
+struct Dots : public Expr {
+ explicit Dots(const TreeRef& tree) : Expr(tree) {
+ tree_->match(TK_DOTS);
+ }
+ static Dots create(const SourceRange& range) {
+ return Dots(Compound::create(TK_DOTS, range, {}));
+ }
+};
+
struct ExprStmt : public Stmt {
explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_EXPR_STMT);
return Apply(func, args, kwargs)
@staticmethod
+ def build_Ellipsis(ctx, expr):
+ r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 3) # len("...") == 3
+ return Dots(r)
+
+ @staticmethod
def build_Name(ctx, expr):
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
if expr.id.startswith(_reserved_prefix):
sub_exprs.append(build_Index(ctx, base, expr))
elif sub_type is ast.Slice:
sub_exprs.append(build_SliceExpr(ctx, base, expr))
+ elif sub_type is ast.Ellipsis:
+ sub_exprs.append(Dots(base.range()))
else:
raise NotSupportedError(base.range(),
"slicing multiple dimensions with "