From 0e435afc3c6a6649510d32a80abce49d7538ecfa Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Fri, 12 Apr 2019 14:23:06 -0700 Subject: [PATCH] Add more debugging helper to net transformer (#19176) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19176 Add some amenities for debugging. Reviewed By: llyfacebook Differential Revision: D14901740 fbshipit-source-id: 2c4018fdbf7e3aba2a754b6b4103a72893c229c2 --- caffe2/opt/backend_transformer_base.cc | 23 +++++++++++++++++++++++ caffe2/opt/backend_transformer_base.h | 8 +++++++- caffe2/opt/onnxifi_transformer.cc | 20 ++------------------ 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/caffe2/opt/backend_transformer_base.cc b/caffe2/opt/backend_transformer_base.cc index 7c33fa6..932f35b 100644 --- a/caffe2/opt/backend_transformer_base.cc +++ b/caffe2/opt/backend_transformer_base.cc @@ -135,4 +135,27 @@ ShapeInfoMap BackendTransformerBase::inferShapes( } return shape_map; } + +void BackendTransformerBase::dumpNet( + const NetDef& pred_net, + const ShapeInfoMap& shape_hints, + const std::string& fname) const { + NetDef shape_net(pred_net); + auto* shape_arg = shape_net.add_arg(); + auto* qshape_arg = shape_net.add_arg(); + shape_arg->set_name("shape_info"); + qshape_arg->set_name("qshape_info"); + for (const auto& kv : shape_hints) { + if (!kv.second.is_quantized) { + auto t = wrapShapeInfoIntoTensorProto(kv.first, kv.second); + t.add_int32_data(static_cast(kv.second.dim_type)); + shape_arg->mutable_tensors()->Add()->CopyFrom(t); + } else { + auto t = wrapShapeInfoIntoQTensorProto(kv.first, kv.second); + t.add_data(static_cast(kv.second.dim_type)); + qshape_arg->mutable_qtensors()->Add()->CopyFrom(t); + } + } + WriteProtoToTextFile(shape_net, "debug_ssa_net.pb_txt"); +} } // namespace caffe2 diff --git a/caffe2/opt/backend_transformer_base.h b/caffe2/opt/backend_transformer_base.h index e845445..3ae3be1 100644 --- a/caffe2/opt/backend_transformer_base.h +++ b/caffe2/opt/backend_transformer_base.h @@ -39,9 +39,15 @@ class BackendTransformerBase { const std::unordered_set& blacklisted_ops) = 0; protected: - // get model ID from the NetDef + // Get model ID from the NetDef std::string getModelId(const NetDef& net); + // Dump the net with shape info + void dumpNet( + const NetDef& pred_net, + const ShapeInfoMap& map, + const std::string& fname) const; + // SSA rewrite the net and return name mapping std::unordered_map ssaRewriteAndMapNames( Workspace* ws, diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 8ec572b..7a1ff8b 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -1027,23 +1027,7 @@ void OnnxifiTransformer::transform( } if (opts_.debug) { - NetDef shape_net(*pred_net); - auto* shape_arg = shape_net.add_arg(); - auto* qshape_arg = shape_net.add_arg(); - shape_arg->set_name("shape_info"); - qshape_arg->set_name("qshape_info"); - for (const auto& kv : shape_hints) { - if (!kv.second.is_quantized) { - auto t = wrapShapeInfoIntoTensorProto(kv.first, kv.second); - t.add_int32_data(static_cast(kv.second.dim_type)); - shape_arg->mutable_tensors()->Add()->CopyFrom(t); - } else { - auto t = wrapShapeInfoIntoQTensorProto(kv.first, kv.second); - t.add_data(static_cast(kv.second.dim_type)); - qshape_arg->mutable_qtensors()->Add()->CopyFrom(t); - } - } - WriteProtoToTextFile(shape_net, "debug_ssa_net.pb_txt"); + dumpNet(*pred_net, shape_hints, "debug_ssa_net.pb_txt"); } // Get backend id @@ -1064,7 +1048,7 @@ void OnnxifiTransformer::transform( net_opt.mutable_device_option()->CopyFrom(pred_net->device_option()); if (opts_.debug) { - WriteProtoToTextFile(net_opt, "debug_full_opt_net.pb_txt"); + dumpNet(*pred_net, shape_hints, "debug_full_opt_net.pb_txt"); } pred_net->Swap(&net_opt); } -- 2.7.4