std::unordered_set<std::string> used_outputs;
std::vector<std::string> boundary_inputs;
std::vector<std::string> boundary_outputs;
+ std::unordered_set<std::string> reshape_info;
// nodes are in topological order, so we just need to iterate
for (const auto& n : results.first) {
onnx_model.mutable_graph()->add_node()->CopyFrom(n);
for (const auto& o : n.output()) {
used_outputs.emplace(o);
}
+
+ // For reshape node, if it has more than 1 inputs, we need to feed the
+ // second input which contains the shape info
+ if (n.op_type() == "Reshape" && n.input_size() > 1) {
+ reshape_info.emplace(n.input(1));
+ }
}
// Second iteration to account all the boundary outputs, which is a newly
// seen output and is not referred as input before
extra_shape_hints;
for (const auto& t : results.second) {
extra_shape_hints.emplace(t.name(), onnx::ExtraTypeProto(t));
+ if (reshape_info.count(t.name())) {
+ onnx_model.mutable_graph()->add_initializer()->CopyFrom(t);
+ }
}
// Add input/output shape info