Refactor Type Parser b/w Schemas & IRParser into a type common parser (#17383)
authorElias Ellison <eellison@fb.com>
Fri, 22 Feb 2019 21:34:48 +0000 (13:34 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Feb 2019 21:43:55 +0000 (13:43 -0800)
Summary:
Creates a new shared type parser to be shared between the IR parser and the Schema Parser.

Also adds parsing of CompleteTensorType and DimensionedTensorType, and feature-gates that for the IRParser.

Renames the existing type_parser for python annotations, python_type_parser, and names the new one jit_type_parser.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17383

Differential Revision: D14186438

Pulled By: eellison

fbshipit-source-id: bbd5e337917d8862c7c6fa0a0006efa101c76afe

test/cpp/jit/test_irparser.h
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/irparser.cpp
torch/csrc/jit/operator.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/schema_type_parser.cpp [new file with mode: 0644]
torch/csrc/jit/script/schema_type_parser.h [new file with mode: 0644]
torch/csrc/jit/script/script_type_parser.cpp [moved from torch/csrc/jit/script/type_parser.cpp with 90% similarity]
torch/csrc/jit/script/script_type_parser.h [moved from torch/csrc/jit/script/type_parser.h with 100% similarity]
torch/csrc/jit/script/sugared_value.cpp

index cdd3263..b12da1f 100644 (file)
@@ -150,6 +150,67 @@ graph(%0 : Tensor,
   return (%7)
 )IR");
   }
+
+  {
+    checkRoundtrip(
+        R"IR(
+graph(%0 : Tensor,
+      %1 : Tensor,
+      %2 : Tensor):
+  %3 : int? = prim::Constant()
+  return (%3)
+)IR");
+  }
+
+  {
+    checkRoundtrip(
+        R"IR(
+graph(%0 : Tensor,
+      %1 : Tensor,
+      %2 : Tensor):
+  %3 : Float(*, *, *) = prim::Constant()
+  return (%3)
+)IR");
+  }
+
+  {
+    checkRoundtrip(
+        R"IR(
+graph(%0 : Tensor,
+      %1 : Tensor,
+      %2 : Tensor):
+  %3 : Long() = prim::Constant()
+  return (%3)
+)IR");
+  }
+
+  {
+    checkRoundtrip(
+        R"IR(
+graph(%0 : Tensor,
+      %1 : Tensor,
+      %2 : Tensor):
+  %3 : Double(4, 4, 5) = prim::Constant()
+  return (%3)
+)IR");
+  }
+
+  {
+    bool error_thrown = false;
+    try {
+      checkRoundtrip(
+          R"IR(
+graph(%0 : Tensor,
+    %1 : Tensor,
+    %2 : Tensor):
+  %3 : Double(4!, 4, 5) = prim::Constant()
+  return (%3)
+)IR");
+    } catch (const std::exception& error) {
+      error_thrown = true;
+    }
+    AT_ASSERT(error_thrown);
+  }
 }
 } // namespace jit
 } // namespace torch
