From 3a01a45f06db767cc36669de70af69b09dbafb0f Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Sat, 16 Feb 2019 20:18:54 -0800 Subject: [PATCH] Implement IRParser. (#16987) Summary: It might need some cleaning up and might be missing some features, but it should be already working for most cases. This PR is based on top of PR16986 (so please review only the last commit here). Pull Request resolved: https://github.com/pytorch/pytorch/pull/16987 Differential Revision: D14074577 Pulled By: ZolotukhinM fbshipit-source-id: 712b598f423265655f574bb9903e2066628eaad3 --- test/cpp/jit/gtest.cpp | 2 + test/cpp/jit/no-gtest.cpp | 2 + test/cpp/jit/test_irparser.h | 155 +++++++++++++++ tools/build_variables.py | 1 + torch/CMakeLists.txt | 1 + torch/csrc/jit/irparser.cpp | 443 +++++++++++++++++++++++++++++++++++++++++++ torch/csrc/jit/irparser.h | 15 ++ 7 files changed, 619 insertions(+) create mode 100644 test/cpp/jit/test_irparser.h create mode 100644 torch/csrc/jit/irparser.cpp create mode 100644 torch/csrc/jit/irparser.h diff --git a/test/cpp/jit/gtest.cpp b/test/cpp/jit/gtest.cpp index 8e9c094..a3aaa4a 100644 --- a/test/cpp/jit/gtest.cpp +++ b/test/cpp/jit/gtest.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -34,6 +35,7 @@ JIT_TEST(TopologicalMove) JIT_TEST(SubgraphUtils) JIT_TEST(AliasAnalysis) JIT_TEST(AliasTracker) +JIT_TEST(IRParser) JIT_TEST(NetDefConverter) diff --git a/test/cpp/jit/no-gtest.cpp b/test/cpp/jit/no-gtest.cpp index 72307a6..654dd30 100644 --- a/test/cpp/jit/no-gtest.cpp +++ b/test/cpp/jit/no-gtest.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -39,6 +40,7 @@ std::string runJITCPPTests() { testAliasAnalysis(); testAliasTracker(); testNetDefConverter(out); + testIRParser(out); return out.str(); } diff --git a/test/cpp/jit/test_irparser.h b/test/cpp/jit/test_irparser.h new file mode 100644 index 0000000..cdd3263 --- /dev/null +++ b/test/cpp/jit/test_irparser.h @@ -0,0 +1,155 @@ +#pragma once + +#include +#include +#include "test/cpp/jit/test_base.h" + +#include +#include + +namespace torch { +namespace jit { + +/** \brief Parse IR from \p S, print the parsed graph and verify that the output + * string matches the original string. + * + * The function is sensitive to value naming and whitespace, so it should be + * used with care. Nevertheless, it helps to keep tests more compact. + */ +static void checkRoundtrip(const std::string& s) { + auto graph = std::make_shared(); + script::parseIR(s, &*graph); + std::ostringstream ss; + ss << *graph; + std::string parsed = ss.str(); + + // Skip whitespace in the beginning of the input string. + int i = 0; + for (char c : s) { + if (!isspace(c)) { + break; + } + i++; + } + std::string original = s.substr(i, s.size()); + if (original != parsed) { + std::cerr << "Input:" << std::endl << original << std::endl; + std::cerr << "Parsed:" << std::endl << parsed << std::endl; + } + AT_ASSERT(original == parsed); +} + +void testIRParser(std::ostream& out = std::cout) { + { + auto graph = std::make_shared(); + script::parseIR( + R"IR( +graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = foo::add(%0, %1) + %res, %3 = foo::mul(%0, %2) + %x, %y = foo::combine(%res, %2, %3) + return (%x, %y, %res))IR", + &*graph); + + AT_ASSERT(graph->inputs().size() == 2); + AT_ASSERT(graph->outputs().size() == 3); + Value* x = graph->outputs()[0]; + Value* y = graph->outputs()[1]; + Value* res = graph->outputs()[2]; + Value* t0 = graph->inputs()[0]; + Value* t1 = graph->inputs()[1]; + AT_ASSERT(x->node() == y->node()); + Node* comb = x->node(); + Value* t2 = comb->inputs()[1]; + Value* t3 = comb->inputs()[2]; + AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine")); + AT_ASSERT(comb->outputs() == std::vector({x, y})); + AT_ASSERT(comb->inputs() == std::vector({res, t2, t3})); + Node* mul = res->node(); + AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul")); + AT_ASSERT(mul->inputs() == std::vector({t0, t2})); + AT_ASSERT(mul->outputs() == std::vector({res, t3})); + Node* add = t2->node(); + AT_ASSERT(add->kind().toQualString() == std::string("foo::add")); + AT_ASSERT(add->inputs() == std::vector({t0, t1})); + AT_ASSERT(add->outputs() == std::vector({t2})); + } + { + checkRoundtrip(R"IR( +graph(): + %0 : Tensor = a::a() + block0(): + %1 : Tensor = b::b() + block0(): + %2 : Tensor = c::c() + -> () + -> () + %3 : Tensor = d::d() + return (%3) +)IR"); + } + { + checkRoundtrip(R"IR( +graph(%0 : Tensor, + %1 : Tensor, + %2 : Tensor): + %3 : int = prim::Constant[value=1]() + %4 : Tensor = aten::add(%0, %1, %3) + %5 : Tensor = prim::If(%2) + block0(): + %6 : int = prim::Constant[value=1]() + %7 : Tensor = aten::add(%1, %3, %6) + %8 : int = prim::Constant[value=1]() + %9 : Tensor = aten::add(%7, %3, %8) + -> (%9) + %10 : int = prim::Constant[value=1]() + %11 : Tensor = aten::add(%5, %3, %10) + return (%11) +)IR"); + } + { + auto graph = std::make_shared(); + script::parseIR( + R"IR( +graph(%a): + return (%a))IR", + &*graph); + graph->inputs()[0]->type()->expect(); + } + { + // Check that parser corectly handles values reusing the same name. + auto graph = std::make_shared(); + script::parseIR( + R"IR( +graph(%x): + %x = a::a(%x) + %x = b::b(%x) + return (%x))IR", + &*graph); + Value* x0 = graph->inputs()[0]; + Value* x2 = graph->outputs()[0]; + Node* b = x2->node(); + Value* x1 = b->inputs()[0]; + Node* a = x1->node(); + AT_ASSERT(a->inputs() == std::vector({x0})); + AT_ASSERT(a->outputs() == std::vector({x1})); + AT_ASSERT(b->inputs() == std::vector({x1})); + AT_ASSERT(b->outputs() == std::vector({x2})); + } + { + // Check that parser handles attributes and types. + checkRoundtrip( + R"IR( +graph(%0 : Tensor, + %1 : Tensor, + %2 : Tensor): + %3 : int, %4 : Tensor = qqq::qqq[i_asdf=2, f_asdf=3.14, s_asdf="hello", ss_asdf=["hello world", "bye bye"]](%0) + %5 : int, %6 : Tensor = ppp::ppp[i_asdf=2, f_asdf=3.14, s_asdf="\"\"\"\"\nhe\"llo", q=[3, 2, 4]](%0) + %7 : float = vvv::vvv[s_asdf="hello"](%0) + %8 : string = z::z() + return (%7) +)IR"); + } +} +} // namespace jit +} // namespace torch diff --git a/tools/build_variables.py b/tools/build_variables.py index b053249..9323d55 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -57,6 +57,7 @@ libtorch_sources = [ "torch/csrc/jit/import.cpp", "torch/csrc/jit/interpreter.cpp", "torch/csrc/jit/ir.cpp", + "torch/csrc/jit/irparser.cpp", "torch/csrc/jit/netdef_converter.cpp", "torch/csrc/jit/caffe2_operator.cpp", "torch/csrc/jit/register_caffe2_ops.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 2d2bb4f..f8302ae 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -133,6 +133,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/constants.cpp ${TORCH_SRC_DIR}/csrc/jit/node_hashing.cpp ${TORCH_SRC_DIR}/csrc/jit/ir.cpp + ${TORCH_SRC_DIR}/csrc/jit/irparser.cpp ${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp ${TORCH_SRC_DIR}/csrc/jit/operator.cpp ${TORCH_SRC_DIR}/csrc/jit/caffe2_operator.cpp diff --git a/torch/csrc/jit/irparser.cpp b/torch/csrc/jit/irparser.cpp new file mode 100644 index 0000000..f3341d9 --- /dev/null +++ b/torch/csrc/jit/irparser.cpp @@ -0,0 +1,443 @@ +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace script { + +struct VarWithType; +struct ParsedLiteral; + +class IRParser { + friend void parseIR(const std::string& str, torch::jit::Graph* graph); + IRParser(const std::string& str, torch::jit::Graph* graph) + : L(str), g(graph) {} + + std::string parseVar(); + VarWithType parseVarWithType(); + ParsedLiteral parseScalarLiteral(Node* n); + + void parse(); + void parseGraphInputs(); + void parseReturnOperator(); + + void parseBlocks(Node* parentNode); + void parseBlock(Node* parentNode); + void parseBlockInputs(Block* b); + void parseBlockOutputs(Block* b); + + void parseOperatorsList(Block* b); + void parseOperator(Block* b); + void parseOperatorOutputs(std::vector* outs); + std::string parseOperatorName(); + void parseOperatorInputs(Node* n); + void parseAttrs(Node* n); + void parseAttr(Node* n); + + void parseList( + int begin, + int sep, + int end, + const std::function& callback); + + torch::jit::script::Lexer L; + torch::jit::Graph* g = nullptr; + std::unordered_map vmap; +}; + +struct ParsedLiteral { + ParsedLiteral() = default; + + AttributeKind k = AttributeKind::t; + + int64_t i = 0; + std::string s = ""; + double f = 0.0; + std::vector is; + std::vector ss; + std::vector fs; +}; + +struct VarWithType { + VarWithType() = default; + std::string name; + std::string type; +}; + +void parseIR(const std::string& str, torch::jit::Graph* graph) { + torch::jit::script::IRParser p(str, graph); + p.parse(); +} + +TypePtr parseType(const std::string& s) { + if (s == "Tensor") { + return TensorType::get(); + } + if (s == "int") { + return IntType::get(); + } + if (s == "float") { + return FloatType::get(); + } + if (s == "string") { + return StringType::get(); + } + // TODO: Support other types. + AT_ASSERTM(false, "Type not supported by parser:", s); +} + +VarWithType IRParser::parseVarWithType() { + L.expect('%'); + VarWithType r; + if (L.cur().kind == TK_IDENT) { + r.name = L.expect(TK_IDENT).text(); + } else { + r.name = L.expect(TK_NUMBER).text(); + } + r.type = "Tensor"; + if (L.nextIf(':')) { + r.type = L.expect(TK_IDENT).text(); + } + return r; +} + +std::string IRParser::parseVar() { + L.expect('%'); + if (L.cur().kind == TK_IDENT) { + return L.expect(TK_IDENT).text(); + } + return L.expect(TK_NUMBER).text(); +} + +void IRParser::parseOperatorOutputs(std::vector* outs) { + if (L.cur().kind != '%') { + return; + } + parseList(TK_NOTHING, ',', TK_NOTHING, [&] { + outs->push_back(parseVarWithType()); + }); + L.expect('='); +} + +// Parse string or numeric literal and return it along with its type. +ParsedLiteral IRParser::parseScalarLiteral(Node* n) { + auto token = L.cur(); + std::string str; + ParsedLiteral r; + switch (token.kind) { + case TK_STRINGLITERAL: + r.k = AttributeKind::s; + r.s = parseStringLiteral(token.range, token.text()); + L.next(); + return r; + case '-': + str = "-"; + L.next(); + L.expect(TK_NUMBER); + // Fallthrough + case TK_NUMBER: + str += L.cur().text(); + + if (str.find('.') != std::string::npos || + str.find('e') != std::string::npos) { + r.k = AttributeKind::f; + r.f = std::stod(str); + } else { + r.k = AttributeKind::i; + r.i = std::stoll(str); + } + L.next(); + return r; + default: + throw ErrorReport(token.range) + << "Could not parse literal" << token.text(); + } +} + +/** \brief Parse attribute and add it to the node N. + * + * The function determines the attribute type (string, int, float, list of + * strings, list of ints, list of floats, and a list of tensors (currently only + * for empty lists)). + * An attribute looks like the following: + * AttrName=AttrValue + * Where AttrValue can be a list or a scalar literal, e.g.: + * size = 27 + * name = "Bob" + * coefs = [1.2, 3.4, 0.6] + */ +void IRParser::parseAttr(Node* n) { + std::string attrname = L.expect(TK_IDENT).text(); + L.expect('='); + if (L.cur().kind == '[') { + // list + AttributeKind k = AttributeKind::ts; + std::vector is; + std::vector ss; + std::vector fs; + int elem_num = 0; + parseList('[', ',', ']', [&] { + ParsedLiteral r = parseScalarLiteral(n); + switch (r.k) { + case AttributeKind::s: + ss.push_back(r.s); + AT_ASSERT(!elem_num++ || k == AttributeKind::ss); + k = AttributeKind::ss; + break; + case AttributeKind::i: + is.push_back(r.i); + AT_ASSERT(!elem_num++ || k == AttributeKind::is); + k = AttributeKind::is; + break; + case AttributeKind::f: + fs.push_back(r.f); + AT_ASSERT(!elem_num++ || k == AttributeKind::fs); + k = AttributeKind::fs; + break; + default: + throw ErrorReport(L.cur().range) << "Unexpected attr type"; + } + }); + switch (k) { + case AttributeKind::ts: + n->ts_(Symbol::attr(attrname), {}); + break; + case AttributeKind::ss: + n->ss_(Symbol::attr(attrname), ss); + break; + case AttributeKind::fs: + n->fs_(Symbol::attr(attrname), fs); + break; + case AttributeKind::is: + n->is_(Symbol::attr(attrname), is); + break; + default: + throw ErrorReport(L.cur().range) << "Unexpected attr type"; + } + } else { + // scalar + ParsedLiteral r = parseScalarLiteral(n); + switch (r.k) { + case AttributeKind::s: + n->s_(Symbol::attr(attrname), r.s); + break; + case AttributeKind::i: + n->i_(Symbol::attr(attrname), r.i); + break; + case AttributeKind::f: + n->f_(Symbol::attr(attrname), r.f); + break; + default: + throw ErrorReport(L.cur().range) << "Unexpected attr type"; + } + return; + } +} + +void IRParser::parseAttrs(Node* n) { + parseList('[', ',', ']', [&] { parseAttr(n); }); +} + +void IRParser::parseOperatorInputs(Node* n) { + if (L.cur().kind == '[') { + parseAttrs(n); + } + parseList('(', ',', ')', [&] { + std::string var_name = parseVar(); + AT_ASSERT(vmap.count(var_name)); + n->addInput(vmap[var_name]); + }); +} + +void IRParser::parseBlocks(Node* parentNode) { + L.expect(TK_INDENT); + while (L.cur().kind != TK_DEDENT) { + parseBlock(parentNode); + } + L.expect(TK_DEDENT); +} + +static bool isNumber(const std::string& s) { + return s.find_first_not_of("0123456789") == std::string::npos; +} + +void IRParser::parseBlockInputs(Block* b) { + parseList('(', ',', ')', [&] { + VarWithType v = parseVarWithType(); + // If the name is a number, don't use it + std::string uniq_name = isNumber(v.name) ? "" : v.name; + vmap[v.name] = b->addInput(uniq_name); + vmap[v.name]->setType(parseType(v.type)); + }); +} + +void IRParser::parseBlockOutputs(Block* b) { + L.expect(TK_ARROW); + parseList('(', ',', ')', [&] { + std::string var_name = parseVar(); + AT_ASSERT(vmap.count(var_name)); + b->registerOutput(vmap[var_name]); + }); + L.expect(TK_NEWLINE); + L.expect(TK_DEDENT); +} + +/** \brief Parse a block. + * + * It should look like the following: + * blockName(input1, input2, input3, ...): + * op1 + * op2 + * ... + * opN + * -> (output1, output2, output3, ...) + */ +void IRParser::parseBlock(Node* parentNode) { + Block* b = parentNode->addBlock(); + L.expect(TK_IDENT).text(); // Block name is not used anywhere. + parseBlockInputs(b); + L.expect(':'); + parseOperatorsList(b); + parseBlockOutputs(b); +} + +/** \brief Parse a list of statements. + * + * It is expected to be delimited by TK_NEWLINE and end with TK_RETURN or + * TK_ARROW. + */ +void IRParser::parseOperatorsList(Block* b) { + L.expect(TK_INDENT); + while (L.cur().kind != TK_ARROW && L.cur().kind != TK_RETURN) { + parseOperator(b); + } +} + +std::string IRParser::parseOperatorName() { + std::string name = L.expect(TK_IDENT).text(); + L.expect(':'); + L.expect(':'); + name += "::" + L.expect(TK_IDENT).text(); + return name; +} + +/** \brief Parse a statement. + * + * It should look like the following: + * = NodeName[]() + * + * Outputs, blocks and attributes are optional. + */ +void IRParser::parseOperator(Block* b) { + // Parse lefthand side. + std::vector outs; + parseOperatorOutputs(&outs); + + // Parse the name and create the corresponding node in the graph. + std::string name = parseOperatorName(); + Node* n = g->create(Symbol::fromQualString(name), {}, outs.size()); + + // Parse attributes and inputs. + parseOperatorInputs(n); + + // Register outputs. + int idx = 0; + for (const VarWithType& v : outs) { + vmap[v.name] = n->outputs()[idx++]; + vmap[v.name]->setType(parseType(v.type)); + } + + // Insert the new node into block B. + b->appendNode(n); + + // If the statement has nested blocks, parse them: + if (L.cur().kind == TK_INDENT) { + parseBlocks(n); + } + L.nextIf(TK_NEWLINE); +} + +void IRParser::parseGraphInputs() { + parseList('(', ',', ')', [&] { + VarWithType v = parseVarWithType(); + // If the name is a number, don't use it + std::string uniq_name = isNumber(v.name) ? "" : v.name; + vmap[v.name] = g->addInput(uniq_name); + vmap[v.name]->setType(parseType(v.type)); + }); +} + +/** \brief Parse return statement. + * + * It should look like the following: + * return (x : TypeX, y : TypeY, z, ...) + */ +void IRParser::parseReturnOperator() { + L.expect(TK_RETURN); + + // Parse output names and types + parseList('(', ',', ')', [&] { + std::string var_name = parseVar(); + // Outputs should already be in VMAP, otherwise we're trying to return + // undefined value. + AT_ASSERT(vmap.count(var_name)); + g->registerOutput(vmap.at(var_name)); + }); + + // Consume ending tokens + if (L.cur().kind != TK_EOF) { + L.expect(TK_NEWLINE); + L.expect(TK_DEDENT); + } +} + +/** \brief Parse entire graph. + * + * It should look like the following: + * graphName (input1, input2, ... inputN): + * op1 + * op2 + * ... + * opN + * return (output1, output2, ... outputN) + */ +void IRParser::parse() { + // Parse graph definition, it should look like the following: + // graphName (input1, input2, ... inputN): + std::string graphName = L.expect(TK_IDENT).text(); + parseGraphInputs(); + L.expect(':'); + + // After the definition we should have a list of statements, parse it: + parseOperatorsList(g->block()); + + // The last statement should be return, which specifies graph outputs + parseReturnOperator(); +} + +void IRParser::parseList( + int begin, + int sep, + int end, + const std::function& callback) { + if (begin != TK_NOTHING) { + L.expect(begin); + } + if (L.cur().kind != end) { + do { + callback(); + } while (L.nextIf(sep)); + } + if (end != TK_NOTHING) { + L.expect(end); + } +} + +} // namespace script +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/irparser.h b/torch/csrc/jit/irparser.h new file mode 100644 index 0000000..66b1ffd --- /dev/null +++ b/torch/csrc/jit/irparser.h @@ -0,0 +1,15 @@ +#include + +namespace torch { +namespace jit { + +struct Graph; + +namespace script { + +// \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH. +void parseIR(const std::string& str, torch::jit::Graph* graph); + +} // namespace script +} // namespace jit +} // namespace torch -- 2.7.4