From 717496e6c19b3c5f263f5b7836de3892d9802274 Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Fri, 14 Dec 2018 16:34:11 -0800 Subject: [PATCH] Supply static shape info to Reshape when doing onnxGetCompatibility (#15242) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15242 Newer version ONNX Reshape gets shape info from a tensor. Hence for static backend, we need to provide this info to it when doing `onnxGetCompatibility` too. Reviewed By: jackm321 Differential Revision: D13471959 fbshipit-source-id: 8a58e28edd900b6ad54a1dbd63ff2579fbe0e820 --- caffe2/opt/onnxifi_transformer.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 94d2014..25bb5c6 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -432,6 +432,7 @@ void OnnxifiTransformer::Transform( std::unordered_set used_outputs; std::vector boundary_inputs; std::vector boundary_outputs; + std::unordered_set reshape_info; // nodes are in topological order, so we just need to iterate for (const auto& n : results.first) { onnx_model.mutable_graph()->add_node()->CopyFrom(n); @@ -446,6 +447,12 @@ void OnnxifiTransformer::Transform( for (const auto& o : n.output()) { used_outputs.emplace(o); } + + // For reshape node, if it has more than 1 inputs, we need to feed the + // second input which contains the shape info + if (n.op_type() == "Reshape" && n.input_size() > 1) { + reshape_info.emplace(n.input(1)); + } } // Second iteration to account all the boundary outputs, which is a newly // seen output and is not referred as input before @@ -462,6 +469,9 @@ void OnnxifiTransformer::Transform( extra_shape_hints; for (const auto& t : results.second) { extra_shape_hints.emplace(t.name(), onnx::ExtraTypeProto(t)); + if (reshape_info.count(t.name())) { + onnx_model.mutable_graph()->add_initializer()->CopyFrom(t); + } } // Add input/output shape info -- 2.7.4