Do not rename net boundary inputs/outputs during ssaRewrite. (#17545)
authorYinghai Lu <yinghai@fb.com>
Wed, 6 Mar 2019 22:24:02 +0000 (14:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 6 Mar 2019 22:26:58 +0000 (14:26 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17545

This diff avoids renaming boundary inputs of net during onnxifi transform.
It also removes adding mappings for the initializer during onnxifi op creation.
Thus gets read of the mapped ws creation during onnxifi op creation.

Reviewed By: zrphercule

Differential Revision: D14243161

fbshipit-source-id: 6eafa920c45f6a6bfacbbb443e8e84cf9778644c

caffe2/onnx/onnx_exporter.cc
caffe2/onnx/onnx_exporter.h
caffe2/onnx/ssa_test.cc
caffe2/operators/onnxifi_op.h
caffe2/opt/backend_transformer_base.cc
caffe2/opt/backend_transformer_base.h
caffe2/opt/onnxifi_transformer.cc

index 9e6a3c2..7a8f584 100644 (file)
@@ -129,71 +129,46 @@ NodeProto AddShapeNode(const std::string& input, const std::string& output) {
 
 std::unordered_map<std::string, std::string> SsaRewrite(
     caffe2::NetDef* init_net,
-    caffe2::NetDef* pred_net,
-    const std::unordered_set<string>& exceptions) {
+    caffe2::NetDef* pred_net) {
   std::unordered_map<std::string, std::string> input_mapping;
   std::unordered_map<std::string, int> blob_versions;
 
-#define REWRITE_EXTERNAL_IO(net, name)                 \
-  for (auto& name : *net->mutable_external_##name()) { \
-    if (exceptions.count(name)) {                      \
-      continue;                                        \
-    }                                                  \
-    auto version = blob_versions.at(name);             \
-    auto new_##name = SsaName(name, version);          \
-    name##_mapping.emplace(new_##name, name);          \
-    name = new_##name;                                 \
-  }
-
   if (init_net) {
-    for (auto& op : *init_net->mutable_op()) {
-      CAFFE_ENFORCE_EQ(op.type().find("GivenTensor"), 0);
-      CAFFE_ENFORCE_EQ(op.type().rfind("Fill"), op.type().size() - 4);
-      CAFFE_ENFORCE_EQ(op.output_size(), 1);
-      const auto& output = op.output(0);
-      op.set_output(0, SsaName(output, 0));
-    }
-    for (const auto& input : init_net->external_input()) {
-      if (exceptions.count(input)) {
-        continue;
-      }
-      blob_versions.emplace(input, 0);
+    // No ssa rewrite is done for init net. The reason being that the output
+    // blobs of init net are what becomes the input blobs of pred_net. Since
+    // inputs of pred_net are not renamed we are not renaming the output of
+    // init_net. Furthermore, the assumption made is that init_net is simple net
+    // with each operator producing the one output and thus not renaming
+    // translates to not renaming the outputs of the init_net. Create identical
+    // mapping for now. This shall be removed eventually.
+    for (const auto& name : init_net->external_input()) {
+      input_mapping.emplace(name, name);
     }
-    for (const auto& output : init_net->external_output()) {
-      if (exceptions.count(output)) {
-        continue;
-      }
-      blob_versions.emplace(output, 0);
-    }
-    REWRITE_EXTERNAL_IO(init_net, input);
     blob_versions.clear();
   }
 
   if (pred_net) {
+    std::unordered_set<std::string> external_outputs;
     for (const auto& input : pred_net->external_input()) {
-      if (exceptions.count(input)) {
-        continue;
-      }
-      blob_versions.emplace(input, 0);
+      // Create identical mapping for now. This shall be removed eventually.
+      input_mapping.emplace(input, input);
+    }
+    for (const auto& output : pred_net->external_output()) {
+      external_outputs.emplace(output);
     }
-    REWRITE_EXTERNAL_IO(pred_net, input);
     for (auto& op : *pred_net->mutable_op()) {
       for (auto& input : *op.mutable_input()) {
-        if (exceptions.count(input)) {
-          continue;
-        }
         const auto it = blob_versions.find(input);
         if (it != blob_versions.end()) {
           input = SsaName(input, it->second);
         } else {
-          blob_versions.emplace(input, 0);
-          input = SsaName(input, 0);
+          // Input blob is not versioned yet.
+          // If it is not versioned yet, it is assumed to be primary input,
+          // Thus skip renaming it.
+          continue;
         }
       }
       for (auto& output : *op.mutable_output()) {
-        if (exceptions.count(output)) {
-          continue;
-        }
         auto it = blob_versions.find(output);
         if (it != blob_versions.end()) {
           it->second += 1;
@@ -205,31 +180,34 @@ std::unordered_map<std::string, std::string> SsaRewrite(
       }
     }
 
-    // Fix the external output name back to original
-    std::unordered_set<std::string> external_outputs;
-    for (const auto& output : pred_net->external_output()) {
-      external_outputs.emplace(output);
+    // For all the renamed blobs find if the blob is one of the external
+    // output. If so add a mapping from it's latest renamed version to its
+    // original name.
+    std::unordered_map<std::string, std::string> renamed_external_outputs;
+    for (const auto it : blob_versions) {
+      if (external_outputs.count(it.first)) {
+        renamed_external_outputs.emplace(
+            SsaName(it.first, it.second), it.first);
+      }
     }
+
+    // Use the mapping to find if the input or output of an op was a renamed
+    // external output. If so replace it with its original name.
     for (auto& op : *pred_net->mutable_op()) {
-      for (auto& output : *op.mutable_output()) {
-        if (exceptions.count(output)) {
-          continue;
-        }
-        auto pos = output.find_last_of('_');
-        CAFFE_ENFORCE_NE(pos, 0);
-        auto basename = output.substr(0, pos);
-        if (!external_outputs.count(basename)) {
-          continue;
+      for (auto& input : *op.mutable_input()) {
+        const auto it = renamed_external_outputs.find(input);
+        if (it != renamed_external_outputs.end()) {
+          input = it->second;
         }
-        auto it = blob_versions.find(basename);
-        if (it != blob_versions.end() &&
-            SsaName(basename, it->second) == output) {
-          output = basename;
+      }
+      for (auto& output : *op.mutable_output()) {
+        const auto it = renamed_external_outputs.find(output);
+        if (it != renamed_external_outputs.end()) {
+          output = it->second;
         }
       }
     }
   }
-#undef REWRITE_EXTERNAL_IO
 
   return input_mapping;
 }
@@ -1176,7 +1154,7 @@ ConvertedResult OnnxExporter::CreateGemmNodes(
                 std::vector<AttributeProto>{
                     MakeAttribute("axis", static_cast<int64_t>(0)),
                 }));
+
     nodes.emplace_back(MakeNode("Reshape",
                 { gemm_y_output, y_shape },
                 { y }));
index 4e3a286..f922c0d 100644 (file)
@@ -28,9 +28,7 @@ using ConvertedResult =
 // output names for predict net.
 CAFFE2_API std::unordered_map<std::string, std::string> SsaRewrite(
     caffe2::NetDef* init_net,
-    caffe2::NetDef* pred_net,
-    const std::unordered_set<std::string>& exceptions =
-        std::unordered_set<std::string>());
+    caffe2::NetDef* pred_net);
 
 ::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
     caffe2::TensorProto::DataType t);
index 9f83b8f..5e6553e 100644 (file)
@@ -24,12 +24,12 @@ TEST(SsaTest, ConvReluInplace) {
 
   std::unordered_map<std::string, std::string> input_mapping =
       caffe2::onnx::SsaRewrite(nullptr, &net);
-  for (const auto& op : net.op()) {
+  for (const auto& net_op : net.op()) {
     std::unordered_set<std::string> inputs;
-    for (const auto& i : op.input()) {
+    for (const auto& i : net_op.input()) {
       inputs.emplace(i);
     }
-    for (const auto& o : op.output()) {
+    for (const auto& o : net_op.output()) {
       EXPECT_TRUE(inputs.count(o) == 0);
     }
   }
@@ -37,3 +37,46 @@ TEST(SsaTest, ConvReluInplace) {
   EXPECT_EQ("X", input_mapping.at(net.external_input(0)));
   EXPECT_EQ("Y", net.external_output(0));
 }
+
+TEST(SsaTest, FC_FC_FC_InPlace_Output) {
+  caffe2::NetDef net;
+  auto* op = net.add_op();
+  op->set_type("FC");
+  op->add_input("X");
+  op->add_input("W0");
+  op->add_input("b0");
+  op->add_output("Y");
+  op = net.add_op();
+  op->set_type("FC");
+  op->add_input("Y");
+  op->add_input("W1");
+  op->add_input("b1");
+  op->add_output("Y");
+  op = net.add_op();
+  op->set_type("FC");
+  op->add_input("Y");
+  op->add_input("W2");
+  op->add_input("b2");
+  op->add_output("Z");
+  net.add_external_input("X");
+  net.add_external_output("Y");
+  net.add_external_output("Z");
+
+  std::unordered_map<std::string, std::string> input_mapping =
+      caffe2::onnx::SsaRewrite(nullptr, &net);
+  for (const auto& net_op : net.op()) {
+    std::unordered_set<std::string> inputs;
+    for (const auto& i : net_op.input()) {
+      inputs.emplace(i);
+    }
+    for (const auto& o : net_op.output()) {
+      EXPECT_TRUE(inputs.count(o) == 0);
+    }
+  }
+  EXPECT_EQ(net.op(0).output(0), net.op(1).input(0));
+  EXPECT_EQ("Y", net.op(2).input(0));
+  EXPECT_EQ("Y_0", net.op(1).input(0));
+  EXPECT_EQ("X", input_mapping.at(net.external_input(0)));
+  EXPECT_EQ("Y", net.external_output(0));
+  EXPECT_EQ("Z", net.external_output(1));
+}
index b81f5b9..e3f2ba5 100644 (file)
@@ -75,20 +75,15 @@ class OnnxifiOp final : public Operator<Context> {
     // map the weight names
     auto initializers =
         this->template GetRepeatedArgument<std::string>("initializers");
-    CAFFE_ENFORCE_EQ(
-        initializers.size() % 2, 0, "initializers should come in pairs");
     std::unordered_set<std::string> initializer_set;
-    std::unordered_map<std::string, std::string> input_mapping;
     for (auto it = initializers.begin(); it != initializers.end(); ++it) {
-      auto key = *it++;
-      input_mapping.emplace(key, *it);
+      auto key = *it;
       initializer_set.emplace(key);
     }
-    Workspace mapped_ws(ws, input_mapping);
     std::vector<std::string> weight_names;
     std::vector<std::vector<uint64_t>> weight_shapes;
     auto weight_descs = buildInitializationList(
-        &mapped_ws, &initializer_set, &weight_names, &weight_shapes);
+        ws, &initializer_set, &weight_names, &weight_shapes);
 
     BuildBackendAndGraph(property_pointers, onnx_model_str, weight_descs);
   }
index 0c116b4..f9e3ecb 100644 (file)
@@ -39,35 +39,16 @@ std::unordered_map<std::string, TensorShape>
 BackendTransformerBase::ssaRewriteAndMapNames(
     Workspace* ws,
     NetDef* pred_net,
-    const std::unordered_set<std::string>& weights,
     const std::unordered_map<std::string, TensorShape>& input_shape_hints) {
-  // Make sure weights do not contain output of any op.
-  for (const auto& op : pred_net->op()) {
-    for (const auto& output : op.output()) {
-      CAFFE_ENFORCE_EQ(
-          weights.count(output),
-          0,
-          "Weight ",
-          output,
-          " shouldn't appear in the output");
-    }
-  }
-  input_mapping_ = onnx::SsaRewrite(nullptr, pred_net, weights);
+  input_mapping_ = onnx::SsaRewrite(nullptr, pred_net);
   // Annote the ops with net position
   AnnotateOpIndex(pred_net);
 
-  // Need to add mapping for weights. This will be used to create new workspace
-  // with mapped weights.
-  for (const auto& w : weights) {
-    input_mapping_.emplace(w, w);
-  }
-
   // Since we are going to create a mapped workspace, we need to make sure that
   // the parent workspace has the mapped blob names. If the blobs don't exist
   // (usually such blobs are input tensor names), we exclude them from mapping.
   std::vector<std::string> exclude_mapping;
   for (const auto kv : input_mapping_) {
-    reverse_input_mapping_.emplace(kv.second, kv.first);
     if (!ws->HasBlob(kv.second)) {
       exclude_mapping.emplace_back(kv.first);
     }
@@ -75,14 +56,10 @@ BackendTransformerBase::ssaRewriteAndMapNames(
   for (const auto& i : exclude_mapping) {
     input_mapping_.erase(i);
   }
+
   std::unordered_map<std::string, TensorShape> shape_hints_mapped;
   for (const auto& kv : input_shape_hints) {
-    const auto it = reverse_input_mapping_.find(kv.first);
-    if (it != reverse_input_mapping_.end()) {
-      shape_hints_mapped.emplace(it->second, kv.second);
-    } else {
-      shape_hints_mapped.emplace(kv.first, kv.second);
-    }
+    shape_hints_mapped.emplace(kv.first, kv.second);
   }
   return shape_hints_mapped;
 }
index 1d3c59f..a7281a8 100644 (file)
@@ -46,7 +46,6 @@ class BackendTransformerBase {
   std::unordered_map<std::string, TensorShape> ssaRewriteAndMapNames(
       Workspace* ws,
       NetDef* pred_net,
-      const std::unordered_set<std::string>& weights,
       const std::unordered_map<std::string, TensorShape>& input_shape_hints);
 
   // Wrap TensorShape into TensorProto
index 771c152..0d17552 100644 (file)
@@ -357,7 +357,6 @@ OperatorDef OnnxifiTransformer::BuildOnnxifiOp(
   initializers_arg->set_name("initializers");
   for (const auto& s : initialization_list) {
     initializers_arg->add_strings(s);
-    initializers_arg->add_strings(input_mapping_.at(s));
   }
 
   // Add the input/output
@@ -576,9 +575,6 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
 
       // Add mappings
       extra_weights.emplace_back(t.name());
-      CAFFE_ENFORCE(
-          input_mapping_.emplace(t.name(), t.name()).second,
-          c10::str("Tensor ", t.name(), " already exists in the workspace"));
     }
   }
 
@@ -948,9 +944,15 @@ void OnnxifiTransformer::transform(
 
   // SSA Rewrite the net
   auto shape_hints_mapped =
-      ssaRewriteAndMapNames(ws, pred_net, weights, input_shape_hints);
+      ssaRewriteAndMapNames(ws, pred_net, input_shape_hints);
 
   // Populate shape info
+  // TODO(yingz): We should not need to create mapped_ws since we did not change
+  // any input mappings during ssarewrite. However this is here for the
+  // following reason: BlackBoxPredictor calls RunNetOnce before onnxifi to
+  // populate dimension info. However during this, it was observed, that new
+  // blob for output is created. This causes problem if inferShape uses original
+  // ws since it does not expect the output blob to be present.
   Workspace mapped_ws(ws, input_mapping_);
   ShapeInfoMap shape_hints = inferShapes(
       &mapped_ws, pred_net, shape_hints_mapped, opts_.bound_shape_spec);