From 552f903c63355906ce46ca6cdba91d70589db501 Mon Sep 17 00:00:00 2001 From: Duc Ngo Date: Tue, 12 Mar 2019 16:54:47 -0700 Subject: [PATCH] JIT IR - Add option to remove prefix string when converting from JIT IR to NetDef (#17931) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17931 When converting from NetDef to IR and back, the prefix string should be removed so the operator types are preserved in caffe2. Reviewed By: ZolotukhinM Differential Revision: D14425954 fbshipit-source-id: 2807e7337b0f804f126970768b1250a4a8c5f35c --- test/cpp/jit/test_netdef_converter.h | 29 +++++++++++++++++++++++++++++ torch/csrc/jit/netdef_converter.cpp | 19 +++++++++++++++---- torch/csrc/jit/netdef_converter.h | 7 ++++++- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/test/cpp/jit/test_netdef_converter.h b/test/cpp/jit/test_netdef_converter.h index 0ac7d92..9bec094 100644 --- a/test/cpp/jit/test_netdef_converter.h +++ b/test/cpp/jit/test_netdef_converter.h @@ -219,6 +219,35 @@ void testNetDefConverter() { AT_ASSERT(net2.external_input().Get(2) == "c"); AT_ASSERT(net2.external_output().Get(0) == "x"); } + + { + // Test that prefix is removed when converting from NetDef to IR and back. + caffe2::NetDef net; + *net.add_op() = createOperator("MatMul", {"a", "b"}, {"c"}); + net.add_external_input("a"); + net.add_external_input("b"); + net.add_external_output("c"); + Graph graph; + std::unordered_map vmap; + convertNetDefToIR(net, &graph, &vmap, "caffe2::"); + + caffe2::NetDef net2; + convertIRToNetDef(&net2, graph, "caffe2::"); + // The conversion should remove the prefix if it maches. + AT_ASSERT(net2.op(0).type() == "MatMul"); + + caffe2::NetDef net3; + convertIRToNetDef(&net3, graph, "foo::"); + // The conversion should still work if the prefix does not match. + AT_ASSERT(net3.op(0).type() == "caffe2::MatMul"); + + // Prefix shouldn't affect blob names. + AT_ASSERT(net2.op(0).input(0) == "a"); + AT_ASSERT(net2.external_input(0) == "a"); + AT_ASSERT(net2.external_output(0) == "c"); + AT_ASSERT(net3.external_input(0) == "a"); + } } + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/netdef_converter.cpp b/torch/csrc/jit/netdef_converter.cpp index 8f77165..2e5dce3 100644 --- a/torch/csrc/jit/netdef_converter.cpp +++ b/torch/csrc/jit/netdef_converter.cpp @@ -190,9 +190,19 @@ static void convertAttrToCaffe2Arg( } } -static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net) { +static std::string removePrefixIfNeeded(const std::string& name, + const std::string& prefix) { + if (!name.compare(0, prefix.size(), prefix)) { + return name.substr(prefix.size()); + } else { + return name; + } +} + +static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net, + const std::string& prefix = "") { caffe2::OperatorDef op; - op.set_type(node->kind().toQualString()); + op.set_type(removePrefixIfNeeded(node->kind().toQualString(), prefix)); for (const Value* input : node->inputs()) { op.add_input(input->uniqueName()); } @@ -207,7 +217,8 @@ static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net) { *net->add_op() = op; } -void convertIRToNetDef(caffe2::NetDef* net, const Graph& g) { +void convertIRToNetDef(caffe2::NetDef* net, const Graph& g, + const std::string& prefix) { net->mutable_op()->Clear(); for (const Value* value : g.inputs()) { @@ -215,7 +226,7 @@ void convertIRToNetDef(caffe2::NetDef* net, const Graph& g) { } for (const Node* node : g.nodes()) { - convertNodeToCaffe2Op(node, net); + convertNodeToCaffe2Op(node, net, prefix); } for (const Value* value : g.outputs()) { diff --git a/torch/csrc/jit/netdef_converter.h b/torch/csrc/jit/netdef_converter.h index d2effb2..03fd014 100644 --- a/torch/csrc/jit/netdef_converter.h +++ b/torch/csrc/jit/netdef_converter.h @@ -28,11 +28,16 @@ void convertNetDefToIR( * both formats will converge to PyTorch IR, so for now we try to keep as close * to it as possible. For short-term applications we might add a separate pass * that would fold such const-nodes into their users. + * \p If Prefix is specified, the prefix will be removed from operator name when + * converting from IR to NetDef. * * TODO: We might need to do a better job at preserving names of the variables, * especially external_inputs/external_outputs. */ -void convertIRToNetDef(caffe2::NetDef* net, const Graph& graph); +void convertIRToNetDef( + caffe2::NetDef* net, + const Graph& graph, + const std::string& prefix = ""); } // namespace jit } // namespace torch -- 2.7.4