Add more debugging helper to net transformer (#19176)
authorYinghai Lu <yinghai@fb.com>
Fri, 12 Apr 2019 21:23:06 +0000 (14:23 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 12 Apr 2019 21:28:37 +0000 (14:28 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19176

Add some amenities for debugging.

Reviewed By: llyfacebook

Differential Revision: D14901740

fbshipit-source-id: 2c4018fdbf7e3aba2a754b6b4103a72893c229c2

caffe2/opt/backend_transformer_base.cc
caffe2/opt/backend_transformer_base.h
caffe2/opt/onnxifi_transformer.cc

index 7c33fa6..932f35b 100644 (file)
@@ -135,4 +135,27 @@ ShapeInfoMap BackendTransformerBase::inferShapes(
   }
   return shape_map;
 }
+
+void BackendTransformerBase::dumpNet(
+    const NetDef& pred_net,
+    const ShapeInfoMap& shape_hints,
+    const std::string& fname) const {
+  NetDef shape_net(pred_net);
+  auto* shape_arg = shape_net.add_arg();
+  auto* qshape_arg = shape_net.add_arg();
+  shape_arg->set_name("shape_info");
+  qshape_arg->set_name("qshape_info");
+  for (const auto& kv : shape_hints) {
+    if (!kv.second.is_quantized) {
+      auto t = wrapShapeInfoIntoTensorProto(kv.first, kv.second);
+      t.add_int32_data(static_cast<int32_t>(kv.second.dim_type));
+      shape_arg->mutable_tensors()->Add()->CopyFrom(t);
+    } else {
+      auto t = wrapShapeInfoIntoQTensorProto(kv.first, kv.second);
+      t.add_data(static_cast<int32_t>(kv.second.dim_type));
+      qshape_arg->mutable_qtensors()->Add()->CopyFrom(t);
+    }
+  }
+  WriteProtoToTextFile(shape_net, "debug_ssa_net.pb_txt");
+}
 } // namespace caffe2
index e845445..3ae3be1 100644 (file)
@@ -39,9 +39,15 @@ class BackendTransformerBase {
       const std::unordered_set<int>& blacklisted_ops) = 0;
 
  protected:
-  // get model ID from the NetDef
+  // Get model ID from the NetDef
   std::string getModelId(const NetDef& net);
 
+  // Dump the net with shape info
+  void dumpNet(
+      const NetDef& pred_net,
+      const ShapeInfoMap& map,
+      const std::string& fname) const;
+
   // SSA rewrite the net and return name mapping
   std::unordered_map<std::string, TensorShape> ssaRewriteAndMapNames(
       Workspace* ws,
index 8ec572b..7a1ff8b 100644 (file)
@@ -1027,23 +1027,7 @@ void OnnxifiTransformer::transform(
   }
 
   if (opts_.debug) {
-    NetDef shape_net(*pred_net);
-    auto* shape_arg = shape_net.add_arg();
-    auto* qshape_arg = shape_net.add_arg();
-    shape_arg->set_name("shape_info");
-    qshape_arg->set_name("qshape_info");
-    for (const auto& kv : shape_hints) {
-      if (!kv.second.is_quantized) {
-        auto t = wrapShapeInfoIntoTensorProto(kv.first, kv.second);
-        t.add_int32_data(static_cast<int32_t>(kv.second.dim_type));
-        shape_arg->mutable_tensors()->Add()->CopyFrom(t);
-      } else {
-        auto t = wrapShapeInfoIntoQTensorProto(kv.first, kv.second);
-        t.add_data(static_cast<int32_t>(kv.second.dim_type));
-        qshape_arg->mutable_qtensors()->Add()->CopyFrom(t);
-      }
-    }
-    WriteProtoToTextFile(shape_net, "debug_ssa_net.pb_txt");
+    dumpNet(*pred_net, shape_hints, "debug_ssa_net.pb_txt");
   }
 
   // Get backend id
@@ -1064,7 +1048,7 @@ void OnnxifiTransformer::transform(
   net_opt.mutable_device_option()->CopyFrom(pred_net->device_option());
 
   if (opts_.debug) {
-    WriteProtoToTextFile(net_opt, "debug_full_opt_net.pb_txt");
+    dumpNet(*pred_net, shape_hints, "debug_full_opt_net.pb_txt");
   }
   pred_net->Swap(&net_opt);
 }