index 4273e1b..15846af 100644 (file)
@@ -94,7 +94,8 @@ libtorch_sources = [
     "torch/csrc/jit/script/compiler.cpp",
     "torch/csrc/jit/script/edit_distance.cpp",
     "torch/csrc/jit/script/final_returns.cpp",
-    "torch/csrc/jit/script/type_parser.cpp",
+    "torch/csrc/jit/script/schema_type_parser.cpp",
+    "torch/csrc/jit/script/script_type_parser.cpp",
     "torch/csrc/jit/script/sugared_value.cpp",
     "torch/csrc/jit/script/schema_matching.cpp",
     "torch/csrc/jit/script/parser.cpp",
index 3faf02c..2f71b24 100644 (file)
@@ -174,7 +174,8 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
-  ${TORCH_SRC_DIR}/csrc/jit/script/type_parser.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/script/schema_type_parser.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/script/script_type_parser.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/parser.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
index 24e23b1..0b64bc1 100644 (file)
@@ -2,6 +2,7 @@
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/script/lexer.h>
 #include <torch/csrc/jit/script/parse_string_literal.h>
+#include <torch/csrc/jit/script/schema_type_parser.h>
 
 #include <string>
 #include <vector>
@@ -16,7 +17,9 @@ struct ParsedLiteral;
 class IRParser {
   friend void parseIR(const std::string& str, torch::jit::Graph* graph);
   IRParser(const std::string& str, torch::jit::Graph* graph)
-      : L(str), g(graph) {}
+      : L(str),
+        g(graph),
+        type_parser(L, /*parse_complete_tensor_types*/ true) {}
 
   std::string parseVar();
   VarWithType parseVarWithType();
@@ -48,6 +51,7 @@ class IRParser {
   torch::jit::script::Lexer L;
   torch::jit::Graph* g = nullptr;
   std::unordered_map<std::string, Value*> vmap;
+  SchemaTypeParser type_parser;
 };
 
 struct ParsedLiteral {
@@ -66,7 +70,7 @@ struct ParsedLiteral {
 struct VarWithType {
   VarWithType() = default;
   std::string name;
-  std::string type;
+  TypePtr type;
 };
 
 void parseIR(const std::string& str, torch::jit::Graph* graph) {
@@ -74,23 +78,6 @@ void parseIR(const std::string& str, torch::jit::Graph* graph) {
   p.parse();
 }
 
-TypePtr parseType(const std::string& s) {
-  if (s == "Tensor") {
-    return TensorType::get();
-  }
-  if (s == "int") {
-    return IntType::get();
-  }
-  if (s == "float") {
-    return FloatType::get();
-  }
-  if (s == "string") {
-    return StringType::get();
-  }
-  // TODO: Support other types.
-  AT_ASSERTM(false, "Type not supported by parser:", s);
-}
-
 VarWithType IRParser::parseVarWithType() {
   L.expect('%');
   VarWithType r;
@@ -99,9 +86,11 @@ VarWithType IRParser::parseVarWithType() {
   } else {
     r.name = L.expect(TK_NUMBER).text();
   }
-  r.type = "Tensor";
+  r.type = TensorType::get();
   if (L.nextIf(':')) {
-    r.type = L.expect(TK_IDENT).text();
+    auto type_alias = type_parser.parseType();
+    AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
+    r.type = type_alias.first;
   }
   return r;
 }
@@ -268,7 +257,7 @@ void IRParser::parseBlockInputs(Block* b) {
     // If the name isn't valid, don't use it
     std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
     vmap[v.name] = b->addInput(uniq_name);
-    vmap[v.name]->setType(parseType(v.type));
+    vmap[v.name]->setType(v.type);
   });
 }
 
@@ -345,7 +334,7 @@ void IRParser::parseOperator(Block* b) {
   int idx = 0;
   for (const VarWithType& v : outs) {
     vmap[v.name] = n->outputs()[idx++];
-    vmap[v.name]->setType(parseType(v.type));
+    vmap[v.name]->setType(v.type);
   }
 
   // Insert the new node into block B.
@@ -364,7 +353,7 @@ void IRParser::parseGraphInputs() {
     // If the name isn't valid, don't use it
     std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
     vmap[v.name] = g->addInput(uniq_name);
-    vmap[v.name]->setType(parseType(v.type));
+    vmap[v.name]->setType(v.type);
   });
 }
 
@@ -433,7 +422,6 @@ void IRParser::parseList(
     L.expect(end);
   }
 }
-
 } // namespace script
 } // namespace jit
 } // namespace torch
index ff2d404..b452358 100644 (file)
@@ -1,13 +1,14 @@
+#include <torch/csrc/jit/operator.h>
 #include <ATen/ATen.h>
 #include <torch/csrc/jit/alias_info.h>
-#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
+#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/script/edit_distance.h>
 #include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/jit/script/lexer.h>
 #include <torch/csrc/jit/script/parse_string_literal.h>
+#include <torch/csrc/jit/script/schema_type_parser.h>
 #include <torch/csrc/jit/script/tree.h>
-#include <torch/csrc/jit/script/edit_distance.h>
 
 #include <functional>
 #include <memory>
