#include <test/cpp/jit/test_alias_analysis.h>
#include <test/cpp/jit/test_misc.h>
+#include <test/cpp/jit/test_netdef_converter.h>
using namespace torch;
using namespace torch::jit;
JIT_TEST(AliasAnalysis)
JIT_TEST(AliasTracker)
+JIT_TEST(NetDefConverter)
+
JIT_TEST(THNNConv)
JIT_TEST(ATenNativeBatchNorm)
#include <test/cpp/jit/test_alias_analysis.h>
#include <test/cpp/jit/test_misc.h>
+#include <test/cpp/jit/test_netdef_converter.h>
#include <sstream>
#include <string>
testRegisterFusionCachesKernel();
testAliasAnalysis();
testAliasTracker();
+ testNetDefConverter(out);
return out.str();
}
--- /dev/null
+#pragma once
+
+#include <torch/csrc/jit/netdef_converter.h>
+#include "test/cpp/jit/test_base.h"
+
+#include <sstream>
+#include <string>
+
+namespace torch {
+namespace jit {
+
+void testNetDefConverter(std::ostream& out = std::cout) {
+ {
+ // Check a simple net conversion back and forth.
+
+ // Create a simple graph:
+ // graph(%0 : Tensor
+ // %1 : Tensor) {
+ // %2 : Tensor = aten::mul(%0, %1)
+ // %3 : int = prim::Constant[value=1]()
+ // %4 : Tensor = aten::add(%0, %2, %3)
+ // return (%2, %4);
+ // }
+ auto graph = std::make_shared<Graph>();
+ auto a = graph->addInput();
+ auto b = graph->addInput();
+ auto c = graph->insert(aten::mul, {a, b});
+ auto d = graph->insert(aten::add, {a, c});
+ graph->registerOutput(c);
+ graph->registerOutput(d);
+
+ // Convert it to netdef and check the result
+ caffe2::NetDef net;
+ convertIRToNetDef(&net, *graph);
+ AT_ASSERT(net.op().size() == 3);
+ AT_ASSERT(net.external_input().size() == 2);
+ AT_ASSERT(net.external_output().size() == 2);
+
+ const caffe2::OperatorDef& MulOp = net.op().Get(0);
+ AT_ASSERT(MulOp.input().size() == 2);
+ AT_ASSERT(MulOp.input().Get(0) == net.external_input().Get(0));
+ AT_ASSERT(MulOp.input().Get(1) == net.external_input().Get(1));
+ AT_ASSERT(MulOp.output().size() == 1);
+
+ const caffe2::OperatorDef& ConstNode = net.op().Get(1);
+ AT_ASSERT(ConstNode.input().size() == 0);
+ AT_ASSERT(ConstNode.output().size() == 1);
+ AT_ASSERT(ConstNode.arg().size() == 1);
+ AT_ASSERT(ConstNode.arg().Get(0).name() == "value");
+ AT_ASSERT(ConstNode.arg().Get(0).i() == 1);
+
+ const caffe2::OperatorDef& AddOp = net.op().Get(2);
+ AT_ASSERT(AddOp.input().size() == 3);
+ AT_ASSERT(AddOp.input().Get(0) == net.external_input().Get(0));
+ AT_ASSERT(AddOp.input().Get(1) == MulOp.output().Get(0));
+ AT_ASSERT(AddOp.input().Get(2) == ConstNode.output().Get(0));
+
+ AT_ASSERT(net.external_output().Get(0) == MulOp.output().Get(0));
+ AT_ASSERT(net.external_output().Get(1) == AddOp.output().Get(0));
+
+ // Convert NetDef back to IR and check if we get the original.
+ Graph graph2;
+ std::unordered_map<std::string, Value*> vmap;
+ convertNetDefToIR(net, &graph2, &vmap);
+
+ Node* mul = graph2.outputs()[0]->node();
+ Node* add = graph2.outputs()[1]->node();
+ AT_ASSERT(mul->kind() == c->node()->kind());
+ AT_ASSERT(add->kind() == d->node()->kind());
+ AT_ASSERT(mul->inputs()[0] == graph2.inputs()[0]);
+ AT_ASSERT(mul->inputs()[1] == graph2.inputs()[1]);
+ AT_ASSERT(add->inputs()[0] == graph2.inputs()[0]);
+ AT_ASSERT(add->inputs()[1] == graph2.outputs()[0]);
+ }
+ {
+ // Check attributes conversion
+ auto graph = std::make_shared<Graph>();
+ auto a = graph->addInput();
+ auto b = graph->addInput();
+ Node* node =
+ graph->create(Symbol::fromQualString("test::some_op"), {a, b}, 2);
+ graph->insertNode(node);
+
+ node->i_(Symbol::fromQualString("attr::i_attr"), 42);
+ node->f_(Symbol::fromQualString("attr::f_attr"), 3.0);
+ node->s_(Symbol::fromQualString("attr::s_attr"), "Hello!");
+
+ node->is_(Symbol::fromQualString("attr::is_attr"), {14, 18, 7});
+ node->fs_(Symbol::fromQualString("attr::fs_attr"), {2.72, 3.14});
+ node->ss_(Symbol::fromQualString("attr::ss_attr"), {"Winter", "Summer"});
+
+ graph->registerOutput(node->outputs()[0]);
+ graph->registerOutput(node->outputs()[1]);
+
+ // Convert it to netdef and check the result
+ caffe2::NetDef net;
+ convertIRToNetDef(&net, *graph);
+ const caffe2::OperatorDef& Op = net.op().Get(0);
+ AT_ASSERT(Op.arg().Get(0).name() == "i_attr");
+ AT_ASSERT(Op.arg().Get(0).i() == 42);
+ AT_ASSERT(Op.arg().Get(1).name() == "f_attr");
+ AT_ASSERT(Op.arg().Get(1).f() == 3.0);
+ AT_ASSERT(Op.arg().Get(2).name() == "s_attr");
+ AT_ASSERT(Op.arg().Get(2).s() == "Hello!");
+
+ AT_ASSERT(Op.arg().Get(3).name() == "is_attr");
+ AT_ASSERT(Op.arg().Get(3).ints().size() == 3);
+ AT_ASSERT(Op.arg().Get(3).ints().Get(0) == 14);
+ AT_ASSERT(Op.arg().Get(3).ints().Get(1) == 18);
+ AT_ASSERT(Op.arg().Get(3).ints().Get(2) == 7);
+
+ AT_ASSERT(Op.arg().Get(4).name() == "fs_attr");
+ AT_ASSERT(Op.arg().Get(4).floats().size() == 2);
+ AT_ASSERT(fabs(Op.arg().Get(4).floats().Get(0) - 2.72) < 0.001);
+
+ AT_ASSERT(Op.arg().Get(5).name() == "ss_attr");
+ AT_ASSERT(Op.arg().Get(5).strings().size() == 2);
+ AT_ASSERT(Op.arg().Get(5).strings().Get(1) == "Summer");
+
+ AT_ASSERT(net.external_output().Get(0) == Op.output().Get(0));
+ AT_ASSERT(net.external_output().Get(1) == Op.output().Get(1));
+
+ // Convert NetDef back to IR and check if we get the original.
+ Graph graph2;
+ std::unordered_map<std::string, Value*> vmap;
+ convertNetDefToIR(net, &graph2, &vmap);
+
+ AT_ASSERT(graph2.outputs()[0]->node() == graph2.outputs()[0]->node());
+ Node* n = graph2.outputs()[0]->node();
+ AT_ASSERT(n->i(Symbol::fromQualString("attr::i_attr")) == 42);
+ AT_ASSERT(n->f(Symbol::fromQualString("attr::f_attr")) == 3.0);
+ AT_ASSERT(n->s(Symbol::fromQualString("attr::s_attr")) == "Hello!");
+ AT_ASSERT(
+ n->is(Symbol::fromQualString("attr::is_attr")) ==
+ std::vector<long>({14, 18, 7}));
+ AT_ASSERT(
+ fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[0] - 2.72) < 0.001);
+ AT_ASSERT(
+ fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[1] - 3.14) < 0.001);
+ AT_ASSERT(
+ n->ss(Symbol::fromQualString("attr::ss_attr")) ==
+ std::vector<std::string>({"Winter", "Summer"}));
+ }
+}
+} // namespace jit
+} // namespace torch
${TORCH_SRC_DIR}/csrc/jit/constants.cpp
${TORCH_SRC_DIR}/csrc/jit/node_hashing.cpp
${TORCH_SRC_DIR}/csrc/jit/ir.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp
${TORCH_SRC_DIR}/csrc/jit/operator.cpp
${TORCH_SRC_DIR}/csrc/jit/caffe2_operator.cpp
${TORCH_SRC_DIR}/csrc/jit/register_c10_ops.cpp
--- /dev/null
+#include <torch/csrc/jit/netdef_converter.h>
+
+namespace torch {
+namespace jit {
+
+static AttributeKind getArgKind(const caffe2::Argument& arg) {
+ if (arg.has_i()) {
+ return AttributeKind::i;
+ } else if (arg.has_f()) {
+ return AttributeKind::f;
+ } else if (arg.has_s()) {
+ return AttributeKind::s;
+ } else if (arg.has_t()) {
+ return AttributeKind::t;
+ } else if (arg.has_n()) {
+ return AttributeKind::g;
+ } else if (arg.ints().size()) {
+ return AttributeKind::is;
+ } else if (arg.floats().size()) {
+ return AttributeKind::fs;
+ } else if (arg.strings().size()) {
+ return AttributeKind::ss;
+ } else if (arg.tensors().size()) {
+ return AttributeKind::ts;
+ } else if (arg.nets().size()) {
+ return AttributeKind::gs;
+ }
+ // Unknown type.
+ abort();
+}
+
+static void convertArg(const caffe2::Argument& arg, Node* node) {
+ std::string attrName = "attr::" + arg.name();
+ auto attrSymbol = Symbol::fromQualString(attrName);
+ AttributeKind kind = getArgKind(arg);
+ switch (kind) {
+ case AttributeKind::i: {
+ node->i_(attrSymbol, (long)arg.i());
+ break;
+ }
+ case AttributeKind::f: {
+ node->f_(attrSymbol, arg.f());
+ break;
+ }
+ case AttributeKind::s: {
+ node->s_(attrSymbol, arg.s());
+ break;
+ }
+ case AttributeKind::is: {
+ std::vector<long> is(arg.ints().begin(), arg.ints().end());
+ node->is_(attrSymbol, is);
+ break;
+ }
+ case AttributeKind::fs: {
+ std::vector<double> fs(arg.floats().begin(), arg.floats().end());
+ node->fs_(attrSymbol, fs);
+ break;
+ }
+ case AttributeKind::ss: {
+ std::vector<std::string> ss(arg.strings().begin(), arg.strings().end());
+ node->ss_(attrSymbol, ss);
+ break;
+ }
+ default: {
+ std::cout << "Unsupported type '" << toString(kind) << "' of attribute '"
+ << attrName << "'"
+ << " in node:" << std::endl;
+ node->dump();
+ abort();
+ }
+ }
+}
+
+void convertNetDefToIR(
+ const caffe2::NetDef& net,
+ Graph* g,
+ std::unordered_map<std::string, Value*>* valueMapPtr,
+ const std::string& prefix) {
+ std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
+ valueMap.clear();
+
+ for (const auto& inputName : net.external_input()) {
+ AT_ASSERT(!valueMap.count(inputName));
+ valueMap[inputName] = g->addInput();
+ }
+
+ for (const auto& op : net.op()) {
+ std::string name = prefix + op.type();
+ Node* node =
+ g->create(Symbol::fromQualString(name), {}, op.output().size());
+ g->insertNode(node);
+
+ for (const auto& input : op.input()) {
+ AT_ASSERT(valueMap.count(input));
+ node->addInput(valueMap[input]);
+ }
+ int idx = 0;
+ for (const auto& output : op.output()) {
+ // If output already exists in valueMap, overwrite it. This way we will
+ // have the last definition of a value named 'output' in valueMap.
+ valueMap[output] = node->outputs()[idx++];
+ }
+ for (const auto& arg : op.arg()) {
+ convertArg(arg, node);
+ }
+ }
+
+ for (const auto& outputName : net.external_output()) {
+ AT_ASSERT(valueMap.count(outputName));
+ g->registerOutput(valueMap.at(outputName));
+ }
+}
+
+static void convertAttrToCaffe2Arg(
+ const Node* node,
+ const Symbol& name,
+ caffe2::Argument* arg) {
+ arg->set_name(name.toUnqualString());
+ switch (node->kindOf(name)) {
+ case AttributeKind::i: {
+ arg->set_i(node->i(name));
+ break;
+ }
+ case AttributeKind::f: {
+ arg->set_f(node->f(name));
+ break;
+ }
+ case AttributeKind::s: {
+ arg->set_s(node->s(name));
+ break;
+ }
+ case AttributeKind::is: {
+ for (long i : node->is(name)) {
+ arg->add_ints(i);
+ }
+ break;
+ }
+ case AttributeKind::fs: {
+ for (double f : node->fs(name)) {
+ arg->add_floats(f);
+ }
+ break;
+ }
+ case AttributeKind::ss: {
+ for (const std::string& s : node->ss(name)) {
+ arg->add_strings(s);
+ }
+ break;
+ }
+ default: {
+ std::cout << "Unsupported type '" << toString(node->kindOf(name))
+ << "' of attribute '" << name.toUnqualString() << "'"
+ << " in node:" << std::endl;
+ node->dump();
+ abort();
+ }
+ }
+}
+
+static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net) {
+ caffe2::OperatorDef op;
+ op.set_type(node->kind().toQualString());
+ for (const Value* input : node->inputs()) {
+ op.add_input(input->uniqueName());
+ }
+ for (const Value* output : node->outputs()) {
+ op.add_output(output->uniqueName());
+ }
+ std::vector<Symbol> names = node->attributeNames();
+ for (const Symbol& name : names) {
+ caffe2::Argument* arg = op.add_arg();
+ convertAttrToCaffe2Arg(node, name, arg);
+ }
+ *net->add_op() = op;
+}
+
+void convertIRToNetDef(caffe2::NetDef* net, const Graph& g) {
+ net->mutable_op()->Clear();
+
+ for (const Value* value : g.inputs()) {
+ net->add_external_input(value->uniqueName());
+ }
+
+ for (const Node* node : g.nodes()) {
+ convertNodeToCaffe2Op(node, net);
+ }
+
+ for (const Value* value : g.outputs()) {
+ net->add_external_output(value->uniqueName());
+ }
+}
+
+} // namespace jit
+} // namespace torch
--- /dev/null
+#pragma once
+#include <caffe2/proto/caffe2_pb.h>
+#include <torch/csrc/jit/ir.h>
+#include <unordered_map>
+
+namespace torch {
+namespace jit {
+
+/** \brief Convert a caffe2 NetDef to PyTorch IR.
+ *
+ * The NetDef \p net is converted and the result is stored in the
+ * torch::jit::Graph \p graph. The function also records name->value map in \p
+ * valueMapPtr. If the original net had several values with the same name, the
+ * map will contain the value for the last definition.
+ * \p Prefix can be used for appending some string to every operator name (e.g.
+ * we can add "caffe2::").
+ */
+void convertNetDefToIR(
+ const caffe2::NetDef& net,
+ Graph* graph,
+ std::unordered_map<std::string, Value*>* valueMapPtr,
+ const std::string& prefix = "");
+
+/** \brief Convert PyTorch IR \p graph to Caffe2 NetDef \p net.
+ *
+ * Note: for constant nodes (prim::Const) we generate a separate op in the net,
+ * which might or might not be what we want. The idea here is that eventually
+ * both formats will converge to PyTorch IR, so for now we try to keep as close
+ * to it as possible. For short-term applications we might add a separate pass
+ * that would fold such const-nodes into their users.
+ *
+ * TODO: We might need to do a better job at preserving names of the variables,
+ * especially external_inputs/external_outputs.
+ */
+void convertIRToNetDef(caffe2::NetDef* net, const Graph& graph);
+
+} // namespace jit
+} // namespace torch