#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/schema_matching.h>
+#include <torch/csrc/jit/script/final_returns.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/script/type_parser.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/interpreter.h>
static bool meaningfulName(const std::string& name) {
if (name.size() == 0)
return false;
+ if (name[0] == '$')
+ return false;
if (name[0] != '_')
return true;
for (size_t i = 1; i < name.size(); ++i) {
type->isSubtypeOf(NumberType::get());
}
-c10::optional<std::string> parseBaseTypeName(const Expr& expr);
-TypePtr parseTypeFromExpr(const Expr& expr);
-c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(const Expr& expr);
+// Information for each def being emitted.
+// Defs can be nested to support closures so we need a stack of this information
+// Currently records information about the functions return type.
+struct DefContext {
+ TypePtr declared_return_type_; // nullptr if not annotated
+ TypePtr merged_return_type_; // nullptr if a Return has not been seen yet
+};
struct to_ir {
to_ir(
- Def def,
+ const Def& def,
Resolver resolver_,
const SugaredValuePtr& self,
Method& method) // method being constructed
, resolver(std::move(resolver_))
, environment_stack(nullptr) {
JIT_ASSERT(resolver);
- pushFrame(graph->block());
+ pushFrame(graph->block(), /*starts_def=*/true);
// Type annotations exclude explicitly typing the "self" parameter, so in the
// case that this is a method with self we expect one fewer parameter annotation
// Singly-linked list of environments. This top element contains a member
// `next` that points to the most immediate enclosing scope's value.
std::shared_ptr<Environment> environment_stack;
+ std::vector<DefContext> def_stack_;
- void pushFrame(Block * b) {
+ void pushFrame(Block * b, bool starts_def=false) {
+ if (starts_def) {
+ def_stack_.emplace_back();
+ }
environment_stack = std::make_shared<Environment>(method, resolver, b, environment_stack);
}
- std::shared_ptr<Environment> popFrame() {
+ std::shared_ptr<Environment> popFrame(bool ends_def=false) {
auto old_frame = environment_stack;
environment_stack = environment_stack->next;
+ if(ends_def) {
+ def_stack_.pop_back();
+ }
return old_frame;
}
FunctionSchema emitDef(const Def& def, const SugaredValuePtr& self, Block* block) {
auto schema = extractSchemaFromDef(def, self);
+ if (schema.returns().size() == 1) {
+ def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
+ }
std::vector<Argument> arguments = emitFormalArguments(def, self, schema, block);
+
// body
- auto stmts = def.statements();
- auto stmts_begin = stmts.begin();
- auto stmts_end = stmts.end();
- c10::optional<Return> return_stmt;
- if (stmts_begin != stmts_end && (*std::prev(stmts_end)).kind() == TK_RETURN) {
- --stmts_end;
- return_stmt = Return(*stmts_end);
- }
- emitStatements(stmts_begin, stmts_end);
- const SourceRange& range = return_stmt ? return_stmt->range() : def.range();
- std::vector<Argument> returns = {emitReturn(range, return_stmt, schema, block)};
+ auto stmts_list = moveAllReturnsToEnd(def.statements());
+ emitStatements(stmts_list.begin(), stmts_list.end());
+ std::vector<Argument> returns = {emitOutput(def.range(), schema, block)};
return {def.name().name(), std::move(arguments), std::move(returns)};
}
c10::optional<int32_t> N;
//BroadcastList list can only appear at the argument level
- if (auto maybe_broad_list = handleBroadcastList(decl_arg.type())) {
+ if (auto maybe_broad_list = parseBroadcastList(decl_arg.type())) {
type = maybe_broad_list->first;
N = maybe_broad_list->second;
} else {
if(!decl.return_type().present())
return {};
- if (handleBroadcastList(decl.return_type().get()))
+ if (parseBroadcastList(decl.return_type().get()))
throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type";
auto parsed_type = parseTypeFromExpr(decl.return_type().get());
return {Argument(
return arguments;
}
- Argument emitReturn(const SourceRange& range, c10::optional<Return> return_stmt, const FunctionSchema& schema, Block* block) {
- JIT_ASSERT(schema.returns().size() <= 1);
+ Argument emitOutput(const SourceRange& range, const FunctionSchema& schema, Block* block) {
+ // rewrites ensure there is always a return statement in program
+ JIT_ASSERT(def_stack_.back().merged_return_type_);
// outputs
- Value* result = return_stmt ? emitExpr(return_stmt->expr())
- : graph->insertConstant(IValue(), range);
- TypePtr result_type = schema.returns().size() > 0
- ? schema.returns().at(0).type()
- : result->type();
-
- if (return_stmt) {
- result = tryConvertToType(
- range, *graph, result_type, result, /*allow_conversions=*/true);
- }
-
- if (!result->type()->isSubtypeOf(result_type)) {
- throw ErrorReport(range) << "Return value was annotated as having type " << result_type->python_str()
- << " but is actually of type " << result->type()->python_str();
- }
+ Value* result = environment_stack->getVar("$return", range);
block->registerOutput(result);
- return Argument("", result_type);
+ return Argument("", def_stack_.back().merged_return_type_);
}
+
void emitStatements(const List<Stmt>& statements) {
return emitStatements(statements.begin(), statements.end());
}
Block* block = closure_node->addBlock();
{
WithInsertPoint guard(block);
- pushFrame(block);
+ pushFrame(block, /*starts_def=*/true);
emitDef(def, nullptr, block); //ignore schema return, we just wont use it for now since we never create a Method for the closure
- popFrame();
+ popFrame(/*ends_def=*/true);
}
std::shared_ptr<Graph> subgraph;
Value* context;
auto tup = graph->insertNode(graph->createTuple({closure_node->output(), context}))->output();
environment_stack->setVar(def.name().range(), def.name().name(), tup);
}
+
+ void emitReturn(const Return& stmt) {
+ Value* result = emitExpr(stmt.expr());
+ TypePtr result_type = def_stack_.back().declared_return_type_;
+ // result type is annotated, every return must convert to that type
+ if (result_type) {
+ // this guard skips implicit conversion from None -> Tensor for the return type.
+ // otherwise forgetting a return a function returning a tensor will cause a None to be
+ // converted to a tensor.
+ if (!(result_type->isSubtypeOf(DynamicType::get()) && result->type()->isSubtypeOf(NoneType::get()))) {
+ result = tryConvertToType(
+ stmt.range(), *graph, result_type, result, /*allow_conversions=*/true);
+ }
+
+ if (!result->type()->isSubtypeOf(result_type)) {
+ throw ErrorReport(stmt.range()) << "Return value was annotated as having type " << result_type->python_str()
+ << " but is actually of type " << result->type()->python_str();
+ }
+ } else {
+ result_type = def_stack_.back().merged_return_type_;
+ if (!result_type) {
+ result_type = result->type();
+ }
+ if(!unifyTypes(result_type, result->type())) {
+ throw ErrorReport(stmt.range())
+ << "Previous return statement returned a value of type "
+ << result_type->python_str()
+ << " but this return statement returns a value of type "
+ << result->type()->python_str();
+ }
+ }
+ JIT_ASSERT(result_type);
+ def_stack_.back().merged_return_type_ = result_type;
+ environment_stack->setVar(stmt.range(), "$return", result);
+ }
+
void emitStatements(List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end) {
for (; begin != end; ++begin) {
auto stmt = *begin;
case TK_ASSERT:
emitAssert(Assert(stmt));
break;
- case TK_RETURN:
- throw ErrorReport(stmt) << "return statements can appear only at the end "
- << "of the function body";
- break;
+ case TK_RETURN: {
+ emitReturn(Return(stmt));
+ } break;
case TK_PASS:
// Emit nothing for pass
break;
emit_if_expr(true_block, std::move(true_expr));
emit_if_expr(false_block, std::move(false_expr));
- auto true_type = unshapedType(true_block->outputs().at(0)->type());
- auto false_type = unshapedType(false_block->outputs().at(0)->type());
- if (*true_type != *false_type) {
+ auto true_type = true_block->outputs().at(0)->type();
+ auto false_type = false_block->outputs().at(0)->type();
+ auto unified = unifyTypes(true_type, false_type);
+ if (!unified) {
throw ErrorReport(range)
<< "if-expression's true branch has type " << true_type->str()
<< " but false branch has type " << false_type->str();
}
// Add op outputs
- auto expr_value = n->addOutput()->setType(true_type); // Resulting value
+ auto expr_value = n->addOutput()->setType(*unified); // Resulting value
return expr_value;
}
}
};
-
void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::vector<Def>& definitions, const std::vector<Resolver>& resolvers, const SugaredValuePtr& self) {
JIT_ASSERT(definitions.size() == resolvers.size());
auto resolver_it = resolvers.begin();
--- /dev/null
+#include <torch/csrc/jit/script/final_returns.h>
+#include <torch/csrc/jit/ir.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+struct ReturnInfo {
+ bool returns_; // true - all paths through stmts_ always return
+ // false - all paths through stmts_ do not return
+ List<Stmt> stmts_;
+};
+
+void checkNoReturn(const TreeRef& ref) {
+ if (ref->kind() == TK_RETURN)
+ throw ErrorReport(ref) << "return is not allowed from a loop.";
+ for(const TreeRef& child : ref->trees()) {
+ checkNoReturn(child);
+ }
+}
+
+// transform stmts so that its last action is to return or report that it
+// never returns.
+// return_none - if true, add an implicit `return None` to the end of the block
+// this handles the case where the return is implicit at the end of the function.
+ReturnInfo makeReturnsFinal(const SourceRange& range, at::ArrayRef<TreeRef> stmts, bool return_none);
+ReturnInfo makeReturnsFinal(const List<Stmt>& stmts, bool return_none) {
+ return makeReturnsFinal(stmts.range(), stmts.get()->trees(), return_none);
+}
+ReturnInfo makeReturnsFinal(const SourceRange& range, at::ArrayRef<TreeRef> stmts, bool return_none) {
+ std::vector<TreeRef> changed;
+ changed.reserve(stmts.size());
+ for(size_t i = 0; i < stmts.size(); ++i) {
+ const TreeRef& stmt = stmts[i];
+ switch(stmt->kind()) {
+ case TK_IF: {
+ auto if_stmt = If(stmt);
+ auto true_final = makeReturnsFinal(if_stmt.trueBranch(), false);
+ // (3) early return an if statement without an else block:
+ if (true_final.returns_ && if_stmt.falseBranch().size() == 0) {
+ auto rest_final = makeReturnsFinal(range, stmts.slice(i + 1), return_none);
+ if (!rest_final.returns_) {
+ throw ErrorReport(if_stmt)
+ << "This if statement performs an early return, but the block of code that follows it does not return."
+ << " Early returns are only allowed when the block following them also returns.";
+ }
+ changed.emplace_back(if_stmt.withNewBranches(true_final.stmts_, rest_final.stmts_));
+ return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
+ }
+
+ auto false_final = makeReturnsFinal(if_stmt.falseBranch(), false);
+ // (1) neither branch returns just keep processing the block
+ if (!true_final.returns_ && !false_final.returns_) {
+ changed.emplace_back(if_stmt);
+ break;
+ }
+ // (2) all branches return
+ if (true_final.returns_ && false_final.returns_) {
+ changed.emplace_back(if_stmt.withNewBranches(true_final.stmts_, false_final.stmts_));
+ return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
+ }
+ throw ErrorReport(if_stmt)
+ << "This if statement contains some paths that return and some paths that do not. "
+ << "If statements must either entirely return or never return.";
+ } break;
+ case TK_WHILE:
+ case TK_FOR:
+ changed.emplace_back(stmt);
+ checkNoReturn(stmt);
+ break;
+ case TK_RETURN:
+ changed.emplace_back(stmt);
+ // ignore the rest the the block, it is dead.
+ return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
+ default:
+ changed.emplace_back(stmt);
+ break;
+ }
+ }
+ if (return_none) {
+ // add an implicit return none node
+ changed.emplace_back(Return::create(range, Expr(Compound::create(TK_NONE, range, {}))));
+ }
+ // we reach the end of the block, no returns have happened
+ // unless we just inserted a return_none implicit return.
+ return {return_none, List<Stmt>::unsafeCreate(range, std::move(changed))};
+}
+
+List<Stmt> moveAllReturnsToEnd(const List<Stmt>& stmts) {
+ return makeReturnsFinal(stmts, true).stmts_;
+}
+
+} // namespace script
+} // namespace jit
+} // namespace torch