@@ -20,7 +21,8 @@ namespace jit {
 
 namespace script {
 struct SchemaParser {
-  SchemaParser(const std::string& str) : L(str) {}
+  SchemaParser(const std::string& str)
+      : L(str), type_parser(L, /*parse_complete_tensor_types*/ false) {}
 
   FunctionSchema parseDeclaration() {
     auto name = L.expect(TK_IDENT).text();
@@ -73,131 +75,10 @@ struct SchemaParser {
   TreeRef parseIdent() {
     return String::create(L.expect(TK_IDENT).text());
   }
-  using TypeAndAlias = std::pair<TypePtr, c10::optional<AliasInfo>>;
-  TypeAndAlias parseBaseType() {
-    static std::unordered_map<std::string, TypePtr> type_map = {
-        {"Generator", GeneratorType::get()},
-        {"ScalarType", IntType::get()},
-        {"Layout", IntType::get()},
-        {"Device", DeviceObjType::get()},
-        {"Scalar", NumberType::get()},
-        {"str", StringType::get()},
-        {"float", FloatType::get()},
-        {"int", IntType::get()},
-        {"bool", BoolType::get()},
-    };
-    auto tok = L.expect(TK_IDENT);
-    auto text = tok.text();
-    auto it = type_map.find(text);
-    if (it == type_map.end()) {
-      if (text.size() > 0 && islower(text[0])) {
-        // lower case identifiers that are not otherwise valid types
-        // are treated as type variables
-        return TypeAndAlias(VarType::create(text), parseAliasAnnotation());
-      }
-      throw ErrorReport(tok.range) << "unknown type specifier";
-    }
-    return TypeAndAlias(it->second, c10::nullopt);
-  }
-  // Examples:
-  // Tensor(a) // Tensor is in set a
-  // Tensor(a!) // it is also written to
-  // Tensor!  // shorthand for Tensor(fresh_identifier!)
-  // Tensor(a! -> a|b) // Tensor is in set a, written to,
-  //                      and after the write is in set a AND b.
-  c10::optional<AliasInfo> parseAliasAnnotation() {
-    std::set<Symbol> sets;
-    AliasInfo alias_info;
-    if (L.nextIf('(')) {
-      // optional 'alias set annotation'
-      parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
-        if (L.nextIf('*')) {
-          alias_info = AliasInfo::createWildcard();
-
-          // If we found a wildcard, ignore all subsequent annotations
-        } else if (!alias_info.isWildcard()) {
-          alias_info.addSet(
-              Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
-        }
-      });
-      if (L.nextIf('!')) {
-        alias_info.setIsWrite(true);
-      }
-      L.expect(')');
-    } else if (L.nextIf('!')) {
-      alias_info.addSet(
-          Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
-      alias_info.setIsWrite(true);
-    } else {
-      return c10::nullopt;
-    }
-
-    return alias_info;
-  }
-
-  std::pair<TypePtr, c10::optional<AliasInfo>> parseType() {
-    TypePtr value;
-    c10::optional<AliasInfo> alias_info;
-    // Tuple type
-    if (L.cur().kind == '(') {
-      std::vector<TypePtr> types;
-      parseList('(', ',', ')', [&] {
-        auto r = parseType();
-        types.push_back(std::move(r.first));
-        if (alias_info && r.second) {
-          alias_info->addContainedType(std::move(*r.second));
-        }
-      });
-      value = TupleType::create(std::move(types));
-    } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
-      L.next(); // Future
-      L.expect('(');
-      auto p = parseType();
-      auto subtype = std::move(p.first);
-      auto subalias = std::move(p.second);
-      L.expect(')');
-      value = FutureType::create(subtype);
-    } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
-      L.next();
-      value = TensorType::get();
-      alias_info = parseAliasAnnotation();
-    } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
-      L.next();
-      L.expect('(');
-      auto key_type = parseType().first;
-      L.expect(',');
-      auto value_type = parseType().first;
-      L.expect(')');
-      alias_info = parseAliasAnnotation();
-
-      value = DictType::create(key_type, value_type);
-    } else {
-      auto value_alias = parseBaseType();
-      value = value_alias.first;
-      alias_info = value_alias.second;
-    }
-    while (true) {
-      if (L.cur().kind == '[' && L.lookahead().kind == ']') {
-        L.next(); // [
-        L.next(); // ]
-        value = ListType::create(value);
-        auto container = parseAliasAnnotation();
-        if (container && alias_info) {
-          container->addContainedType(std::move(*alias_info));
-        }
-        alias_info = std::move(container);
-      } else if (L.nextIf('?')) {
-        value = OptionalType::create(value);
-      } else {
-        break;
-      }
-    }
-    return std::make_pair(std::move(value), std::move(alias_info));
-  }
 
   Argument parseArgument(size_t idx, bool is_return, bool kwarg_only) {
     Argument result;
-    auto p = parseType();
+    auto p = type_parser.parseType();
     auto type = std::move(p.first);
     auto alias_info = std::move(p.second);
     c10::optional<int32_t> N;
@@ -209,7 +90,7 @@ struct SchemaParser {
       type = ListType::create(type);
       N = std::stoll(L.expect(TK_NUMBER).text());
       L.expect(']');
-      auto container = parseAliasAnnotation();
+      auto container = type_parser.parseAliasAnnotation();
       if (container && alias_info) {
         container->addContainedType(std::move(*alias_info));
       }
@@ -369,9 +250,8 @@ struct SchemaParser {
       L.expect(end);
   }
   Lexer L;
-  size_t next_id = 0;
+  SchemaTypeParser type_parser;
 };
-
 } // namespace script
 
 namespace {
@@ -457,7 +337,8 @@ struct OperatorRegistry {
       return lhs.first > rhs.first;
     };
 
-    std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)> rankings(cmp);
+    std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)>
+        rankings(cmp);
     static constexpr size_t MAX_EDIT_DIST = 2u;
     for (const auto& op : operators) {
       auto edit_dist = script::ComputeEditDistance(
@@ -479,7 +360,6 @@ OperatorRegistry& getRegistry() {
   static OperatorRegistry r;
   return r;
 }
-
 } // anonymous namespace
 
 void registerOperator(Operator&& op) {
@@ -641,6 +521,5 @@ Operator* OperatorSet::find(const Node* n) const {
   }
   return nullptr;
 }
-
 } // namespace jit
 } // namespace torch
