#include <test/cpp/jit/test_alias_analysis.h>
#include <test/cpp/jit/test_misc.h>
-#include <test/cpp/jit/test_netdef_converter.h>
using namespace torch;
using namespace torch::jit;
JIT_TEST(AliasAnalysis)
JIT_TEST(AliasTracker)
-JIT_TEST(NetDefConverter)
-
JIT_TEST(THNNConv)
JIT_TEST(ATenNativeBatchNorm)
#include <test/cpp/jit/test_alias_analysis.h>
#include <test/cpp/jit/test_misc.h>
-#include <test/cpp/jit/test_netdef_converter.h>
#include <sstream>
#include <string>
testRegisterFusionCachesKernel();
testAliasAnalysis();
testAliasTracker();
- testNetDefConverter(out);
return out.str();
}
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/netdef_converter.h>
-#include "test/cpp/jit/test_base.h"
-
-#include <sstream>
-#include <string>
-
-namespace torch {
-namespace jit {
-
-void testNetDefConverter(std::ostream& out = std::cout) {
- {
- // Check a simple net conversion back and forth.
-
- // Create a simple graph:
- // graph(%0 : Tensor
- // %1 : Tensor) {
- // %2 : Tensor = aten::mul(%0, %1)
- // %3 : int = prim::Constant[value=1]()
- // %4 : Tensor = aten::add(%0, %2, %3)
- // return (%2, %4);
- // }
- auto graph = std::make_shared<Graph>();
- auto a = graph->addInput();
- auto b = graph->addInput();
- auto c = graph->insert(aten::mul, {a, b});
- auto d = graph->insert(aten::add, {a, c});
- graph->registerOutput(c);
- graph->registerOutput(d);
-
- // Convert it to netdef and check the result
- caffe2::NetDef net;
- convertIRToNetDef(&net, *graph);
- AT_ASSERT(net.op().size() == 3);
- AT_ASSERT(net.external_input().size() == 2);
- AT_ASSERT(net.external_output().size() == 2);
-
- const caffe2::OperatorDef& MulOp = net.op().Get(0);
- AT_ASSERT(MulOp.input().size() == 2);
- AT_ASSERT(MulOp.input().Get(0) == net.external_input().Get(0));
- AT_ASSERT(MulOp.input().Get(1) == net.external_input().Get(1));
- AT_ASSERT(MulOp.output().size() == 1);
-
- const caffe2::OperatorDef& ConstNode = net.op().Get(1);
- AT_ASSERT(ConstNode.input().size() == 0);
- AT_ASSERT(ConstNode.output().size() == 1);
- AT_ASSERT(ConstNode.arg().size() == 1);
- AT_ASSERT(ConstNode.arg().Get(0).name() == "value");
- AT_ASSERT(ConstNode.arg().Get(0).i() == 1);
-
- const caffe2::OperatorDef& AddOp = net.op().Get(2);
- AT_ASSERT(AddOp.input().size() == 3);
- AT_ASSERT(AddOp.input().Get(0) == net.external_input().Get(0));
- AT_ASSERT(AddOp.input().Get(1) == MulOp.output().Get(0));
- AT_ASSERT(AddOp.input().Get(2) == ConstNode.output().Get(0));
-
- AT_ASSERT(net.external_output().Get(0) == MulOp.output().Get(0));
- AT_ASSERT(net.external_output().Get(1) == AddOp.output().Get(0));
-
- // Convert NetDef back to IR and check if we get the original.
- Graph graph2;
- std::unordered_map<std::string, Value*> vmap;
- convertNetDefToIR(net, &graph2, &vmap);
-
- Node* mul = graph2.outputs()[0]->node();
- Node* add = graph2.outputs()[1]->node();
- AT_ASSERT(mul->kind() == c->node()->kind());
- AT_ASSERT(add->kind() == d->node()->kind());
- AT_ASSERT(mul->inputs()[0] == graph2.inputs()[0]);
- AT_ASSERT(mul->inputs()[1] == graph2.inputs()[1]);
- AT_ASSERT(add->inputs()[0] == graph2.inputs()[0]);
- AT_ASSERT(add->inputs()[1] == graph2.outputs()[0]);
- }
- {
- // Check attributes conversion
- auto graph = std::make_shared<Graph>();
- auto a = graph->addInput();
- auto b = graph->addInput();
- Node* node =
- graph->create(Symbol::fromQualString("test::some_op"), {a, b}, 2);
- graph->insertNode(node);
-
- node->i_(Symbol::fromQualString("attr::i_attr"), 42);
- node->f_(Symbol::fromQualString("attr::f_attr"), 3.0);
- node->s_(Symbol::fromQualString("attr::s_attr"), "Hello!");
-
- node->is_(Symbol::fromQualString("attr::is_attr"), {14, 18, 7});
- node->fs_(Symbol::fromQualString("attr::fs_attr"), {2.72, 3.14});
- node->ss_(Symbol::fromQualString("attr::ss_attr"), {"Winter", "Summer"});
-
- graph->registerOutput(node->outputs()[0]);
- graph->registerOutput(node->outputs()[1]);
-
- // Convert it to netdef and check the result
- caffe2::NetDef net;
- convertIRToNetDef(&net, *graph);
- const caffe2::OperatorDef& Op = net.op().Get(0);
- AT_ASSERT(Op.arg().Get(0).name() == "i_attr");
- AT_ASSERT(Op.arg().Get(0).i() == 42);
- AT_ASSERT(Op.arg().Get(1).name() == "f_attr");
- AT_ASSERT(Op.arg().Get(1).f() == 3.0);
- AT_ASSERT(Op.arg().Get(2).name() == "s_attr");
- AT_ASSERT(Op.arg().Get(2).s() == "Hello!");
-
- AT_ASSERT(Op.arg().Get(3).name() == "is_attr");
- AT_ASSERT(Op.arg().Get(3).ints().size() == 3);
- AT_ASSERT(Op.arg().Get(3).ints().Get(0) == 14);
- AT_ASSERT(Op.arg().Get(3).ints().Get(1) == 18);
- AT_ASSERT(Op.arg().Get(3).ints().Get(2) == 7);
-
- AT_ASSERT(Op.arg().Get(4).name() == "fs_attr");
- AT_ASSERT(Op.arg().Get(4).floats().size() == 2);
- AT_ASSERT(fabs(Op.arg().Get(4).floats().Get(0) - 2.72) < 0.001);
-
- AT_ASSERT(Op.arg().Get(5).name() == "ss_attr");
- AT_ASSERT(Op.arg().Get(5).strings().size() == 2);
- AT_ASSERT(Op.arg().Get(5).strings().Get(1) == "Summer");
-
- AT_ASSERT(net.external_output().Get(0) == Op.output().Get(0));
- AT_ASSERT(net.external_output().Get(1) == Op.output().Get(1));
-
- // Convert NetDef back to IR and check if we get the original.
- Graph graph2;
- std::unordered_map<std::string, Value*> vmap;
- convertNetDefToIR(net, &graph2, &vmap);
-
- AT_ASSERT(graph2.outputs()[0]->node() == graph2.outputs()[0]->node());
- Node* n = graph2.outputs()[0]->node();
- AT_ASSERT(n->i(Symbol::fromQualString("attr::i_attr")) == 42);
- AT_ASSERT(n->f(Symbol::fromQualString("attr::f_attr")) == 3.0);
- AT_ASSERT(n->s(Symbol::fromQualString("attr::s_attr")) == "Hello!");
- AT_ASSERT(
- n->is(Symbol::fromQualString("attr::is_attr")) ==
- std::vector<long>({14, 18, 7}));
- AT_ASSERT(
- fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[0] - 2.72) < 0.001);
- AT_ASSERT(
- fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[1] - 3.14) < 0.001);
- AT_ASSERT(
- n->ss(Symbol::fromQualString("attr::ss_attr")) ==
- std::vector<std::string>({"Winter", "Summer"}));
- }
-}
-} // namespace jit
-} // namespace torch
${TORCH_SRC_DIR}/csrc/jit/constants.cpp
${TORCH_SRC_DIR}/csrc/jit/node_hashing.cpp
${TORCH_SRC_DIR}/csrc/jit/ir.cpp
- ${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp
${TORCH_SRC_DIR}/csrc/jit/operator.cpp
${TORCH_SRC_DIR}/csrc/jit/caffe2_operator.cpp
${TORCH_SRC_DIR}/csrc/jit/register_c10_ops.cpp
+++ /dev/null
-#include <torch/csrc/jit/netdef_converter.h>
-
-namespace torch {
-namespace jit {
-
-static AttributeKind getArgKind(const caffe2::Argument& arg) {
- if (arg.has_i()) {
- return AttributeKind::i;
- } else if (arg.has_f()) {
- return AttributeKind::f;
- } else if (arg.has_s()) {
- return AttributeKind::s;
- } else if (arg.has_t()) {
- return AttributeKind::t;
- } else if (arg.has_n()) {
- return AttributeKind::g;
- } else if (arg.ints().size()) {
- return AttributeKind::is;
- } else if (arg.floats().size()) {
- return AttributeKind::fs;
- } else if (arg.strings().size()) {
- return AttributeKind::ss;
- } else if (arg.tensors().size()) {
- return AttributeKind::ts;
- } else if (arg.nets().size()) {
- return AttributeKind::gs;
- }
- // Unknown type.
- abort();
-}
-
-static void convertArg(const caffe2::Argument& arg, Node* node) {
- std::string attrName = "attr::" + arg.name();
- auto attrSymbol = Symbol::fromQualString(attrName);
- AttributeKind kind = getArgKind(arg);
- switch (kind) {
- case AttributeKind::i: {
- node->i_(attrSymbol, (long)arg.i());
- break;
- }
- case AttributeKind::f: {
- node->f_(attrSymbol, arg.f());
- break;
- }
- case AttributeKind::s: {
- node->s_(attrSymbol, arg.s());
- break;
- }
- case AttributeKind::is: {
- std::vector<long> is(arg.ints().begin(), arg.ints().end());
- node->is_(attrSymbol, is);
- break;
- }
- case AttributeKind::fs: {
- std::vector<double> fs(arg.floats().begin(), arg.floats().end());
- node->fs_(attrSymbol, fs);
- break;
- }
- case AttributeKind::ss: {
- std::vector<std::string> ss(arg.strings().begin(), arg.strings().end());
- node->ss_(attrSymbol, ss);
- break;
- }
- default: {
- std::cout << "Unsupported type '" << toString(kind) << "' of attribute '"
- << attrName << "'"
- << " in node:" << std::endl;
- node->dump();
- abort();
- }
- }
-}
-
-void convertNetDefToIR(
- const caffe2::NetDef& net,
- Graph* g,
- std::unordered_map<std::string, Value*>* valueMapPtr,
- const std::string& prefix) {
- std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
- valueMap.clear();
-
- for (const auto& inputName : net.external_input()) {
- AT_ASSERT(!valueMap.count(inputName));
- valueMap[inputName] = g->addInput();
- }
-
- for (const auto& op : net.op()) {
- std::string name = prefix + op.type();
- Node* node =
- g->create(Symbol::fromQualString(name), {}, op.output().size());
- g->insertNode(node);
-
- for (const auto& input : op.input()) {
- AT_ASSERT(valueMap.count(input));
- node->addInput(valueMap[input]);
- }
- int idx = 0;
- for (const auto& output : op.output()) {
- // If output already exists in valueMap, overwrite it. This way we will
- // have the last definition of a value named 'output' in valueMap.
- valueMap[output] = node->outputs()[idx++];
- }
- for (const auto& arg : op.arg()) {
- convertArg(arg, node);
- }
- }
-
- for (const auto& outputName : net.external_output()) {
- AT_ASSERT(valueMap.count(outputName));
- g->registerOutput(valueMap.at(outputName));
- }
-}
-
-static void convertAttrToCaffe2Arg(
- const Node* node,
- const Symbol& name,
- caffe2::Argument* arg) {
- arg->set_name(name.toUnqualString());
- switch (node->kindOf(name)) {
- case AttributeKind::i: {
- arg->set_i(node->i(name));
- break;
- }
- case AttributeKind::f: {
- arg->set_f(node->f(name));
- break;
- }
- case AttributeKind::s: {
- arg->set_s(node->s(name));
- break;
- }
- case AttributeKind::is: {
- for (long i : node->is(name)) {
- arg->add_ints(i);
- }
- break;
- }
- case AttributeKind::fs: {
- for (double f : node->fs(name)) {
- arg->add_floats(f);
- }
- break;
- }
- case AttributeKind::ss: {
- for (const std::string& s : node->ss(name)) {
- arg->add_strings(s);
- }
- break;
- }
- default: {
- std::cout << "Unsupported type '" << toString(node->kindOf(name))
- << "' of attribute '" << name.toUnqualString() << "'"
- << " in node:" << std::endl;
- node->dump();
- abort();
- }
- }
-}
-
-static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net) {
- caffe2::OperatorDef op;
- op.set_type(node->kind().toQualString());
- for (const Value* input : node->inputs()) {
- op.add_input(input->uniqueName());
- }
- for (const Value* output : node->outputs()) {
- op.add_output(output->uniqueName());
- }
- std::vector<Symbol> names = node->attributeNames();
- for (const Symbol& name : names) {
- caffe2::Argument* arg = op.add_arg();
- convertAttrToCaffe2Arg(node, name, arg);
- }
- *net->add_op() = op;
-}
-
-void convertIRToNetDef(caffe2::NetDef* net, const Graph& g) {
- net->mutable_op()->Clear();
-
- for (const Value* value : g.inputs()) {
- net->add_external_input(value->uniqueName());
- }
-
- for (const Node* node : g.nodes()) {
- convertNodeToCaffe2Op(node, net);
- }
-
- for (const Value* value : g.outputs()) {
- net->add_external_output(value->uniqueName());
- }
-}
-
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-#include <caffe2/proto/caffe2_pb.h>
-#include <torch/csrc/jit/ir.h>
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-
-/** \brief Convert a caffe2 NetDef to PyTorch IR.
- *
- * The NetDef \p net is converted and the result is stored in the
- * torch::jit::Graph \p graph. The function also records name->value map in \p
- * valueMapPtr. If the original net had several values with the same name, the
- * map will contain the value for the last definition.
- * \p Prefix can be used for appending some string to every operator name (e.g.
- * we can add "caffe2::").
- */
-void convertNetDefToIR(
- const caffe2::NetDef& net,
- Graph* graph,
- std::unordered_map<std::string, Value*>* valueMapPtr,
- const std::string& prefix = "");
-
-/** \brief Convert PyTorch IR \p graph to Caffe2 NetDef \p net.
- *
- * Note: for constant nodes (prim::Const) we generate a separate op in the net,
- * which might or might not be what we want. The idea here is that eventually
- * both formats will converge to PyTorch IR, so for now we try to keep as close
- * to it as possible. For short-term applications we might add a separate pass
- * that would fold such const-nodes into their users.
- *
- * TODO: We might need to do a better job at preserving names of the variables,
- * especially external_inputs/external_outputs.
- */
-void convertIRToNetDef(caffe2::NetDef* net, const Graph& graph);
-
-} // namespace jit
-} // namespace torch