return (%7)
)IR");
}
+
+ {
+ checkRoundtrip(
+ R"IR(
+graph(%0 : Tensor,
+ %1 : Tensor,
+ %2 : Tensor):
+ %3 : int? = prim::Constant()
+ return (%3)
+)IR");
+ }
+
+ {
+ checkRoundtrip(
+ R"IR(
+graph(%0 : Tensor,
+ %1 : Tensor,
+ %2 : Tensor):
+ %3 : Float(*, *, *) = prim::Constant()
+ return (%3)
+)IR");
+ }
+
+ {
+ checkRoundtrip(
+ R"IR(
+graph(%0 : Tensor,
+ %1 : Tensor,
+ %2 : Tensor):
+ %3 : Long() = prim::Constant()
+ return (%3)
+)IR");
+ }
+
+ {
+ checkRoundtrip(
+ R"IR(
+graph(%0 : Tensor,
+ %1 : Tensor,
+ %2 : Tensor):
+ %3 : Double(4, 4, 5) = prim::Constant()
+ return (%3)
+)IR");
+ }
+
+ {
+ bool error_thrown = false;
+ try {
+ checkRoundtrip(
+ R"IR(
+graph(%0 : Tensor,
+ %1 : Tensor,
+ %2 : Tensor):
+ %3 : Double(4!, 4, 5) = prim::Constant()
+ return (%3)
+)IR");
+ } catch (const std::exception& error) {
+ error_thrown = true;
+ }
+ AT_ASSERT(error_thrown);
+ }
}
} // namespace jit
} // namespace torch
"torch/csrc/jit/script/compiler.cpp",
"torch/csrc/jit/script/edit_distance.cpp",
"torch/csrc/jit/script/final_returns.cpp",
- "torch/csrc/jit/script/type_parser.cpp",
+ "torch/csrc/jit/script/schema_type_parser.cpp",
+ "torch/csrc/jit/script/script_type_parser.cpp",
"torch/csrc/jit/script/sugared_value.cpp",
"torch/csrc/jit/script/schema_matching.cpp",
"torch/csrc/jit/script/parser.cpp",
${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
- ${TORCH_SRC_DIR}/csrc/jit/script/type_parser.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/script/schema_type_parser.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/script/script_type_parser.cpp
${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp
${TORCH_SRC_DIR}/csrc/jit/script/parser.cpp
${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/lexer.h>
#include <torch/csrc/jit/script/parse_string_literal.h>
+#include <torch/csrc/jit/script/schema_type_parser.h>
#include <string>
#include <vector>
class IRParser {
friend void parseIR(const std::string& str, torch::jit::Graph* graph);
IRParser(const std::string& str, torch::jit::Graph* graph)
- : L(str), g(graph) {}
+ : L(str),
+ g(graph),
+ type_parser(L, /*parse_complete_tensor_types*/ true) {}
std::string parseVar();
VarWithType parseVarWithType();
torch::jit::script::Lexer L;
torch::jit::Graph* g = nullptr;
std::unordered_map<std::string, Value*> vmap;
+ SchemaTypeParser type_parser;
};
struct ParsedLiteral {
struct VarWithType {
VarWithType() = default;
std::string name;
- std::string type;
+ TypePtr type;
};
void parseIR(const std::string& str, torch::jit::Graph* graph) {
p.parse();
}
-TypePtr parseType(const std::string& s) {
- if (s == "Tensor") {
- return TensorType::get();
- }
- if (s == "int") {
- return IntType::get();
- }
- if (s == "float") {
- return FloatType::get();
- }
- if (s == "string") {
- return StringType::get();
- }
- // TODO: Support other types.
- AT_ASSERTM(false, "Type not supported by parser:", s);
-}
-
VarWithType IRParser::parseVarWithType() {
L.expect('%');
VarWithType r;
} else {
r.name = L.expect(TK_NUMBER).text();
}
- r.type = "Tensor";
+ r.type = TensorType::get();
if (L.nextIf(':')) {
- r.type = L.expect(TK_IDENT).text();
+ auto type_alias = type_parser.parseType();
+ AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
+ r.type = type_alias.first;
}
return r;
}
// If the name isn't valid, don't use it
std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
vmap[v.name] = b->addInput(uniq_name);
- vmap[v.name]->setType(parseType(v.type));
+ vmap[v.name]->setType(v.type);
});
}
int idx = 0;
for (const VarWithType& v : outs) {
vmap[v.name] = n->outputs()[idx++];
- vmap[v.name]->setType(parseType(v.type));
+ vmap[v.name]->setType(v.type);
}
// Insert the new node into block B.
// If the name isn't valid, don't use it
std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
vmap[v.name] = g->addInput(uniq_name);
- vmap[v.name]->setType(parseType(v.type));
+ vmap[v.name]->setType(v.type);
});
}
L.expect(end);
}
}
-
} // namespace script
} // namespace jit
} // namespace torch
+#include <torch/csrc/jit/operator.h>
#include <ATen/ATen.h>
#include <torch/csrc/jit/alias_info.h>
-#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
+#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/script/edit_distance.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/lexer.h>
#include <torch/csrc/jit/script/parse_string_literal.h>
+#include <torch/csrc/jit/script/schema_type_parser.h>
#include <torch/csrc/jit/script/tree.h>
-#include <torch/csrc/jit/script/edit_distance.h>
#include <functional>
#include <memory>
namespace script {
struct SchemaParser {
- SchemaParser(const std::string& str) : L(str) {}
+ SchemaParser(const std::string& str)
+ : L(str), type_parser(L, /*parse_complete_tensor_types*/ false) {}
FunctionSchema parseDeclaration() {
auto name = L.expect(TK_IDENT).text();
TreeRef parseIdent() {
return String::create(L.expect(TK_IDENT).text());
}
- using TypeAndAlias = std::pair<TypePtr, c10::optional<AliasInfo>>;
- TypeAndAlias parseBaseType() {
- static std::unordered_map<std::string, TypePtr> type_map = {
- {"Generator", GeneratorType::get()},
- {"ScalarType", IntType::get()},
- {"Layout", IntType::get()},
- {"Device", DeviceObjType::get()},
- {"Scalar", NumberType::get()},
- {"str", StringType::get()},
- {"float", FloatType::get()},
- {"int", IntType::get()},
- {"bool", BoolType::get()},
- };
- auto tok = L.expect(TK_IDENT);
- auto text = tok.text();
- auto it = type_map.find(text);
- if (it == type_map.end()) {
- if (text.size() > 0 && islower(text[0])) {
- // lower case identifiers that are not otherwise valid types
- // are treated as type variables
- return TypeAndAlias(VarType::create(text), parseAliasAnnotation());
- }
- throw ErrorReport(tok.range) << "unknown type specifier";
- }
- return TypeAndAlias(it->second, c10::nullopt);
- }
- // Examples:
- // Tensor(a) // Tensor is in set a
- // Tensor(a!) // it is also written to
- // Tensor! // shorthand for Tensor(fresh_identifier!)
- // Tensor(a! -> a|b) // Tensor is in set a, written to,
- // and after the write is in set a AND b.
- c10::optional<AliasInfo> parseAliasAnnotation() {
- std::set<Symbol> sets;
- AliasInfo alias_info;
- if (L.nextIf('(')) {
- // optional 'alias set annotation'
- parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
- if (L.nextIf('*')) {
- alias_info = AliasInfo::createWildcard();
-
- // If we found a wildcard, ignore all subsequent annotations
- } else if (!alias_info.isWildcard()) {
- alias_info.addSet(
- Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
- }
- });
- if (L.nextIf('!')) {
- alias_info.setIsWrite(true);
- }
- L.expect(')');
- } else if (L.nextIf('!')) {
- alias_info.addSet(
- Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
- alias_info.setIsWrite(true);
- } else {
- return c10::nullopt;
- }
-
- return alias_info;
- }
-
- std::pair<TypePtr, c10::optional<AliasInfo>> parseType() {
- TypePtr value;
- c10::optional<AliasInfo> alias_info;
- // Tuple type
- if (L.cur().kind == '(') {
- std::vector<TypePtr> types;
- parseList('(', ',', ')', [&] {
- auto r = parseType();
- types.push_back(std::move(r.first));
- if (alias_info && r.second) {
- alias_info->addContainedType(std::move(*r.second));
- }
- });
- value = TupleType::create(std::move(types));
- } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
- L.next(); // Future
- L.expect('(');
- auto p = parseType();
- auto subtype = std::move(p.first);
- auto subalias = std::move(p.second);
- L.expect(')');
- value = FutureType::create(subtype);
- } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
- L.next();
- value = TensorType::get();
- alias_info = parseAliasAnnotation();
- } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
- L.next();
- L.expect('(');
- auto key_type = parseType().first;
- L.expect(',');
- auto value_type = parseType().first;
- L.expect(')');
- alias_info = parseAliasAnnotation();
-
- value = DictType::create(key_type, value_type);
- } else {
- auto value_alias = parseBaseType();
- value = value_alias.first;
- alias_info = value_alias.second;
- }
- while (true) {
- if (L.cur().kind == '[' && L.lookahead().kind == ']') {
- L.next(); // [
- L.next(); // ]
- value = ListType::create(value);
- auto container = parseAliasAnnotation();
- if (container && alias_info) {
- container->addContainedType(std::move(*alias_info));
- }
- alias_info = std::move(container);
- } else if (L.nextIf('?')) {
- value = OptionalType::create(value);
- } else {
- break;
- }
- }
- return std::make_pair(std::move(value), std::move(alias_info));
- }
Argument parseArgument(size_t idx, bool is_return, bool kwarg_only) {
Argument result;
- auto p = parseType();
+ auto p = type_parser.parseType();
auto type = std::move(p.first);
auto alias_info = std::move(p.second);
c10::optional<int32_t> N;
type = ListType::create(type);
N = std::stoll(L.expect(TK_NUMBER).text());
L.expect(']');
- auto container = parseAliasAnnotation();
+ auto container = type_parser.parseAliasAnnotation();
if (container && alias_info) {
container->addContainedType(std::move(*alias_info));
}
L.expect(end);
}
Lexer L;
- size_t next_id = 0;
+ SchemaTypeParser type_parser;
};
-
} // namespace script
namespace {
return lhs.first > rhs.first;
};
- std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)> rankings(cmp);
+ std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)>
+ rankings(cmp);
static constexpr size_t MAX_EDIT_DIST = 2u;
for (const auto& op : operators) {
auto edit_dist = script::ComputeEditDistance(
static OperatorRegistry r;
return r;
}
-
} // anonymous namespace
void registerOperator(Operator&& op) {
}
return nullptr;
}
-
} // namespace jit
} // namespace torch
+#include <torch/csrc/jit/script/compiler.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/hooks_for_testing.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
-#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/final_returns.h>
#include <torch/csrc/jit/script/parser.h>
#include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/script/type_parser.h>
+#include <torch/csrc/jit/script/script_type_parser.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/jit/constants.h>
fork_node->g_(attr::Subgraph, forked_graph);
fork_node->eraseBlock(0);
}
-
} // namespace script
} // namespace jit
} // namespace torch
--- /dev/null
+#include <torch/csrc/jit/script/schema_type_parser.h>
+#include <ATen/core/interned_strings.h>
+#include <torch/csrc/jit/alias_info.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/lexer.h>
+#include <torch/csrc/jit/script/parse_string_literal.h>
+#include <string>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+TypeAndAlias SchemaTypeParser::parseBaseType() {
+ static std::unordered_map<std::string, TypePtr> type_map = {
+ {"Generator", GeneratorType::get()},
+ {"ScalarType", IntType::get()},
+ {"Layout", IntType::get()},
+ {"Device", DeviceObjType::get()},
+ {"Scalar", NumberType::get()},
+ {"str", StringType::get()},
+ {"float", FloatType::get()},
+ {"int", IntType::get()},
+ {"bool", BoolType::get()},
+ };
+ auto tok = L.expect(TK_IDENT);
+ auto text = tok.text();
+ auto it = type_map.find(text);
+ if (it == type_map.end()) {
+ if (text.size() > 0 && islower(text[0])) {
+ // lower case identifiers that are not otherwise valid types
+ // are treated as type variables
+ return TypeAndAlias(VarType::create(text), parseAliasAnnotation());
+ }
+ throw ErrorReport(tok.range) << "unknown type specifier";
+ }
+ return TypeAndAlias(it->second, c10::nullopt);
+}
+
+// Examples:
+// Tensor(a) // Tensor is in set a
+// Tensor(a!) // it is also written to
+// Tensor! // shorthand for Tensor(fresh_identifier!)
+// Tensor(a! -> a|b) // Tensor is in set a, written to,
+// and after the write is in set a AND b.
+c10::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
+ std::set<Symbol> sets;
+ AliasInfo alias_info;
+ if (L.nextIf('(')) {
+ // optional 'alias set annotation'
+ parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
+ if (L.nextIf('*')) {
+ alias_info = AliasInfo::createWildcard();
+
+ // If we found a wildcard, ignore all subsequent annotations
+ } else if (!alias_info.isWildcard()) {
+ alias_info.addSet(
+ Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
+ }
+ });
+ if (L.nextIf('!')) {
+ alias_info.setIsWrite(true);
+ }
+ L.expect(')');
+ } else if (L.nextIf('!')) {
+ alias_info.addSet(
+ Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
+ alias_info.setIsWrite(true);
+ } else {
+ return c10::nullopt;
+ }
+
+ return alias_info;
+}
+
+c10::optional<at::ScalarType> SchemaTypeParser::parseTensorDType(
+ const std::string& dtype) {
+#define DEFINE_SCALAR_TYPE(_1, n, _2) {#n, at::ScalarType::n},
+
+ static std::unordered_map<std::string, at::ScalarType> type_map = {
+ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)};
+
+ auto type = type_map.find(dtype);
+ if (type != type_map.end()) {
+ return type->second;
+ }
+ return c10::nullopt;
+}
+
+TypePtr SchemaTypeParser::parseRefinedTensor() {
+ auto maybe_dtype = parseTensorDType(L.expect(TK_IDENT).text());
+ AT_ASSERT(maybe_dtype);
+ at::ScalarType dtype = *maybe_dtype;
+ TypePtr ptr;
+ L.expect('(');
+ TypePtr tensor_type;
+ if (L.cur().kind == '*') {
+ size_t num_dims = 0;
+ parseList(TK_NOTHING, ',', ')', [&] {
+ L.expect('*');
+ num_dims++;
+ });
+ ptr = DimensionedTensorType::create(dtype, at::DeviceType::CPU, num_dims);
+ } else {
+ std::vector<int64_t> dims;
+ parseList(TK_NOTHING, ',', ')', [&] {
+ const std::string& num = L.expect(TK_NUMBER).text();
+ std::string::size_type num_len;
+ size_t dim = std::stoi(num, &num_len);
+ AT_ASSERTM(
+ num_len == num.size(),
+ "Bad tensor dimension size. Strides not yet supported in parsing",
+ num);
+ dims.push_back(dim);
+ });
+ at::IntArrayRef dims_ref(dims);
+ ptr =
+ CompleteTensorType::create(dtype, at::DeviceType::CPU, dims_ref, false);
+ }
+ return ptr;
+}
+
+std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
+ TypePtr value;
+ c10::optional<AliasInfo> alias_info;
+ // Tuple type
+ if (L.cur().kind == '(') {
+ std::vector<TypePtr> types;
+ parseList('(', ',', ')', [&] {
+ auto r = parseType();
+ types.push_back(std::move(r.first));
+ if (alias_info && r.second) {
+ alias_info->addContainedType(std::move(*r.second));
+ }
+ });
+ value = TupleType::create(std::move(types));
+ } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
+ L.next(); // Future
+ L.expect('(');
+ auto p = parseType();
+ auto subtype = std::move(p.first);
+ auto subalias = std::move(p.second);
+ L.expect(')');
+ value = FutureType::create(subtype);
+ } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
+ L.next();
+ value = TensorType::get();
+ alias_info = parseAliasAnnotation();
+ } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
+ L.next();
+ L.expect('(');
+ auto key_type = parseType().first;
+ L.expect(',');
+ auto value_type = parseType().first;
+ L.expect(')');
+ alias_info = parseAliasAnnotation();
+ value = DictType::create(key_type, value_type);
+ } else if (
+ complete_tensor_types && L.cur().kind == TK_IDENT &&
+ parseTensorDType(L.cur().text())) {
+ value = parseRefinedTensor();
+ alias_info = parseAliasAnnotation();
+ } else {
+ auto value_alias = parseBaseType();
+ value = value_alias.first;
+ alias_info = value_alias.second;
+ }
+ while (true) {
+ if (L.cur().kind == '[' && L.lookahead().kind == ']') {
+ L.next(); // [
+ L.next(); // ]
+ value = ListType::create(value);
+ auto container = parseAliasAnnotation();
+ if (container && alias_info) {
+ container->addContainedType(std::move(*alias_info));
+ }
+ alias_info = std::move(container);
+ } else if (L.nextIf('?')) {
+ value = OptionalType::create(value);
+ } else {
+ break;
+ }
+ }
+ return std::make_pair(std::move(value), std::move(alias_info));
+}
+
+void SchemaTypeParser::parseList(
+ int begin,
+ int sep,
+ int end,
+ const std::function<void()>& callback) {
+ auto r = L.cur().range;
+ if (begin != TK_NOTHING)
+ L.expect(begin);
+ if (L.cur().kind != end) {
+ do {
+ callback();
+ } while (L.nextIf(sep));
+ }
+ if (end != TK_NOTHING)
+ L.expect(end);
+}
+} // namespace script
+} // namespace jit
+} // namespace torch
--- /dev/null
+#include <ATen/ATen.h>
+#include <ATen/core/jit_type.h>
+#include <torch/csrc/jit/alias_info.h>
+#include <torch/csrc/jit/script/lexer.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+using TypePtr = c10::TypePtr;
+using TypeAndAlias = std::pair<TypePtr, c10::optional<AliasInfo>>;
+
+struct SchemaTypeParser {
+ TypeAndAlias parseBaseType();
+ c10::optional<AliasInfo> parseAliasAnnotation();
+ std::pair<TypePtr, c10::optional<AliasInfo>> parseType();
+ c10::optional<at::ScalarType> parseTensorDType(const std::string& dtype);
+ TypePtr parseRefinedTensor();
+
+ SchemaTypeParser(Lexer& L, bool parse_complete_tensor_types) : L(L) {
+ complete_tensor_types = parse_complete_tensor_types;
+ }
+
+ private:
+ void parseList(
+ int begin,
+ int sep,
+ int end,
+ const std::function<void()>& callback);
+
+ bool complete_tensor_types;
+ Lexer& L;
+ size_t next_id = 0;
+};
+} // namespace script
+} // namespace jit
+} // namespace torch
+#include <torch/csrc/jit/script/script_type_parser.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/tree_views.h>
-#include <torch/csrc/jit/script/type_parser.h>
namespace torch {
namespace jit {
parseTypeFromExpr(*subscript.subscript_exprs().begin());
return FutureType::create(elem_type);
}},
- {"Dict",
- [](Subscript subscript) -> TypePtr {
- if (subscript.subscript_exprs().size() != 2) {
- throw ErrorReport(subscript)
- << " expected exactly 2 element types but found "
- << subscript.subscript_exprs().size();
- }
- auto key_type =
- parseTypeFromExpr(subscript.subscript_exprs()[0]);
- auto value_type =
- parseTypeFromExpr(subscript.subscript_exprs()[1]);
- return DictType::create(key_type, value_type);
- }},
+ {"Dict",
+ [](Subscript subscript) -> TypePtr {
+ if (subscript.subscript_exprs().size() != 2) {
+ throw ErrorReport(subscript)
+ << " expected exactly 2 element types but found "
+ << subscript.subscript_exprs().size();
+ }
+ auto key_type = parseTypeFromExpr(subscript.subscript_exprs()[0]);
+ auto value_type =
+ parseTypeFromExpr(subscript.subscript_exprs()[1]);
+ return DictType::create(key_type, value_type);
+ }},
};
return map;
}
+#include <torch/csrc/jit/script/sugared_value.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/script/sugared_value.h>
#include <torch/csrc/jit/script/tree_views.h>
-#include <torch/csrc/jit/script/type_parser.h>
namespace torch {
namespace jit {
throw ErrorReport(loc) << value->type()->str()
<< " cannot be used as a tuple";
}
-
} // namespace script
} // namespace jit
} // namespace torch