index 1bda8dd..79e23f5 100644 (file)
@@ -1,3 +1,4 @@
+#include <torch/csrc/jit/script/compiler.h>
 #include <c10/util/Exception.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
 #include <torch/csrc/jit/interpreter.h>
@@ -5,11 +6,10 @@
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/lower_tuples.h>
-#include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/script/final_returns.h>
 #include <torch/csrc/jit/script/parser.h>
 #include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/script/type_parser.h>
+#include <torch/csrc/jit/script/script_type_parser.h>
 #include <torch/csrc/utils/object_ptr.h>
 
 #include <torch/csrc/jit/constants.h>
@@ -2696,7 +2696,6 @@ void lambdaLiftFork(Node* fork_node) {
   fork_node->g_(attr::Subgraph, forked_graph);
   fork_node->eraseBlock(0);
 }
-
 } // namespace script
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/script/schema_type_parser.cpp b/torch/csrc/jit/script/schema_type_parser.cpp
new file mode 100644 (file)
index 0000000..c71c0b6
--- /dev/null
@@ -0,0 +1,204 @@
+#include <torch/csrc/jit/script/schema_type_parser.h>
+#include <ATen/core/interned_strings.h>
+#include <torch/csrc/jit/alias_info.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/lexer.h>
+#include <torch/csrc/jit/script/parse_string_literal.h>
+#include <string>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+TypeAndAlias SchemaTypeParser::parseBaseType() {
+  static std::unordered_map<std::string, TypePtr> type_map = {
+      {"Generator", GeneratorType::get()},
+      {"ScalarType", IntType::get()},
+      {"Layout", IntType::get()},
+      {"Device", DeviceObjType::get()},
+      {"Scalar", NumberType::get()},
+      {"str", StringType::get()},
+      {"float", FloatType::get()},
+      {"int", IntType::get()},
+      {"bool", BoolType::get()},
+  };
+  auto tok = L.expect(TK_IDENT);
+  auto text = tok.text();
+  auto it = type_map.find(text);
+  if (it == type_map.end()) {
+    if (text.size() > 0 && islower(text[0])) {
+      // lower case identifiers that are not otherwise valid types
+      // are treated as type variables
+      return TypeAndAlias(VarType::create(text), parseAliasAnnotation());
+    }
+    throw ErrorReport(tok.range) << "unknown type specifier";
+  }
+  return TypeAndAlias(it->second, c10::nullopt);
+}
+
+// Examples:
+// Tensor(a) // Tensor is in set a
+// Tensor(a!) // it is also written to
+// Tensor!  // shorthand for Tensor(fresh_identifier!)
+// Tensor(a! -> a|b) // Tensor is in set a, written to,
+//                      and after the write is in set a AND b.
+c10::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
+  std::set<Symbol> sets;
+  AliasInfo alias_info;
+  if (L.nextIf('(')) {
+    // optional 'alias set annotation'
+    parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
+      if (L.nextIf('*')) {
+        alias_info = AliasInfo::createWildcard();
+
+        // If we found a wildcard, ignore all subsequent annotations
+      } else if (!alias_info.isWildcard()) {
+        alias_info.addSet(
+            Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
+      }
+    });
+    if (L.nextIf('!')) {
+      alias_info.setIsWrite(true);
+    }
+    L.expect(')');
+  } else if (L.nextIf('!')) {
+    alias_info.addSet(
+        Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
+    alias_info.setIsWrite(true);
+  } else {
+    return c10::nullopt;
+  }
+
+  return alias_info;
+}
+
+c10::optional<at::ScalarType> SchemaTypeParser::parseTensorDType(
+    const std::string& dtype) {
+#define DEFINE_SCALAR_TYPE(_1, n, _2) {#n, at::ScalarType::n},
+
+  static std::unordered_map<std::string, at::ScalarType> type_map = {
+      AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)};
+
+  auto type = type_map.find(dtype);
+  if (type != type_map.end()) {
+    return type->second;
+  }
+  return c10::nullopt;
+}
+
+TypePtr SchemaTypeParser::parseRefinedTensor() {
+  auto maybe_dtype = parseTensorDType(L.expect(TK_IDENT).text());
+  AT_ASSERT(maybe_dtype);
+  at::ScalarType dtype = *maybe_dtype;
+  TypePtr ptr;
+  L.expect('(');
+  TypePtr tensor_type;
+  if (L.cur().kind == '*') {
+    size_t num_dims = 0;
+    parseList(TK_NOTHING, ',', ')', [&] {
+      L.expect('*');
+      num_dims++;
+    });
+    ptr = DimensionedTensorType::create(dtype, at::DeviceType::CPU, num_dims);
+  } else {
+    std::vector<int64_t> dims;
+    parseList(TK_NOTHING, ',', ')', [&] {
+      const std::string& num = L.expect(TK_NUMBER).text();
+      std::string::size_type num_len;
+      size_t dim = std::stoi(num, &num_len);
+      AT_ASSERTM(
+          num_len == num.size(),
+          "Bad tensor dimension size. Strides not yet supported in parsing",
+          num);
+      dims.push_back(dim);
+    });
+    at::IntArrayRef dims_ref(dims);
+    ptr =
+        CompleteTensorType::create(dtype, at::DeviceType::CPU, dims_ref, false);
+  }
+  return ptr;
+}
+
+std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
+  TypePtr value;
+  c10::optional<AliasInfo> alias_info;
+  // Tuple type
+  if (L.cur().kind == '(') {
+    std::vector<TypePtr> types;
+    parseList('(', ',', ')', [&] {
+      auto r = parseType();
+      types.push_back(std::move(r.first));
+      if (alias_info && r.second) {
+        alias_info->addContainedType(std::move(*r.second));
+      }
+    });
+    value = TupleType::create(std::move(types));
+  } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
+    L.next(); // Future
+    L.expect('(');
+    auto p = parseType();
+    auto subtype = std::move(p.first);
+    auto subalias = std::move(p.second);
+    L.expect(')');
+    value = FutureType::create(subtype);
+  } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
+    L.next();
+    value = TensorType::get();
+    alias_info = parseAliasAnnotation();
+  } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
+    L.next();
+    L.expect('(');
+    auto key_type = parseType().first;
+    L.expect(',');
+    auto value_type = parseType().first;
+    L.expect(')');
+    alias_info = parseAliasAnnotation();
+    value = DictType::create(key_type, value_type);
+  } else if (
+      complete_tensor_types && L.cur().kind == TK_IDENT &&
+      parseTensorDType(L.cur().text())) {
+    value = parseRefinedTensor();
+    alias_info = parseAliasAnnotation();
+  } else {
+    auto value_alias = parseBaseType();
+    value = value_alias.first;
+    alias_info = value_alias.second;
+  }
+  while (true) {
+    if (L.cur().kind == '[' && L.lookahead().kind == ']') {
+      L.next(); // [
+      L.next(); // ]
+      value = ListType::create(value);
+      auto container = parseAliasAnnotation();
+      if (container && alias_info) {
+        container->addContainedType(std::move(*alias_info));
+      }
+      alias_info = std::move(container);
+    } else if (L.nextIf('?')) {
+      value = OptionalType::create(value);
+    } else {
+      break;
+    }
+  }
+  return std::make_pair(std::move(value), std::move(alias_info));
+}
+
+void SchemaTypeParser::parseList(
+    int begin,
+    int sep,
+    int end,
+    const std::function<void()>& callback) {
+  auto r = L.cur().range;
+  if (begin != TK_NOTHING)
+    L.expect(begin);
+  if (L.cur().kind != end) {
+    do {
+      callback();
+    } while (L.nextIf(sep));
+  }
+  if (end != TK_NOTHING)
+    L.expect(end);
+}
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/schema_type_parser.h b/torch/csrc/jit/script/schema_type_parser.h
new file mode 100644 (file)
index 0000000..5677626
--- /dev/null
@@ -0,0 +1,37 @@
+#include <ATen/ATen.h>
+#include <ATen/core/jit_type.h>
+#include <torch/csrc/jit/alias_info.h>
+#include <torch/csrc/jit/script/lexer.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+using TypePtr = c10::TypePtr;
+using TypeAndAlias = std::pair<TypePtr, c10::optional<AliasInfo>>;
+
+struct SchemaTypeParser {
+  TypeAndAlias parseBaseType();
+  c10::optional<AliasInfo> parseAliasAnnotation();
+  std::pair<TypePtr, c10::optional<AliasInfo>> parseType();
+  c10::optional<at::ScalarType> parseTensorDType(const std::string& dtype);
+  TypePtr parseRefinedTensor();
+
+  SchemaTypeParser(Lexer& L, bool parse_complete_tensor_types) : L(L) {
+    complete_tensor_types = parse_complete_tensor_types;
+  }
+
+ private:
+  void parseList(
+      int begin,
+      int sep,
+      int end,
+      const std::function<void()>& callback);
+
+  bool complete_tensor_types;
+  Lexer& L;
+  size_t next_id = 0;
+};
+} // namespace script
+} // namespace jit
+} // namespace torch
similarity index 90%
rename from torch/csrc/jit/script/type_parser.cpp
rename to torch/csrc/jit/script/script_type_parser.cpp
index fb5726f..aadc0a7 100644 (file)
@@ -1,6 +1,6 @@
+#include <torch/csrc/jit/script/script_type_parser.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/script/tree_views.h>
-#include <torch/csrc/jit/script/type_parser.h>
 
 namespace torch {
 namespace jit {
@@ -67,19 +67,18 @@ subscript_to_type_fns() {
                  parseTypeFromExpr(*subscript.subscript_exprs().begin());
              return FutureType::create(elem_type);
            }},
-           {"Dict",
-            [](Subscript subscript) -> TypePtr {
-              if (subscript.subscript_exprs().size() != 2) {
-                throw ErrorReport(subscript)
-                    << " expected exactly 2 element types but found "
-                    << subscript.subscript_exprs().size();
-              }
-              auto key_type =
-                  parseTypeFromExpr(subscript.subscript_exprs()[0]);
-              auto value_type =
-                  parseTypeFromExpr(subscript.subscript_exprs()[1]);
-              return DictType::create(key_type, value_type);
-            }},
+          {"Dict",
+           [](Subscript subscript) -> TypePtr {
+             if (subscript.subscript_exprs().size() != 2) {
+               throw ErrorReport(subscript)
+                   << " expected exactly 2 element types but found "
+                   << subscript.subscript_exprs().size();
+             }
+             auto key_type = parseTypeFromExpr(subscript.subscript_exprs()[0]);
+             auto value_type =
+                 parseTypeFromExpr(subscript.subscript_exprs()[1]);
+             return DictType::create(key_type, value_type);
+           }},
       };
   return map;
 }
index 282d2fc..1ebb98e 100644 (file)
@@ -1,8 +1,7 @@
+#include <torch/csrc/jit/script/sugared_value.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/script/sugared_value.h>
 #include <torch/csrc/jit/script/tree_views.h>
-#include <torch/csrc/jit/script/type_parser.h>
 
 namespace torch {
 namespace jit {
@@ -142,7 +141,6 @@ std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
   throw ErrorReport(loc) << value->type()->str()
                          << " cannot be used as a tuple";
 }
-
 } // namespace script
 } // namespace jit
 } // namespace torch