Keep weights name unchanged during SsaRewrite (#16932)
authorKimish Patel <kimishpatel@fb.com>
Mon, 11 Feb 2019 22:32:30 +0000 (14:32 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 11 Feb 2019 22:55:31 +0000 (14:55 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16932

During onnxifi transformation net ssa is rewritten. At the last step the weight
names are changed back to what they were before. The diff keeps the weight
names unchanged thru the process.

Reviewed By: yinghai

Differential Revision: D13972597

fbshipit-source-id: 7c29857f788a674edf625c073b345f2b44267b33

caffe2/onnx/onnx_exporter.cc
caffe2/onnx/onnx_exporter.h
caffe2/opt/onnxifi_transformer.cc
caffe2/opt/onnxifi_transformer.h
caffe2/python/pybind_state.cc

index 962c985..a0d3d38 100644 (file)
@@ -128,12 +128,16 @@ NodeProto AddShapeNode(const std::string& input, const std::string& output) {
 
 std::unordered_map<std::string, std::string> SsaRewrite(
     caffe2::NetDef* init_net,
-    caffe2::NetDef* pred_net) {
+    caffe2::NetDef* pred_net,
+    const std::unordered_set<string>& exceptions) {
   std::unordered_map<std::string, std::string> input_mapping;
   std::unordered_map<std::string, int> blob_versions;
 
 #define REWRITE_EXTERNAL_IO(net, name)                 \
   for (auto& name : *net->mutable_external_##name()) { \
+    if (exceptions.count(name)) {                      \
+      continue;                                        \
+    }                                                  \
     auto version = blob_versions.at(name);             \
     auto new_##name = SsaName(name, version);          \
     name##_mapping.emplace(new_##name, name);          \
@@ -149,9 +153,15 @@ std::unordered_map<std::string, std::string> SsaRewrite(
       op.set_output(0, SsaName(output, 0));
     }
     for (const auto& input : init_net->external_input()) {
+      if (exceptions.count(input)) {
+        continue;
+      }
       blob_versions.emplace(input, 0);
     }
     for (const auto& output : init_net->external_output()) {
+      if (exceptions.count(output)) {
+        continue;
+      }
       blob_versions.emplace(output, 0);
     }
     REWRITE_EXTERNAL_IO(init_net, input);
@@ -160,11 +170,17 @@ std::unordered_map<std::string, std::string> SsaRewrite(
 
   if (pred_net) {
     for (const auto& input : pred_net->external_input()) {
+      if (exceptions.count(input)) {
+        continue;
+      }
       blob_versions.emplace(input, 0);
     }
     REWRITE_EXTERNAL_IO(pred_net, input);
     for (auto& op : *pred_net->mutable_op()) {
       for (auto& input : *op.mutable_input()) {
+        if (exceptions.count(input)) {
+          continue;
+        }
         const auto it = blob_versions.find(input);
         if (it != blob_versions.end()) {
           input = SsaName(input, it->second);
@@ -174,6 +190,9 @@ std::unordered_map<std::string, std::string> SsaRewrite(
         }
       }
       for (auto& output : *op.mutable_output()) {
+        if (exceptions.count(output)) {
+          continue;
+        }
         auto it = blob_versions.find(output);
         if (it != blob_versions.end()) {
           it->second += 1;
@@ -192,6 +211,9 @@ std::unordered_map<std::string, std::string> SsaRewrite(
     }
     for (auto& op : *pred_net->mutable_op()) {
       for (auto& output : *op.mutable_output()) {
+        if (exceptions.count(output)) {
+          continue;
+        }
         auto pos = output.find_last_of('_');
         CAFFE_ENFORCE_NE(pos, 0);
         auto basename = output.substr(0, pos);
index 0c7e65d..7ad8f7c 100644 (file)
@@ -28,7 +28,9 @@ using ConvertedResult =
 // output names for predict net.
 CAFFE2_API std::unordered_map<std::string, std::string> SsaRewrite(
     caffe2::NetDef* init_net,
-    caffe2::NetDef* pred_net);
+    caffe2::NetDef* pred_net,
+    const std::unordered_set<std::string>& exceptions =
+        std::unordered_set<std::string>());
 
 ::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
     caffe2::TensorProto::DataType t);
index 607d6f0..99f9b39 100644 (file)
@@ -715,11 +715,23 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
 CaffeMap<std::string, TensorShape> OnnxifiTransformer::SsaRewriteAndMapNames(
     Workspace* ws,
     NetDef* pred_net,
+    const std::unordered_set<std::string>& weights,
     const std::unordered_map<std::string, TensorShape>& input_shape_hints) {
-  input_mapping_ = onnx::SsaRewrite(nullptr, pred_net);
+  // Make sure weights do not contain output of any op.
+  for (const auto& op : pred_net->op()) {
+    for (const auto& output : op.output()) {
+      CAFFE_ENFORCE_EQ(weights.count(output), 0);
+    }
+  }
+  input_mapping_ = onnx::SsaRewrite(nullptr, pred_net, weights);
   // Annote the ops with net position
   AnnotateOpIndex(pred_net);
   std::vector<std::string> external_inputs;
+  // Need to add mapping for weights. This will be used to create new workspace
+  // with mapped weights.
+  for (const auto& w : weights) {
+    input_mapping_.emplace(w, w);
+  }
   for (const auto kv : input_mapping_) {
     reverse_input_mapping_.emplace(kv.second, kv.first);
     if (!ws->HasBlob(kv.second)) {
@@ -966,6 +978,7 @@ void OnnxifiTransformer::Transform(
     Workspace* ws,
     NetDef* pred_net,
     const std::vector<std::string>& external_inputs,
+    const std::vector<std::string>& weight_names,
     const std::unordered_map<std::string, TensorShape>& input_shape_hints,
     const std::unordered_set<int>& blacklisted_ops) {
   CAFFE_ENFORCE(ws);
@@ -975,9 +988,12 @@ void OnnxifiTransformer::Transform(
   model_id_ = GetModelId(*pred_net);
   onnxifi_op_id_ = 0;
 
+  std::unordered_set<std::string> weights(
+      weight_names.begin(), weight_names.end());
+
   // SSA Rewrite the net
   auto shape_hints_ordered =
-      SsaRewriteAndMapNames(ws, pred_net, input_shape_hints);
+      SsaRewriteAndMapNames(ws, pred_net, weights, input_shape_hints);
 
   // Populate shape info
   Workspace mapped_ws(ws, input_mapping_);
@@ -988,22 +1004,6 @@ void OnnxifiTransformer::Transform(
       opts_.infer_shapes,
       opts_.bound_shape_spec);
 
-  // Figure out what are the weights
-  std::unordered_set<std::string> weights;
-  std::unordered_set<std::string> input_set;
-  for (const auto& i : external_inputs) {
-    const auto it = reverse_input_mapping_.find(i);
-    if (it != reverse_input_mapping_.end()) {
-      input_set.emplace(it->second);
-    }
-  }
-  const std::vector<string>& ws_blobs = mapped_ws.Blobs();
-  for (const auto& s : ws_blobs) {
-    if (!input_set.count(s)) {
-      weights.emplace(s);
-    }
-  }
-
   // Transform the net
   NetDef net_opt = opts_.use_onnx
       ? TransformViaOnnx(ws, pred_net, weights, blacklisted_ops, &shape_hints)
index a7ba90a..a1a8cd8 100644 (file)
@@ -42,6 +42,7 @@ class CAFFE2_API OnnxifiTransformer final {
       Workspace* ws,
       NetDef* pred_net,
       const std::vector<std::string>& external_inputs,
+      const std::vector<std::string>& weight_names,
       const std::unordered_map<std::string, TensorShape>& shape_hints,
       const std::unordered_set<int>& blacklisted_ops);
 
@@ -85,6 +86,7 @@ class CAFFE2_API OnnxifiTransformer final {
   CaffeMap<std::string, TensorShape> SsaRewriteAndMapNames(
       Workspace* ws,
       NetDef* pred_net,
+      const std::unordered_set<std::string>& weights,
       const std::unordered_map<std::string, TensorShape>& input_shape_hints);
 
   // Transform by passing C2 proto to backend
index 55b583a..d4f54d8 100644 (file)
@@ -1628,10 +1628,13 @@ void addGlobalMethods(py::module& m) {
         opts.debug = debug_builder;
         opts.use_onnx = use_onnx;
         OnnxifiTransformer ts(opts);
+        Workspace* curr_ws = GetCurrentWorkspace();
+        auto weight_names = curr_ws->Blobs();
         ts.Transform(
-            GetCurrentWorkspace(),
+            curr_ws,
             &pred_net,
             external_inputs,
+            weight_names,
             tensor_shapes,
             std::unordered_set<int>());
         std::string pred_net_str2;