AT_ASSERT(net2.external_input().Get(2) == "c");
AT_ASSERT(net2.external_output().Get(0) == "x");
}
+
+ {
+ // Test that prefix is removed when converting from NetDef to IR and back.
+ caffe2::NetDef net;
+ *net.add_op() = createOperator("MatMul", {"a", "b"}, {"c"});
+ net.add_external_input("a");
+ net.add_external_input("b");
+ net.add_external_output("c");
+ Graph graph;
+ std::unordered_map<std::string, Value*> vmap;
+ convertNetDefToIR(net, &graph, &vmap, "caffe2::");
+
+ caffe2::NetDef net2;
+ convertIRToNetDef(&net2, graph, "caffe2::");
+ // The conversion should remove the prefix if it maches.
+ AT_ASSERT(net2.op(0).type() == "MatMul");
+
+ caffe2::NetDef net3;
+ convertIRToNetDef(&net3, graph, "foo::");
+ // The conversion should still work if the prefix does not match.
+ AT_ASSERT(net3.op(0).type() == "caffe2::MatMul");
+
+ // Prefix shouldn't affect blob names.
+ AT_ASSERT(net2.op(0).input(0) == "a");
+ AT_ASSERT(net2.external_input(0) == "a");
+ AT_ASSERT(net2.external_output(0) == "c");
+ AT_ASSERT(net3.external_input(0) == "a");
+ }
}
+
} // namespace jit
} // namespace torch
}
}
-static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net) {
+static std::string removePrefixIfNeeded(const std::string& name,
+ const std::string& prefix) {
+ if (!name.compare(0, prefix.size(), prefix)) {
+ return name.substr(prefix.size());
+ } else {
+ return name;
+ }
+}
+
+static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net,
+ const std::string& prefix = "") {
caffe2::OperatorDef op;
- op.set_type(node->kind().toQualString());
+ op.set_type(removePrefixIfNeeded(node->kind().toQualString(), prefix));
for (const Value* input : node->inputs()) {
op.add_input(input->uniqueName());
}
*net->add_op() = op;
}
-void convertIRToNetDef(caffe2::NetDef* net, const Graph& g) {
+void convertIRToNetDef(caffe2::NetDef* net, const Graph& g,
+ const std::string& prefix) {
net->mutable_op()->Clear();
for (const Value* value : g.inputs()) {
}
for (const Node* node : g.nodes()) {
- convertNodeToCaffe2Op(node, net);
+ convertNodeToCaffe2Op(node, net, prefix);
}
for (const Value* value : g.outputs()) {
* both formats will converge to PyTorch IR, so for now we try to keep as close
* to it as possible. For short-term applications we might add a separate pass
* that would fold such const-nodes into their users.
+ * \p If Prefix is specified, the prefix will be removed from operator name when
+ * converting from IR to NetDef.
*
* TODO: We might need to do a better job at preserving names of the variables,
* especially external_inputs/external_outputs.
*/
-void convertIRToNetDef(caffe2::NetDef* net, const Graph& graph);
+void convertIRToNetDef(
+ caffe2::NetDef* net,
+ const Graph& graph,
+ const std::string& prefix = "");
} // namespace jit
} // namespace torch