From 172ec4ace520a72729191b95b9f21c651aa5e245 Mon Sep 17 00:00:00 2001 From: Duc Ngo Date: Fri, 22 Mar 2019 11:14:40 -0700 Subject: [PATCH] caffe2 - Util to cleanup external inputs and outputs from a NetDef (#18194) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18194 Add a util method to cleanup external inputs and outputs from a NetDef The following conditions will be met after the modification - No duplicate external inputs - No duplicate external outputs - Going through list of ops in order, all op inputs must be outputs from other ops, or registered as external inputs. - All external outputs must be outputs of some operators. Reviewed By: ZolotukhinM Differential Revision: D14528589 fbshipit-source-id: c8d82fda1946aa3696abcbec869a4a8bb22f09b6 --- caffe2/utils/proto_utils.cc | 54 +++++++++++++++++++++++++++++++++++++++- caffe2/utils/proto_utils.h | 8 ++++++ caffe2/utils/proto_utils_test.cc | 35 ++++++++++++++++++++++++-- 3 files changed, 94 insertions(+), 3 deletions(-) diff --git a/caffe2/utils/proto_utils.cc b/caffe2/utils/proto_utils.cc index 213feb4..40cc1b8 100644 --- a/caffe2/utils/proto_utils.cc +++ b/caffe2/utils/proto_utils.cc @@ -558,4 +558,56 @@ C10_EXPORT Argument* GetMutableArgument( } } -} // namespace caffe2 +C10_EXPORT void cleanupExternalInputsAndOutputs(NetDef* net) { + std::vector oldExternalInputs; + for (const auto& input : net->external_input()) { + oldExternalInputs.emplace_back(input); + } + std::vector oldExternalOutputs; + for (const auto& output : net->external_output()) { + oldExternalOutputs.emplace_back(output); + } + + net->clear_external_input(); + net->clear_external_output(); + + std::set inputSet; + for (const auto& input : oldExternalInputs) { + if (inputSet.count(input)) { + // Prevent duplicate external inputs. + continue; + } + inputSet.insert(input); + net->add_external_input(input); + } + + // Set of blobs that are external inputs or outputs of some operators. + std::set allOutputs(inputSet.begin(), inputSet.end()); + for (const auto& op : net->op()) { + for (const auto& input : op.input()) { + if (inputSet.count(input) || allOutputs.count(input)) { + continue; + } + // Add missing external inputs. + inputSet.insert(input); + net->add_external_input(input); + } + for (const auto& output : op.output()) { + allOutputs.insert(output); + } + } + + std::set outputSet; + for (const auto& output : oldExternalOutputs) { + if (!allOutputs.count(output)) { + continue; + } + if (outputSet.count(output)) { + continue; + } + outputSet.insert(output); + net->add_external_output(output); + } +} + +} // namespace caffe2 diff --git a/caffe2/utils/proto_utils.h b/caffe2/utils/proto_utils.h index 9637836..22ccc63 100644 --- a/caffe2/utils/proto_utils.h +++ b/caffe2/utils/proto_utils.h @@ -329,6 +329,14 @@ bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) { return IsSameDevice(dl, dr); } +// Given a net, modify the external inputs/outputs if necessary so that +// the following conditions are met +// - No duplicate external inputs +// - No duplicate external outputs +// - Going through list of ops in order, all op inputs must be outputs +// from other ops, or registered as external inputs. +// - All external outputs must be outputs of some operators. +CAFFE2_API void cleanupExternalInputsAndOutputs(NetDef* net); } // namespace caffe2 diff --git a/caffe2/utils/proto_utils_test.cc b/caffe2/utils/proto_utils_test.cc index 5d8fb86..1a68769 100644 --- a/caffe2/utils/proto_utils_test.cc +++ b/caffe2/utils/proto_utils_test.cc @@ -1,6 +1,8 @@ -#include "caffe2/utils/proto_utils.h" #include +#include "caffe2/core/test_utils.h" +#include "caffe2/utils/proto_utils.h" + namespace caffe2 { TEST(ProtoUtilsTest, IsSameDevice) { @@ -29,4 +31,33 @@ TEST(ProtoUtilsTest, SimpleReadWrite) { EXPECT_EQ(content, read_back); } -} // namespace caffe2 +TEST(ProtoUtilsTest, CleanupExternalInputsAndOutputs) { + caffe2::NetDef net; + caffe2::testing::NetMutator(&net) + .newOp("op1", {"X1", "X2"}, {"Y"}) + .newOp("op2", {"W", "Y"}, {"Z1", "Z2"}) + .newOp("op3", {"Z2", "W"}, {"O"}) + .externalInputs({"X1", "X3", "X1", "W"}) + .externalOutputs({"O", "Z2", "Z3", "O", "X3"}); + cleanupExternalInputsAndOutputs(&net); + + std::vector externalInputs; + for (const auto& inputName : net.external_input()) { + externalInputs.emplace_back(inputName); + } + // The 2nd X1 is removed because of duplication. + // X2 is added because it should be a missing external input. + std::vector expectedExternalInputs{"X1", "X3", "W", "X2"}; + EXPECT_EQ(externalInputs, expectedExternalInputs); + + std::vector externalOutputs; + for (const auto& outputName : net.external_output()) { + externalOutputs.emplace_back(outputName); + } + // Z3 is removed because it's not an output of any operator in the net. + // The 2nd O is removed because of duplication. + std::vector expectedexternalOutputs{"O", "Z2", "X3"}; + EXPECT_EQ(externalOutputs, expectedexternalOutputs); +} + +} // namespace caffe2 -- 2.7.4