Supply static shape info to Reshape when doing onnxGetCompatibility (#15242)
authorYinghai Lu <yinghai@fb.com>
Sat, 15 Dec 2018 00:34:11 +0000 (16:34 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 15 Dec 2018 00:37:39 +0000 (16:37 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15242

Newer version ONNX Reshape gets shape info from a tensor. Hence for static backend, we need to provide this info to it when doing `onnxGetCompatibility` too.

Reviewed By: jackm321

Differential Revision: D13471959

fbshipit-source-id: 8a58e28edd900b6ad54a1dbd63ff2579fbe0e820

caffe2/opt/onnxifi_transformer.cc

index 94d2014..25bb5c6 100644 (file)
@@ -432,6 +432,7 @@ void OnnxifiTransformer::Transform(
       std::unordered_set<std::string> used_outputs;
       std::vector<std::string> boundary_inputs;
       std::vector<std::string> boundary_outputs;
+      std::unordered_set<std::string> reshape_info;
       // nodes are in topological order, so we just need to iterate
       for (const auto& n : results.first) {
         onnx_model.mutable_graph()->add_node()->CopyFrom(n);
@@ -446,6 +447,12 @@ void OnnxifiTransformer::Transform(
         for (const auto& o : n.output()) {
           used_outputs.emplace(o);
         }
+
+        // For reshape node, if it has more than 1 inputs, we need to feed the
+        // second input which contains the shape info
+        if (n.op_type() == "Reshape" && n.input_size() > 1) {
+          reshape_info.emplace(n.input(1));
+        }
       }
       // Second iteration to account all the boundary outputs, which is a newly
       // seen output and is not referred as input before
@@ -462,6 +469,9 @@ void OnnxifiTransformer::Transform(
           extra_shape_hints;
       for (const auto& t : results.second) {
         extra_shape_hints.emplace(t.name(), onnx::ExtraTypeProto(t));
+        if (reshape_info.count(t.name())) {
+          onnx_model.mutable_graph()->add_initializer()->CopyFrom(t);
+        }
       }
 
       // Add input/output shape info