TEST(
ProcessedNode,
- VerifyOutputsNotOverlappingWithImmutableInputsWithImmutableArguments) {
+ VerifyNoMemoryOverlapWithImmutableInputsWithImmutableArguments) {
script::Module module("module");
// Not using out= variant.
module.define(sigmoid_script);
ProcessedNode pnode(sigmoid_node, std::move(ivalue_inputs), true);
pnode.Output(0) = b;
- EXPECT_TRUE(pnode.verify_outputs_not_overlapping_with_immutable_inputs());
+ EXPECT_TRUE(pnode.verify_no_memory_overlap());
pnode.Output(0) = a;
- EXPECT_FALSE(pnode.verify_outputs_not_overlapping_with_immutable_inputs());
+ EXPECT_FALSE(pnode.verify_no_memory_overlap());
}
TEST(
ProcessedNode,
- VerifyOutputsNotOverlappingWithImmutableInputsWithMutableArguments) {
+ VerifyNoMemoryOverlapWithImmutableInputsWithMutableArguments) {
script::Module module("module");
// Using out= variant.
module.define(sigmoid_inplace_script);
ProcessedNode pnode(sigmoid_node, std::move(ivalue_inputs), true);
pnode.Output(0) = b;
- EXPECT_TRUE(pnode.verify_outputs_not_overlapping_with_immutable_inputs());
+ EXPECT_TRUE(pnode.verify_no_memory_overlap());
pnode.Output(0) = a;
- EXPECT_TRUE(pnode.verify_outputs_not_overlapping_with_immutable_inputs());
+ EXPECT_TRUE(pnode.verify_no_memory_overlap());
+}
+
+TEST(ProcessedNode, VerifyNoMemoryOverlapWithOverlappingOutputs) {
+ auto g = std::make_shared<torch::jit::Graph>();
+ torch::jit::parseIR(
+ R"IR(
+ graph(%0):
+ %1 : Tensor, %2 : Tensor = prim::ListUnpack(%0)
+ return (%1, %2))IR",
+ g.get());
+ torch::jit::StaticModule smodule(g);
+ Node* list_unpack_node = getNodeWithKind(smodule, "prim::ListUnpack");
+ {
+ auto a = at::randn({2, 3});
+ IValue ivalue(a);
+ std::vector<const IValue*> inputs{&ivalue};
+ ProcessedNode list_unpack_pnode(list_unpack_node, std::move(inputs), /*enable_out_variant=*/true);
+ ASSERT_EQ(list_unpack_pnode.outputs().size(), 2);
+ EXPECT_TRUE(list_unpack_pnode.verify_no_memory_overlap());
+ }
+ {
+ auto a = at::randn({2, 3});
+ IValue ivalue(a);
+ std::vector<const IValue*> inputs{&ivalue};
+ ProcessedNode list_unpack_pnode(list_unpack_node, std::move(inputs), /*enable_out_variant=*/true);
+ auto b = at::randn({2, 3});
+ list_unpack_pnode.Output(0) = b;
+ list_unpack_pnode.Output(1) = b;
+ EXPECT_FALSE(list_unpack_pnode.verify_no_memory_overlap());
+ }
}
TEST(StaticRuntime, IndividualOps_isinstance) {
}
void ProcessedNode::run() {
- DCHECK(verify_outputs_not_overlapping_with_immutable_inputs());
+ DCHECK(verify_no_memory_overlap());
if (fn_) {
fn_(this);
} else if (native_fn_) {
}
}
-bool ProcessedNode::verify_outputs_not_overlapping_with_immutable_inputs()
- const {
+static bool checkNoMemoryOverlap(const at::Tensor& a, const at::Tensor& b) {
+ at::MemOverlapStatus status = at::get_overlap_status(a, b);
+ if (status == at::MemOverlapStatus::FULL ||
+ status == at::MemOverlapStatus::PARTIAL) {
+ return false;
+ }
+ if (status == at::MemOverlapStatus::TOO_HARD) {
+ LOG(WARNING) << "Detected TOO_HARD memory overlap status";
+ }
+ return true;
+}
+
+bool ProcessedNode::verify_no_memory_overlap() const {
+ for (size_t i = 0; i < outputs_.size(); ++i) {
+ if (!outputs_[i].isTensor()) {
+ continue;
+ }
+ const auto& out0_t = outputs_[i].toTensor();
+ for (size_t j = i + 1; j < outputs_.size(); ++j) {
+ if (!outputs_[j].isTensor()) {
+ continue;
+ }
+ const auto& out1_t = outputs_[j].toTensor();
+ if (!checkNoMemoryOverlap(out0_t, out1_t)) {
+ return false;
+ }
+ }
+ }
+
auto schema = node()->maybeSchema();
if (!schema || schema->is_mutable()) {
return true;
continue;
}
const auto& out_t = out.toTensor();
- at::MemOverlapStatus status = at::get_overlap_status(in_t, out_t);
- if (status != at::MemOverlapStatus::NO) {
+ if (!checkNoMemoryOverlap(in_t, out_t)) {
return false;
}
}