caffe2 - Util to cleanup external inputs and outputs from a NetDef (#18194)
authorDuc Ngo <duc@fb.com>
Fri, 22 Mar 2019 18:14:40 +0000 (11:14 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 18:23:03 +0000 (11:23 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18194

Add a util method to cleanup external inputs and outputs from a NetDef

The following conditions will be met after the modification
- No duplicate external inputs
- No duplicate external outputs
- Going through list of ops in order, all op inputs must be outputs
from other ops, or registered as external inputs.
- All external outputs must be outputs of some operators.

Reviewed By: ZolotukhinM

Differential Revision: D14528589

fbshipit-source-id: c8d82fda1946aa3696abcbec869a4a8bb22f09b6

caffe2/utils/proto_utils.cc
caffe2/utils/proto_utils.h
caffe2/utils/proto_utils_test.cc

index 213feb4..40cc1b8 100644 (file)
@@ -558,4 +558,56 @@ C10_EXPORT Argument* GetMutableArgument(
   }
 }
 
-}  // namespace caffe2
+C10_EXPORT void cleanupExternalInputsAndOutputs(NetDef* net) {
+  std::vector<std::string> oldExternalInputs;
+  for (const auto& input : net->external_input()) {
+    oldExternalInputs.emplace_back(input);
+  }
+  std::vector<std::string> oldExternalOutputs;
+  for (const auto& output : net->external_output()) {
+    oldExternalOutputs.emplace_back(output);
+  }
+
+  net->clear_external_input();
+  net->clear_external_output();
+
+  std::set<std::string> inputSet;
+  for (const auto& input : oldExternalInputs) {
+    if (inputSet.count(input)) {
+      // Prevent duplicate external inputs.
+      continue;
+    }
+    inputSet.insert(input);
+    net->add_external_input(input);
+  }
+
+  // Set of blobs that are external inputs or outputs of some operators.
+  std::set<std::string> allOutputs(inputSet.begin(), inputSet.end());
+  for (const auto& op : net->op()) {
+    for (const auto& input : op.input()) {
+      if (inputSet.count(input) || allOutputs.count(input)) {
+        continue;
+      }
+      // Add missing external inputs.
+      inputSet.insert(input);
+      net->add_external_input(input);
+    }
+    for (const auto& output : op.output()) {
+      allOutputs.insert(output);
+    }
+  }
+
+  std::set<std::string> outputSet;
+  for (const auto& output : oldExternalOutputs) {
+    if (!allOutputs.count(output)) {
+      continue;
+    }
+    if (outputSet.count(output)) {
+      continue;
+    }
+    outputSet.insert(output);
+    net->add_external_output(output);
+  }
+}
+
+} // namespace caffe2
index 9637836..22ccc63 100644 (file)
@@ -329,6 +329,14 @@ bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) {
   return IsSameDevice(dl, dr);
 }
 
+// Given a net, modify the external inputs/outputs if necessary so that
+// the following conditions are met
+// - No duplicate external inputs
+// - No duplicate external outputs
+// - Going through list of ops in order, all op inputs must be outputs
+// from other ops, or registered as external inputs.
+// - All external outputs must be outputs of some operators.
+CAFFE2_API void cleanupExternalInputsAndOutputs(NetDef* net);
 
 } // namespace caffe2
 
index 5d8fb86..1a68769 100644 (file)
@@ -1,6 +1,8 @@
-#include "caffe2/utils/proto_utils.h"
 #include <gtest/gtest.h>
 
+#include "caffe2/core/test_utils.h"
+#include "caffe2/utils/proto_utils.h"
+
 namespace caffe2 {
 
 TEST(ProtoUtilsTest, IsSameDevice) {
@@ -29,4 +31,33 @@ TEST(ProtoUtilsTest, SimpleReadWrite) {
   EXPECT_EQ(content, read_back);
 }
 
-}  // namespace caffe2
+TEST(ProtoUtilsTest, CleanupExternalInputsAndOutputs) {
+  caffe2::NetDef net;
+  caffe2::testing::NetMutator(&net)
+      .newOp("op1", {"X1", "X2"}, {"Y"})
+      .newOp("op2", {"W", "Y"}, {"Z1", "Z2"})
+      .newOp("op3", {"Z2", "W"}, {"O"})
+      .externalInputs({"X1", "X3", "X1", "W"})
+      .externalOutputs({"O", "Z2", "Z3", "O", "X3"});
+  cleanupExternalInputsAndOutputs(&net);
+
+  std::vector<std::string> externalInputs;
+  for (const auto& inputName : net.external_input()) {
+    externalInputs.emplace_back(inputName);
+  }
+  // The 2nd X1 is removed because of duplication.
+  // X2 is added because it should be a missing external input.
+  std::vector<std::string> expectedExternalInputs{"X1", "X3", "W", "X2"};
+  EXPECT_EQ(externalInputs, expectedExternalInputs);
+
+  std::vector<std::string> externalOutputs;
+  for (const auto& outputName : net.external_output()) {
+    externalOutputs.emplace_back(outputName);
+  }
+  // Z3 is removed because it's not an output of any operator in the net.
+  // The 2nd O is removed because of duplication.
+  std::vector<std::string> expectedexternalOutputs{"O", "Z2", "X3"};
+  EXPECT_EQ(externalOutputs, expectedexternalOutputs);
+}
+
+} // namespace caffe2