Add output info when doing onnxGetBackendCompatibility (#14784)
authorYinghai Lu <yinghai@fb.com>
Wed, 5 Dec 2018 05:50:41 +0000 (21:50 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 05:53:32 +0000 (21:53 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14784

TSIA. To give more complete info to `onnxGetBackendCompatibility`.

Reviewed By: bertmaher, rdzhabarov

Differential Revision: D13331989

fbshipit-source-id: 1064b93f7f474788f736e6f0c893dae915c6fb99

caffe2/opt/onnxifi_transformer.cc

index 6338b68..b52833f 100644 (file)
@@ -419,15 +419,23 @@ void OnnxifiTransformer::Transform(
         onnx_model.mutable_graph()->add_node()->CopyFrom(n);
       }
 
-      // Add input shape info
-      std::vector<std::string> input_tmp;
+      // Add input/output shape info
+      std::vector<std::string> io_tmp;
       for (const auto& op_input : op.input()) {
-        input_tmp.emplace_back(op_input);
+        io_tmp.emplace_back(op_input);
       }
-      auto io_vec = ConvertToValueInfo(input_tmp, shape_hints);
+      auto io_vec = ConvertToValueInfo(io_tmp, shape_hints);
       for (const auto& i : io_vec) {
         onnx_model.mutable_graph()->add_input()->CopyFrom(i);
       }
+      io_tmp.clear();
+      for (const auto& op_output : op.output()) {
+        io_tmp.emplace_back(op_output);
+      }
+      io_vec = ConvertToValueInfo(io_tmp, shape_hints);
+      for (const auto& i : io_vec) {
+        onnx_model.mutable_graph()->add_output()->CopyFrom(i);
+      }
 
       std::string onnx_model_str;
       onnx_model.SerializeToString(&onnx_model_str);