From 496a3339dc6cd89345aac2f5da841f71fe635a39 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Mon, 11 Mar 2019 19:07:58 -0700 Subject: [PATCH] add support for parsing class defs to the string frontend (#17628) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17628 This is not hooked up anywhere yet, just adding support. This shares the same restrictions as the python frontend—namely, that the only exprs allowed right now are method defs. Reviewed By: shannonzhu Differential Revision: D14291654 fbshipit-source-id: 7798e5ff412a52ef8803c7bae8f439e50968a73a --- test/cpp/jit/gtest.cpp | 3 +++ test/cpp/jit/no-gtest.cpp | 3 +++ test/cpp/jit/test_class_parser.h | 34 ++++++++++++++++++++++++++++++++++ torch/csrc/jit/script/parser.cpp | 20 ++++++++++++++++++++ torch/csrc/jit/script/parser.h | 1 + 5 files changed, 61 insertions(+) create mode 100644 test/cpp/jit/test_class_parser.h diff --git a/test/cpp/jit/gtest.cpp b/test/cpp/jit/gtest.cpp index 8f6c2a8..2c1edae 100644 --- a/test/cpp/jit/gtest.cpp +++ b/test/cpp/jit/gtest.cpp @@ -1,12 +1,14 @@ #include #include +#include #include #include #include using namespace torch; using namespace torch::jit; +using namespace torch::jit::script; #define JIT_TEST(name) \ TEST(JitTest, name) { \ @@ -44,6 +46,7 @@ JIT_TEST(NetDefConverter) JIT_TEST(THNNConv) JIT_TEST(ATenNativeBatchNorm) JIT_TEST(NoneSchemaMatch) +JIT_TEST(ClassParser) #define JIT_TEST_CUDA(name) \ TEST(JitTest, name##_CUDA) { \ diff --git a/test/cpp/jit/no-gtest.cpp b/test/cpp/jit/no-gtest.cpp index 00b6892..3c6ca37 100644 --- a/test/cpp/jit/no-gtest.cpp +++ b/test/cpp/jit/no-gtest.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -6,6 +7,7 @@ #include #include +using namespace torch::jit::script; namespace torch { namespace jit { std::string runJITCPPTests() { @@ -44,6 +46,7 @@ std::string runJITCPPTests() { testMemoryDAG(); testNetDefConverter(out); testIRParser(out); + testClassParser(); return out.str(); } diff --git a/test/cpp/jit/test_class_parser.h b/test/cpp/jit/test_class_parser.h new file mode 100644 index 0000000..50a87ab --- /dev/null +++ b/test/cpp/jit/test_class_parser.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { +namespace script { +const auto testSource = R"JIT( + class FooTest: + def __init__(self, x): + self.x = x + + def get_x(self): + return self.x +)JIT"; + +void testClassParser() { + auto cu = std::make_shared(); + Parser p(testSource); + std::vector definitions; + std::vector resolvers; + + const auto classDef = ClassDef(p.parseClass()); + p.lexer().expect(TK_EOF); + + ASSERT_EQ(classDef.name().name(), "FooTest"); + ASSERT_EQ(classDef.defs().size(), 2); + ASSERT_EQ(classDef.defs()[0].name().name(), "__init__"); + ASSERT_EQ(classDef.defs()[1].name().name(), "get_x"); +} +} // namespace script +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/script/parser.cpp b/torch/csrc/jit/script/parser.cpp index b47a0b6..e00a503 100644 --- a/torch/csrc/jit/script/parser.cpp +++ b/torch/csrc/jit/script/parser.cpp @@ -549,6 +549,23 @@ struct ParserImpl { paramlist.range(), List(paramlist), return_annotation); } + TreeRef parseClass() { + L.expect(TK_CLASS_DEF); + const auto name = parseIdent(); + // TODO no inheritance or () allowed right now + L.expect(':'); + + L.expect(TK_INDENT); + std::vector methods; + while (L.cur().kind != TK_DEDENT) { + methods.push_back(Def(parseFunction(/*is_method=*/true))); + } + L.expect(TK_DEDENT); + + return ClassDef::create( + name.range(), name, List::create(name.range(), methods)); + } + TreeRef parseFunction(bool is_method) { L.expect(TK_DEF); auto name = parseIdent(); @@ -590,6 +607,9 @@ Parser::~Parser() = default; TreeRef Parser::parseFunction(bool is_method) { return pImpl->parseFunction(is_method); } +TreeRef Parser::parseClass() { + return pImpl->parseClass(); +} Lexer& Parser::lexer() { return pImpl->lexer(); } diff --git a/torch/csrc/jit/script/parser.h b/torch/csrc/jit/script/parser.h index b695e2c..5515b61 100644 --- a/torch/csrc/jit/script/parser.h +++ b/torch/csrc/jit/script/parser.h @@ -19,6 +19,7 @@ TORCH_API Decl mergeTypesFromTypeComment( struct TORCH_API Parser { explicit Parser(const std::string& str); TreeRef parseFunction(bool is_method); + TreeRef parseClass(); Decl parseTypeComment(); Lexer& lexer(); ~Parser(); -- 2.7.4