}
return shape_map;
}
+
+void BackendTransformerBase::dumpNet(
+ const NetDef& pred_net,
+ const ShapeInfoMap& shape_hints,
+ const std::string& fname) const {
+ NetDef shape_net(pred_net);
+ auto* shape_arg = shape_net.add_arg();
+ auto* qshape_arg = shape_net.add_arg();
+ shape_arg->set_name("shape_info");
+ qshape_arg->set_name("qshape_info");
+ for (const auto& kv : shape_hints) {
+ if (!kv.second.is_quantized) {
+ auto t = wrapShapeInfoIntoTensorProto(kv.first, kv.second);
+ t.add_int32_data(static_cast<int32_t>(kv.second.dim_type));
+ shape_arg->mutable_tensors()->Add()->CopyFrom(t);
+ } else {
+ auto t = wrapShapeInfoIntoQTensorProto(kv.first, kv.second);
+ t.add_data(static_cast<int32_t>(kv.second.dim_type));
+ qshape_arg->mutable_qtensors()->Add()->CopyFrom(t);
+ }
+ }
+ WriteProtoToTextFile(shape_net, "debug_ssa_net.pb_txt");
+}
} // namespace caffe2
const std::unordered_set<int>& blacklisted_ops) = 0;
protected:
- // get model ID from the NetDef
+ // Get model ID from the NetDef
std::string getModelId(const NetDef& net);
+ // Dump the net with shape info
+ void dumpNet(
+ const NetDef& pred_net,
+ const ShapeInfoMap& map,
+ const std::string& fname) const;
+
// SSA rewrite the net and return name mapping
std::unordered_map<std::string, TensorShape> ssaRewriteAndMapNames(
Workspace* ws,
}
if (opts_.debug) {
- NetDef shape_net(*pred_net);
- auto* shape_arg = shape_net.add_arg();
- auto* qshape_arg = shape_net.add_arg();
- shape_arg->set_name("shape_info");
- qshape_arg->set_name("qshape_info");
- for (const auto& kv : shape_hints) {
- if (!kv.second.is_quantized) {
- auto t = wrapShapeInfoIntoTensorProto(kv.first, kv.second);
- t.add_int32_data(static_cast<int32_t>(kv.second.dim_type));
- shape_arg->mutable_tensors()->Add()->CopyFrom(t);
- } else {
- auto t = wrapShapeInfoIntoQTensorProto(kv.first, kv.second);
- t.add_data(static_cast<int32_t>(kv.second.dim_type));
- qshape_arg->mutable_qtensors()->Add()->CopyFrom(t);
- }
- }
- WriteProtoToTextFile(shape_net, "debug_ssa_net.pb_txt");
+ dumpNet(*pred_net, shape_hints, "debug_ssa_net.pb_txt");
}
// Get backend id
net_opt.mutable_device_option()->CopyFrom(pred_net->device_option());
if (opts_.debug) {
- WriteProtoToTextFile(net_opt, "debug_full_opt_net.pb_txt");
+ dumpNet(*pred_net, shape_hints, "debug_full_opt_net.pb_txt");
}
pred_net->Swap(&net_opt);
}