return (d, e, f)
)JIT";
+const auto reshape_inplace_script_1 = R"JIT(
+ def forward(self, inp: Tensor, shape: List[int], flag: bool):
+ if flag:
+ a = inp + inp
+ b = a.reshape(shape)
+ c = b.sigmoid()
+ else:
+ a = inp * inp
+ b = a.sigmoid_()
+ c = b.reshape(shape)
+ d = c + c
+ e = a + a
+ f = b + b
+ return (d, e, f)
+)JIT";
+
const auto sigmoid_inplace_script = R"JIT(
def forward(self, inp: Tensor):
a = torch.sigmoid(inp, out=inp).clone()
TEST(StaticRuntime, InPlace) {
EXPECT_TRUE(testHasInplaceOp(reshape_inplace_script));
+ EXPECT_TRUE(testHasInplaceOp(reshape_inplace_script_1));
EXPECT_TRUE(testHasInplaceOp(sigmoid_inplace_script));
EXPECT_FALSE(testHasInplaceOp(sigmoid_out_script));
}
bool HasInplaceOp(Block* block, const AliasDb& alias_db) {
for (auto* node : block->nodes()) {
for (Block* sub_block : node->blocks()) {
- return HasInplaceOp(sub_block, alias_db);
+ if (HasInplaceOp(sub_block, alias_db)) {
+ return true;
+ }
}
auto inputs = node->inputs();
// check if node modifies inputs (both inplace ops and certain out variants