From 2fa3c8327c807680c3be8ff90d7dada0ef2056be Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Mon, 26 Nov 2018 12:02:09 -0800 Subject: [PATCH] fix tensor advanced indexing with assignment (#14311) Summary: Fix a mishandling of `foo[a] = b` when `a` was a tensor. We were assigning to a copy of `foo`, not a view of it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14311 Differential Revision: D13196109 Pulled By: suo fbshipit-source-id: c929401fda7c4a27622d3fe2b11278b08a7f17f1 --- aten/src/ATen/core/interned_strings.h | 1 + test/test_jit.py | 16 ++++ torch/csrc/jit/script/compiler.cpp | 148 ++++++++++++++++++---------------- 3 files changed, 96 insertions(+), 69 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index ad0c5e8..28b78dc 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -79,6 +79,7 @@ namespace c10 { _(aten, __isnot__) \ _(aten, copy_) \ _(aten, _set_item) \ + _(aten, index_put_) \ FORALL_ATEN_BASE_SYMBOLS(_) \ _(onnx, Add) \ _(onnx, Concat) \ diff --git a/test/test_jit.py b/test/test_jit.py index bea26a5..d06e5e7 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8547,6 +8547,22 @@ a") return a self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) + def test_lhs_advanced_indexing_assignment(self): + def foo(x, y): + a = torch.exp(x) + b = x == 1 + a[b] = y[b] + return a + self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3))) + + def test_lhs_advanced_indexing_augmented_assignment(self): + def foo(x, y): + a = torch.exp(x) + b = x == 1 + a[b] += y[b] + return a + self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3))) + def test_lhs_indexing_list(self): def foo(a, b): ls = [a] diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 0fe0963..7ae65a8 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -1533,21 +1533,47 @@ private: const auto sliceable = emitExpr(lhs.value()); if (sliceable->type()->isSubtypeOf(DynamicType::get())) { - // If it's a tensor, just fully evaluate the subscript operation and emit - // an in-place assignment - const auto lhsValue = - emitSubscript(lhs.range(), sliceable, lhs.subscript_exprs()); - + // If it's a tensor, just fully evaluate the subscript operation and emit + // an in-place assignment + std::vector tensorIndices; + Value* sliced; + std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing( + lhs.range(), sliceable, lhs.subscript_exprs()); + + const auto slicedArg = NamedValue(stmt.lhs().range(), "self", sliced); const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs())); - const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue); - emitBuiltinCall( - stmt.range(), - *method.graph(), - getAugOp(stmt, /*isTensor=*/true), - self, - {rhs}, - {}, - /*required=*/true); + if (tensorIndices.size() == 0) { + // Common case: we only tried to index with int and slices. Emit the + // correct augmented assignment op to the sliced value + emitBuiltinCall( + stmt.range(), + *method.graph(), + getAugOp(stmt, /*isTensor=*/true), + slicedArg, + {rhs}, + {}, + /*required=*/true); + } else { + // Special case: we tried to do "advanced indexing". Lower this expr + // into `index` and `index_put_` ops + const auto indices = graph->insertNode( + graph->createList(DynamicType::get(), tensorIndices))->output(); + const auto indexed = + graph->insert(aten::index, {slicedArg, indices}, {}, stmt.range()); + const auto augmented = emitBuiltinCall( + stmt.range(), + *method.graph(), + getAugOp(stmt, /*isTensor=*/true), + indexed, + {rhs}, + {}, + /*required=*/true); + graph->insert( + aten::index_put_, + {slicedArg, indices, augmented}, + {}, + stmt.range()); + } } else { // Otherwise, it should be a list. Lower this expression into: // list.set_item(get_item(idx).add_(value)) @@ -1573,32 +1599,12 @@ private: const auto valueArg = NamedValue(stmt.rhs().range(), "value", emitExpr(stmt.rhs())); - const auto getItem = emitBuiltinCall( - stmt.range(), - *method.graph(), - aten::select, - c10::nullopt, - {listArg, idxArg}, - {}, - /*required=*/true); - - const auto augmentedItem = emitBuiltinCall( - stmt.range(), - *method.graph(), - getAugOp(stmt, isTensorList), - {}, - {getItem, valueArg}, - {}, - /*required=*/true); - - emitBuiltinCall( - stmt.range(), - *method.graph(), - aten::_set_item, - c10::nullopt, - {listArg, idxArg, augmentedItem}, - {}, - /*required=*/true); + const auto getItem = + graph->insert(aten::select, {listArg, idxArg}, {}, stmt.range()); + const auto augmentedItem = graph->insert( + getAugOp(stmt, isTensorList), {getItem, valueArg}, {}, stmt.range()); + graph->insert( + aten::_set_item, {listArg, idxArg, augmentedItem}, {}, stmt.range()); } } @@ -1620,20 +1626,30 @@ private: // If it's a tensor, copy the RHS data into it if (sliceable->type()->isSubtypeOf(DynamicType::get())) { - std::vector args; - // Obtain the sliced value - auto lhsValue = - emitSubscript(lhs.range(), sliceable, lhs.subscript_exprs()); - args.emplace_back(lhs.range(), "t", lhsValue); - args.emplace_back(rhs.loc(), "other", rhs.value(*graph)); - emitBuiltinCall( - stmtRange, - *method.graph(), - aten::copy_, - c10::nullopt, - args, - {}, - true); + std::vector tensorIndices; + Value* sliced; + // Handle multi-dimensional slicing: first emit int/slice indexing + // TODO: the Python equivalent code has special-cased copy_to + // broadcasting to match NumPy semantics (see PR#4853). We can't + // replicate that without knowing the size of the Tensor; so really that + // code should be moved into the aten function + std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing( + lhs.range(), sliceable, lhs.subscript_exprs()); + + const auto slicedArg = NamedValue(lhs.range(), sliced); + if (tensorIndices.size() == 0) { + // Common case: we only tried to index with int and slices. Copy the + // RHS into the resulting tensor. + graph->insert(aten::copy_, {slicedArg, rhs}, {}, stmtRange); + } else { + // Special case: we tried to do "advanced indexing" with a tensor. + // Dispatch to `aten::index_put_`. + const auto indices = graph->insertNode( + graph->createList(DynamicType::get(), tensorIndices))->output(); + + graph->insert( + aten::index_put_, {slicedArg, indices, rhs}, {}, stmtRange); + } // Otherwise, this is a list. Dispatch to aten::_set_item to both select and // assign @@ -1651,14 +1667,8 @@ private: args.emplace_back( lhs.subscript_exprs().range(), "idx", emitExpr(subscript[0])); args.push_back(rhs); - emitBuiltinCall( - stmtRange, - *method.graph(), - aten::_set_item, - c10::nullopt, - args, - {}, - true); + + graph->insert(aten::_set_item, args, {}, stmtRange); } } @@ -2200,6 +2210,13 @@ private: << "Unsupported operation: indexing tensor with unsupported index type " << index->type()->str() << ". Only ints, slices, and tensors are supported."; } + // at::index takes in a TensorList where some tensors can be undefined. + // Convert NULL tensorIndices to undefined tensors to pass to at::index. + for (auto& index : tensor_indices) { + if (index == nullptr) { + index = graph->insertNode(graph->createUndefined())->output(); + } + } return std::make_pair(sliceable, tensor_indices); } @@ -2240,13 +2257,6 @@ private: return sliceable; } - // at::index takes in a TensorList where some tensors can be undefined. - // Convert NULL tensor_indices to undefined tensors to pass to at::index. - for (auto& index : tensor_indices) { - if (index == nullptr) { - index = graph->insertNode(graph->createUndefined())->output(); - } - } return emitIndex(loc, sliceable, tensor_indices); } -- 2.7.4