From ecc17fe3dd03962f36a989659336e42de86a38ca Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Tue, 4 Dec 2018 21:50:41 -0800 Subject: [PATCH] Add output info when doing onnxGetBackendCompatibility (#14784) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14784 TSIA. To give more complete info to `onnxGetBackendCompatibility`. Reviewed By: bertmaher, rdzhabarov Differential Revision: D13331989 fbshipit-source-id: 1064b93f7f474788f736e6f0c893dae915c6fb99 --- caffe2/opt/onnxifi_transformer.cc | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 6338b68..b52833f 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -419,15 +419,23 @@ void OnnxifiTransformer::Transform( onnx_model.mutable_graph()->add_node()->CopyFrom(n); } - // Add input shape info - std::vector input_tmp; + // Add input/output shape info + std::vector io_tmp; for (const auto& op_input : op.input()) { - input_tmp.emplace_back(op_input); + io_tmp.emplace_back(op_input); } - auto io_vec = ConvertToValueInfo(input_tmp, shape_hints); + auto io_vec = ConvertToValueInfo(io_tmp, shape_hints); for (const auto& i : io_vec) { onnx_model.mutable_graph()->add_input()->CopyFrom(i); } + io_tmp.clear(); + for (const auto& op_output : op.output()) { + io_tmp.emplace_back(op_output); + } + io_vec = ConvertToValueInfo(io_tmp, shape_hints); + for (const auto& i : io_vec) { + onnx_model.mutable_graph()->add_output()->CopyFrom(i); + } std::string onnx_model_str; onnx_model.SerializeToString(&onnx_model_str); -- 2.7.4