Turn script_type_parser into a class (#18211)
authorMichael Suo <suo@fb.com>
Fri, 22 Mar 2019 23:24:36 +0000 (16:24 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 23:30:05 +0000 (16:30 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18211
ghimport-source-id: 73b81e9ec631937b14db1da10991831788a6894b

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18296 [jit] Add namespacing for ScriptClasses
* #18284 [jit] make test module hook use save/load
* **#18211 [jit] Turn script_type_parser into a class**
* #18148 [jit] python interop for script classes

If we are namespacing classes, the type parser will need to carry around
some state about which namespaces to look in. This PR just wraps it in a
class in preparation.

Also, subscriptToType can no longer be static, since parseTypeFromExpr
may give different results depending on the namespaces available, so
it's been made a regular function instead of a static map lookup.

Reviewed By: eellison

Differential Revision: D14581128

fbshipit-source-id: 711315472ccde1920abf9fdb5a871ac27fb86787

torch/csrc/jit/import.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/script_type_parser.cpp
torch/csrc/jit/script/script_type_parser.h

index d375686..fd63c33 100644 (file)
@@ -258,6 +258,7 @@ void ScriptModuleDeserializer::convertModule(
       module->register_parameter(param_def.name(), tensor, /*is_buffer=*/false);
     }
   }
+  script::ScriptTypeParser typeParser;
   for (int i = 0; i < module_def.attributes_size(); ++i) {
     const torch::AttributeDef& attr_def = module_def.attributes(i);
     if (module->find_buffer(attr_def.name())) {
@@ -267,7 +268,7 @@ void ScriptModuleDeserializer::convertModule(
 
     module->register_attribute(
       attr_def.name(),
-      script::parseType(attr_def.type()),
+      typeParser.parseType(attr_def.type()),
       attribute_table_.at(attr_def.id())
     );
   }
index 0b1074f..c4796f5 100644 (file)
@@ -402,7 +402,8 @@ struct Environment {
           {"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
           {"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
           {"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
-          {"rangelist", std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
+          {"rangelist",
+           std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
       };
       auto it = globals.find(ident);
       if (it != globals.end()) {
@@ -543,6 +544,7 @@ struct to_ir {
   Resolver resolver;
   std::unordered_map<int64_t, Value*> integral_constants;
   std::unordered_map<double, Value*> fp_constants;
+  ScriptTypeParser typeParser_;
 
   // Singly-linked list of environments. This top element contains a member
   // `next` that points to the most immediate enclosing scope's value.
@@ -658,11 +660,12 @@ struct to_ir {
       c10::optional<int32_t> N;
 
       // BroadcastList list can only appear at the argument level
-      if (auto maybe_broad_list = parseBroadcastList(decl_arg.type())) {
+      if (auto maybe_broad_list =
+              typeParser_.parseBroadcastList(decl_arg.type())) {
         type = maybe_broad_list->first;
         N = maybe_broad_list->second;
       } else {
-        type = parseTypeFromExpr(decl_arg.type());
+        type = typeParser_.parseTypeFromExpr(decl_arg.type());
         N = c10::nullopt;
       }
       c10::optional<IValue> default_value = c10::nullopt;
@@ -688,10 +691,10 @@ struct to_ir {
     if (!decl.return_type().present())
       return {};
 
-    if (parseBroadcastList(decl.return_type().get()))
+    if (typeParser_.parseBroadcastList(decl.return_type().get()))
       throw ErrorReport(decl.return_type().range())
           << "Broadcastable lists cannot appear as a return type";
-    auto parsed_type = parseTypeFromExpr(decl.return_type().get());
+    auto parsed_type = typeParser_.parseTypeFromExpr(decl.return_type().get());
     return {Argument(
         "",
         parsed_type,
@@ -2006,7 +2009,7 @@ struct to_ir {
       return emitForkExpr(loc, forked, inputs, attributes);
     } else if (auto annotate_value = dynamic_cast<AnnotateValue*>(sv.get())) {
       checkApplyExpr(apply, loc);
-      TypePtr type = parseTypeFromExpr(apply.inputs()[0]);
+      TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]);
       Value* expr = tryConvertToType(
           apply.range(),
           *graph,
@@ -2055,7 +2058,7 @@ struct to_ir {
           }
           return false;
         }
-        auto type_name = parseBaseTypeName(classinfo);
+        auto type_name = typeParser_.parseBaseTypeName(classinfo);
         if (!type_name) {
           throw ErrorReport(classinfo.range())
               << "type must be a type identifier";
@@ -2073,7 +2076,7 @@ struct to_ir {
               << "Optional isinstance check is not supported, "
               << "consider use is/isnot None instead";
         } else {
-          TypePtr type = parseTypeFromExpr(classinfo);
+          TypePtr type = typeParser_.parseTypeFromExpr(classinfo);
           if (val->type()->isSubtypeOf(type)) {
             return true;
           }
@@ -2090,7 +2093,8 @@ struct to_ir {
         throw ErrorReport(loc) << "Only one argument to __new__ allowed";
       }
       return classNew->createObject(
-          apply.range(), method, Var(apply.inputs()[0]).name().name());;
+          apply.range(), method, Var(apply.inputs()[0]).name().name());
+      ;
     } else {
       auto inputs = getNamedValues(apply.inputs(), true);
       auto attributes = emitAttributes(apply.attributes());
index daffb9c..cc6ddf4 100644 (file)
@@ -1,13 +1,11 @@
-#include <torch/csrc/jit/script/script_type_parser.h>
 #include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/script/parser.h>
 #include <torch/csrc/jit/script/script_type_parser.h>
-#include <torch/csrc/jit/script/tree_views.h>
 
 namespace torch {
 namespace jit {
 namespace script {
 
+namespace {
 const std::unordered_map<std::string, TypePtr>& ident_to_type_lut() {
   static std::unordered_map<std::string, TypePtr> map = {
       {"Tensor", TensorType::get()},
@@ -24,73 +22,63 @@ const std::unordered_map<std::string, TypePtr>& ident_to_type_lut() {
   return map;
 }
 
-const std::unordered_map<std::string, std::function<TypePtr(Subscript)>>&
-subscript_to_type_fns() {
-  static std::unordered_map<std::string, std::function<TypePtr(Subscript)>>
-      map = {
-          {"Tuple",
-           [](Subscript subscript) -> TypePtr {
-             std::vector<TypePtr> subscript_expr_types;
-             for (auto expr : subscript.subscript_exprs()) {
-               subscript_expr_types.push_back(parseTypeFromExpr(expr));
-             }
-             return TupleType::create(subscript_expr_types);
-           }},
-          {"List",
-           [](Subscript subscript) -> TypePtr {
-             if (subscript.subscript_exprs().size() != 1) {
-               throw ErrorReport(subscript)
-                   << " expected exactly one element type but found "
-                   << subscript.subscript_exprs().size();
-             }
-             auto elem_type =
-                 parseTypeFromExpr(*subscript.subscript_exprs().begin());
-             return ListType::create(elem_type);
-           }},
-          {"Optional",
-           [](Subscript subscript) -> TypePtr {
-             if (subscript.subscript_exprs().size() != 1) {
-               throw ErrorReport(subscript)
-                   << " expected exactly one element type but found "
-                   << subscript.subscript_exprs().size();
-             }
-             auto elem_type =
-                 parseTypeFromExpr(*subscript.subscript_exprs().begin());
-             return OptionalType::create(elem_type);
-           }},
-          {"Future",
-           [](Subscript subscript) -> TypePtr {
-             if (subscript.subscript_exprs().size() != 1) {
-               throw ErrorReport(subscript)
-                   << " expected exactly one element type but found "
-                   << subscript.subscript_exprs().size();
-             }
-             auto elem_type =
-                 parseTypeFromExpr(*subscript.subscript_exprs().begin());
-             return FutureType::create(elem_type);
-           }},
-          {"Dict",
-           [](Subscript subscript) -> TypePtr {
-             if (subscript.subscript_exprs().size() != 2) {
-               throw ErrorReport(subscript)
-                   << " expected exactly 2 element types but found "
-                   << subscript.subscript_exprs().size();
-             }
-             auto key_type = parseTypeFromExpr(subscript.subscript_exprs()[0]);
-             auto value_type =
-                 parseTypeFromExpr(subscript.subscript_exprs()[1]);
-             return DictType::create(key_type, value_type);
-           }},
-      };
-  return map;
-}
-
 bool isTorch(const Expr& expr) {
   return expr.kind() == TK_VAR && Var(expr).name().name() == "torch";
 }
+} // namespace
 
-c10::optional<std::pair<TypePtr, int32_t>> parseBroadcastList(
-    const Expr& expr) {
+TypePtr ScriptTypeParser::subscriptToType(
+    const std::string& typeName,
+    const Subscript& subscript) const {
+  if (typeName == "Tuple") {
+    std::vector<TypePtr> subscript_expr_types;
+    for (auto expr : subscript.subscript_exprs()) {
+      subscript_expr_types.push_back(parseTypeFromExpr(expr));
+    }
+    return TupleType::create(subscript_expr_types);
+  } else if (typeName == "List") {
+    if (subscript.subscript_exprs().size() != 1) {
+      throw ErrorReport(subscript)
+          << " expected exactly one element type but found "
+          << subscript.subscript_exprs().size();
+    }
+    auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
+    return ListType::create(elem_type);
+
+  } else if (typeName == "Optional") {
+    if (subscript.subscript_exprs().size() != 1) {
+      throw ErrorReport(subscript)
+          << " expected exactly one element type but found "
+          << subscript.subscript_exprs().size();
+    }
+    auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
+    return OptionalType::create(elem_type);
+
+  } else if (typeName == "Future") {
+    if (subscript.subscript_exprs().size() != 1) {
+      throw ErrorReport(subscript)
+          << " expected exactly one element type but found "
+          << subscript.subscript_exprs().size();
+    }
+    auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
+    return FutureType::create(elem_type);
+  } else if (typeName == "Dict") {
+    if (subscript.subscript_exprs().size() != 2) {
+      throw ErrorReport(subscript)
+          << " expected exactly 2 element types but found "
+          << subscript.subscript_exprs().size();
+    }
+    auto key_type = parseTypeFromExpr(subscript.subscript_exprs()[0]);
+    auto value_type = parseTypeFromExpr(subscript.subscript_exprs()[1]);
+    return DictType::create(key_type, value_type);
+  } else {
+    throw ErrorReport(subscript.range())
+        << "Unknown type constructor " << typeName;
+  }
+}
+
+c10::optional<std::pair<TypePtr, int32_t>> ScriptTypeParser::parseBroadcastList(
+    const Expr& expr) const {
   if (expr.kind() != TK_SUBSCRIPT)
     return c10::nullopt;
   auto subscript = Subscript(expr);
@@ -114,7 +102,8 @@ c10::optional<std::pair<TypePtr, int32_t>> parseBroadcastList(
 
   if (subscript_exprs.size() != 1)
     throw ErrorReport(subscript.subscript_exprs().range())
-        << "BroadcastingList/Optional[BroadcastingList] must be subscripted with a type";
+        << "BroadcastingList/Optional[BroadcastingList] "
+           "must be subscripted with a type";
 
   auto typ = subscript_exprs[0];
   auto len = var.name().name().substr(strlen("BroadcastingList"));
@@ -144,7 +133,8 @@ c10::optional<std::pair<TypePtr, int32_t>> parseBroadcastList(
 
 // gets the base type name given namespaces where the types live
 // turns torch.Tensor -> Tensor, X -> X
-c10::optional<std::string> parseBaseTypeName(const Expr& expr) {
+c10::optional<std::string> ScriptTypeParser::parseBaseTypeName(
+    const Expr& expr) const {
   switch (expr.kind()) {
     case TK_VAR: {
       return Var(expr).name().name();
@@ -162,7 +152,7 @@ c10::optional<std::string> parseBaseTypeName(const Expr& expr) {
   return at::nullopt;
 }
 
-TypePtr parseTypeFromExpr(const Expr& expr) {
+TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {
   if (expr.kind() == TK_SUBSCRIPT) {
     auto subscript = Subscript(expr);
     auto value_name = parseBaseTypeName(subscript.value());
@@ -170,11 +160,7 @@ TypePtr parseTypeFromExpr(const Expr& expr) {
       throw ErrorReport(subscript.value().range())
           << "Subscripted type must be a type identifier";
     }
-    if (!subscript_to_type_fns().count(*value_name)) {
-      throw ErrorReport(subscript.range())
-          << "Unknown type constructor " << *value_name;
-    }
-    return subscript_to_type_fns().at(*value_name)(subscript);
+    return subscriptToType(*value_name, subscript);
   } else if (auto name = parseBaseTypeName(expr)) {
     auto itr = ident_to_type_lut().find(*name);
     if (itr != ident_to_type_lut().end()) {
@@ -190,7 +176,7 @@ TypePtr parseTypeFromExpr(const Expr& expr) {
       << " cannot be used in a type expression";
 }
 
-TypePtr parseType(const std::string& str) {
+TypePtr ScriptTypeParser::parseType(const std::string& str) {
   Parser p(str);
   return parseTypeFromExpr(p.parseExp());
 }
index 65c0ca5..4eec6e6 100644 (file)
@@ -2,16 +2,33 @@
 #include <ATen/core/jit_type.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/script/parser.h>
-
+#include <torch/csrc/jit/script/tree_views.h>
 namespace torch {
 namespace jit {
 namespace script {
-struct Expr;
-TORCH_API c10::optional<std::string> parseBaseTypeName(const Expr& expr);
-TORCH_API c10::TypePtr parseTypeFromExpr(const Expr& expr);
-TORCH_API c10::optional<std::pair<c10::TypePtr, int32_t>> parseBroadcastList(
-    const Expr& expr);
-TORCH_API c10::TypePtr parseType(const std::string& str);
+
+/**
+ * class ScriptTypeParser
+ *
+ * Parses expressions in our typed AST format (TreeView) into types and
+ * typenames.
+ */
+class TORCH_API ScriptTypeParser {
+ public:
+  c10::optional<std::string> parseBaseTypeName(const Expr& expr) const;
+
+  c10::TypePtr parseTypeFromExpr(const Expr& expr) const;
+
+  c10::optional<std::pair<c10::TypePtr, int32_t>> parseBroadcastList(
+      const Expr& expr) const;
+
+  c10::TypePtr parseType(const std::string& str);
+
+ private:
+  at::TypePtr subscriptToType(
+      const std::string& typeName,
+      const Subscript& subscript) const;
+};
 } // namespace script
 } // namespace jit
 } // namespace torch