From: Michael Suo Date: Fri, 22 Mar 2019 23:24:36 +0000 (-0700) Subject: Turn script_type_parser into a class (#18211) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~669 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ff3ecfec893e31d8c1696f278d4c136536449a3b;p=platform%2Fupstream%2Fpytorch.git Turn script_type_parser into a class (#18211) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18211 ghimport-source-id: 73b81e9ec631937b14db1da10991831788a6894b Stack from [ghstack](https://github.com/ezyang/ghstack): * #18296 [jit] Add namespacing for ScriptClasses * #18284 [jit] make test module hook use save/load * **#18211 [jit] Turn script_type_parser into a class** * #18148 [jit] python interop for script classes If we are namespacing classes, the type parser will need to carry around some state about which namespaces to look in. This PR just wraps it in a class in preparation. Also, subscriptToType can no longer be static, since parseTypeFromExpr may give different results depending on the namespaces available, so it's been made a regular function instead of a static map lookup. Reviewed By: eellison Differential Revision: D14581128 fbshipit-source-id: 711315472ccde1920abf9fdb5a871ac27fb86787 --- diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index d375686..fd63c33 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -258,6 +258,7 @@ void ScriptModuleDeserializer::convertModule( module->register_parameter(param_def.name(), tensor, /*is_buffer=*/false); } } + script::ScriptTypeParser typeParser; for (int i = 0; i < module_def.attributes_size(); ++i) { const torch::AttributeDef& attr_def = module_def.attributes(i); if (module->find_buffer(attr_def.name())) { @@ -267,7 +268,7 @@ void ScriptModuleDeserializer::convertModule( module->register_attribute( attr_def.name(), - script::parseType(attr_def.type()), + typeParser.parseType(attr_def.type()), attribute_table_.at(attr_def.id()) ); } diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 0b1074f..c4796f5 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -402,7 +402,8 @@ struct Environment { {"min", std::make_shared(prim::min, at::nullopt)}, {"max", std::make_shared(prim::max, at::nullopt)}, {"list", std::make_shared(aten::list, at::nullopt)}, - {"rangelist", std::make_shared(prim::rangelist, at::nullopt)}, + {"rangelist", + std::make_shared(prim::rangelist, at::nullopt)}, }; auto it = globals.find(ident); if (it != globals.end()) { @@ -543,6 +544,7 @@ struct to_ir { Resolver resolver; std::unordered_map integral_constants; std::unordered_map fp_constants; + ScriptTypeParser typeParser_; // Singly-linked list of environments. This top element contains a member // `next` that points to the most immediate enclosing scope's value. @@ -658,11 +660,12 @@ struct to_ir { c10::optional N; // BroadcastList list can only appear at the argument level - if (auto maybe_broad_list = parseBroadcastList(decl_arg.type())) { + if (auto maybe_broad_list = + typeParser_.parseBroadcastList(decl_arg.type())) { type = maybe_broad_list->first; N = maybe_broad_list->second; } else { - type = parseTypeFromExpr(decl_arg.type()); + type = typeParser_.parseTypeFromExpr(decl_arg.type()); N = c10::nullopt; } c10::optional default_value = c10::nullopt; @@ -688,10 +691,10 @@ struct to_ir { if (!decl.return_type().present()) return {}; - if (parseBroadcastList(decl.return_type().get())) + if (typeParser_.parseBroadcastList(decl.return_type().get())) throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type"; - auto parsed_type = parseTypeFromExpr(decl.return_type().get()); + auto parsed_type = typeParser_.parseTypeFromExpr(decl.return_type().get()); return {Argument( "", parsed_type, @@ -2006,7 +2009,7 @@ struct to_ir { return emitForkExpr(loc, forked, inputs, attributes); } else if (auto annotate_value = dynamic_cast(sv.get())) { checkApplyExpr(apply, loc); - TypePtr type = parseTypeFromExpr(apply.inputs()[0]); + TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]); Value* expr = tryConvertToType( apply.range(), *graph, @@ -2055,7 +2058,7 @@ struct to_ir { } return false; } - auto type_name = parseBaseTypeName(classinfo); + auto type_name = typeParser_.parseBaseTypeName(classinfo); if (!type_name) { throw ErrorReport(classinfo.range()) << "type must be a type identifier"; @@ -2073,7 +2076,7 @@ struct to_ir { << "Optional isinstance check is not supported, " << "consider use is/isnot None instead"; } else { - TypePtr type = parseTypeFromExpr(classinfo); + TypePtr type = typeParser_.parseTypeFromExpr(classinfo); if (val->type()->isSubtypeOf(type)) { return true; } @@ -2090,7 +2093,8 @@ struct to_ir { throw ErrorReport(loc) << "Only one argument to __new__ allowed"; } return classNew->createObject( - apply.range(), method, Var(apply.inputs()[0]).name().name());; + apply.range(), method, Var(apply.inputs()[0]).name().name()); + ; } else { auto inputs = getNamedValues(apply.inputs(), true); auto attributes = emitAttributes(apply.attributes()); diff --git a/torch/csrc/jit/script/script_type_parser.cpp b/torch/csrc/jit/script/script_type_parser.cpp index daffb9c..cc6ddf4 100644 --- a/torch/csrc/jit/script/script_type_parser.cpp +++ b/torch/csrc/jit/script/script_type_parser.cpp @@ -1,13 +1,11 @@ -#include #include -#include #include -#include namespace torch { namespace jit { namespace script { +namespace { const std::unordered_map& ident_to_type_lut() { static std::unordered_map map = { {"Tensor", TensorType::get()}, @@ -24,73 +22,63 @@ const std::unordered_map& ident_to_type_lut() { return map; } -const std::unordered_map>& -subscript_to_type_fns() { - static std::unordered_map> - map = { - {"Tuple", - [](Subscript subscript) -> TypePtr { - std::vector subscript_expr_types; - for (auto expr : subscript.subscript_exprs()) { - subscript_expr_types.push_back(parseTypeFromExpr(expr)); - } - return TupleType::create(subscript_expr_types); - }}, - {"List", - [](Subscript subscript) -> TypePtr { - if (subscript.subscript_exprs().size() != 1) { - throw ErrorReport(subscript) - << " expected exactly one element type but found " - << subscript.subscript_exprs().size(); - } - auto elem_type = - parseTypeFromExpr(*subscript.subscript_exprs().begin()); - return ListType::create(elem_type); - }}, - {"Optional", - [](Subscript subscript) -> TypePtr { - if (subscript.subscript_exprs().size() != 1) { - throw ErrorReport(subscript) - << " expected exactly one element type but found " - << subscript.subscript_exprs().size(); - } - auto elem_type = - parseTypeFromExpr(*subscript.subscript_exprs().begin()); - return OptionalType::create(elem_type); - }}, - {"Future", - [](Subscript subscript) -> TypePtr { - if (subscript.subscript_exprs().size() != 1) { - throw ErrorReport(subscript) - << " expected exactly one element type but found " - << subscript.subscript_exprs().size(); - } - auto elem_type = - parseTypeFromExpr(*subscript.subscript_exprs().begin()); - return FutureType::create(elem_type); - }}, - {"Dict", - [](Subscript subscript) -> TypePtr { - if (subscript.subscript_exprs().size() != 2) { - throw ErrorReport(subscript) - << " expected exactly 2 element types but found " - << subscript.subscript_exprs().size(); - } - auto key_type = parseTypeFromExpr(subscript.subscript_exprs()[0]); - auto value_type = - parseTypeFromExpr(subscript.subscript_exprs()[1]); - return DictType::create(key_type, value_type); - }}, - }; - return map; -} - bool isTorch(const Expr& expr) { return expr.kind() == TK_VAR && Var(expr).name().name() == "torch"; } +} // namespace -c10::optional> parseBroadcastList( - const Expr& expr) { +TypePtr ScriptTypeParser::subscriptToType( + const std::string& typeName, + const Subscript& subscript) const { + if (typeName == "Tuple") { + std::vector subscript_expr_types; + for (auto expr : subscript.subscript_exprs()) { + subscript_expr_types.push_back(parseTypeFromExpr(expr)); + } + return TupleType::create(subscript_expr_types); + } else if (typeName == "List") { + if (subscript.subscript_exprs().size() != 1) { + throw ErrorReport(subscript) + << " expected exactly one element type but found " + << subscript.subscript_exprs().size(); + } + auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin()); + return ListType::create(elem_type); + + } else if (typeName == "Optional") { + if (subscript.subscript_exprs().size() != 1) { + throw ErrorReport(subscript) + << " expected exactly one element type but found " + << subscript.subscript_exprs().size(); + } + auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin()); + return OptionalType::create(elem_type); + + } else if (typeName == "Future") { + if (subscript.subscript_exprs().size() != 1) { + throw ErrorReport(subscript) + << " expected exactly one element type but found " + << subscript.subscript_exprs().size(); + } + auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin()); + return FutureType::create(elem_type); + } else if (typeName == "Dict") { + if (subscript.subscript_exprs().size() != 2) { + throw ErrorReport(subscript) + << " expected exactly 2 element types but found " + << subscript.subscript_exprs().size(); + } + auto key_type = parseTypeFromExpr(subscript.subscript_exprs()[0]); + auto value_type = parseTypeFromExpr(subscript.subscript_exprs()[1]); + return DictType::create(key_type, value_type); + } else { + throw ErrorReport(subscript.range()) + << "Unknown type constructor " << typeName; + } +} + +c10::optional> ScriptTypeParser::parseBroadcastList( + const Expr& expr) const { if (expr.kind() != TK_SUBSCRIPT) return c10::nullopt; auto subscript = Subscript(expr); @@ -114,7 +102,8 @@ c10::optional> parseBroadcastList( if (subscript_exprs.size() != 1) throw ErrorReport(subscript.subscript_exprs().range()) - << "BroadcastingList/Optional[BroadcastingList] must be subscripted with a type"; + << "BroadcastingList/Optional[BroadcastingList] " + "must be subscripted with a type"; auto typ = subscript_exprs[0]; auto len = var.name().name().substr(strlen("BroadcastingList")); @@ -144,7 +133,8 @@ c10::optional> parseBroadcastList( // gets the base type name given namespaces where the types live // turns torch.Tensor -> Tensor, X -> X -c10::optional parseBaseTypeName(const Expr& expr) { +c10::optional ScriptTypeParser::parseBaseTypeName( + const Expr& expr) const { switch (expr.kind()) { case TK_VAR: { return Var(expr).name().name(); @@ -162,7 +152,7 @@ c10::optional parseBaseTypeName(const Expr& expr) { return at::nullopt; } -TypePtr parseTypeFromExpr(const Expr& expr) { +TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const { if (expr.kind() == TK_SUBSCRIPT) { auto subscript = Subscript(expr); auto value_name = parseBaseTypeName(subscript.value()); @@ -170,11 +160,7 @@ TypePtr parseTypeFromExpr(const Expr& expr) { throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier"; } - if (!subscript_to_type_fns().count(*value_name)) { - throw ErrorReport(subscript.range()) - << "Unknown type constructor " << *value_name; - } - return subscript_to_type_fns().at(*value_name)(subscript); + return subscriptToType(*value_name, subscript); } else if (auto name = parseBaseTypeName(expr)) { auto itr = ident_to_type_lut().find(*name); if (itr != ident_to_type_lut().end()) { @@ -190,7 +176,7 @@ TypePtr parseTypeFromExpr(const Expr& expr) { << " cannot be used in a type expression"; } -TypePtr parseType(const std::string& str) { +TypePtr ScriptTypeParser::parseType(const std::string& str) { Parser p(str); return parseTypeFromExpr(p.parseExp()); } diff --git a/torch/csrc/jit/script/script_type_parser.h b/torch/csrc/jit/script/script_type_parser.h index 65c0ca5..4eec6e6 100644 --- a/torch/csrc/jit/script/script_type_parser.h +++ b/torch/csrc/jit/script/script_type_parser.h @@ -2,16 +2,33 @@ #include #include #include - +#include namespace torch { namespace jit { namespace script { -struct Expr; -TORCH_API c10::optional parseBaseTypeName(const Expr& expr); -TORCH_API c10::TypePtr parseTypeFromExpr(const Expr& expr); -TORCH_API c10::optional> parseBroadcastList( - const Expr& expr); -TORCH_API c10::TypePtr parseType(const std::string& str); + +/** + * class ScriptTypeParser + * + * Parses expressions in our typed AST format (TreeView) into types and + * typenames. + */ +class TORCH_API ScriptTypeParser { + public: + c10::optional parseBaseTypeName(const Expr& expr) const; + + c10::TypePtr parseTypeFromExpr(const Expr& expr) const; + + c10::optional> parseBroadcastList( + const Expr& expr) const; + + c10::TypePtr parseType(const std::string& str); + + private: + at::TypePtr subscriptToType( + const std::string& typeName, + const Subscript& subscript) const; +}; } // namespace script } // namespace jit } // namespace torch