allow non-final returns (#15463)
authorZachary DeVito <zdevito@fb.com>
Fri, 21 Dec 2018 21:46:12 +0000 (13:46 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 21 Dec 2018 22:01:33 +0000 (14:01 -0800)
Summary:
This PR allows a subclass of programs that have return statements that are not final in the graph.

`final_returns.h` contains the a comment describing how this is accomplished.
To minimize complexity in `compiler.cpp`, this pass is done as an AST-to-AST rewrite before the compiler runs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15463

Differential Revision: D13538962

Pulled By: zdevito

fbshipit-source-id: 67105ca873351825b4a364092ab1873779f3e462

test/test_jit.py
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/final_returns.cpp [new file with mode: 0644]
torch/csrc/jit/script/final_returns.h [new file with mode: 0644]
torch/csrc/jit/script/tree.h
torch/csrc/jit/script/tree_views.h
torch/csrc/jit/script/type_parser.cpp
torch/csrc/jit/script/type_parser.h

index 7dab0b5..ac5d387 100644 (file)
@@ -6825,13 +6825,12 @@ a")
             MethodNoSelf()
 
     def test_return_stmt_not_at_end(self):
-        with self.assertRaisesRegex(RuntimeError, 'return statements can appear only at the end of the function body'):
-            @torch.jit.script
-            def return_stmt_wrong(x):
-                if bool(x > 3):
-                    return 3
-                else:
-                    return x
+        def return_stmt(x):
+            if bool(x > 3):
+                return x + 3
+            else:
+                return x
+        self.checkScript(return_stmt, (torch.rand(1),))
 
     def test_for_range_no_arg(self):
         with self.assertRaisesRegex(RuntimeError, r'range\(\) expects 1 argument but got 0'):
@@ -8880,6 +8879,84 @@ a")
         with self.capture_stdout() as captured:
             print(fn(x, scale, shift))
 
+    def test_non_final_return(self):
+
+        def simple(x):
+            if bool(x > 3):
+                return x + 1
+            else:
+                return x + 2
+            raise RuntimeError("nope")
+
+        def nest(x):
+            x = x + 1
+            if bool(x > 3):
+                if bool(x > 4):
+                    x += 1
+                return x + 1
+            else:
+                return x + 2
+
+        def early_ret(x):
+            x = x + 1
+            if bool(x > 3):
+                return x + 1
+            x = x + 1
+            return x + 2
+
+        def nest_early_ret(x):
+            x = x + 1
+            if bool(x > 3):
+                if bool(x > 4):
+                    return x + 2
+                return x + 1
+            x = x + 1
+            return x + 2
+
+        self.checkScript(simple, torch.rand(1))
+        self.checkScript(nest, torch.rand(1))
+        self.checkScript(early_ret, torch.rand(1))
+        self.checkScript(nest_early_ret, torch.rand(1))
+
+        with self.assertRaisesRegex(RuntimeError, "early"):
+            @torch.jit.script
+            def not_early_ret(x):
+                if bool(x > 3):
+                    if bool(x > 4):
+                        return 1
+                    print("foo")
+                else:
+                    print("5")
+                return 7
+
+        with self.assertRaisesRegex(RuntimeError, "some paths"):
+            @torch.jit.script
+            def not_total_ret(x):
+                if bool(x > 3):
+                    if bool(x > 4):
+                        return 1
+                    else:
+                        return 2
+                else:
+                    print("5")
+                return 7
+
+        with self.assertRaisesRegex(RuntimeError, "from a loop"):
+            @torch.jit.script
+            def nest_while_ret(x):
+                while bool(x > 4):
+                    if bool(x < 3):
+                        return 4
+                return 5
+
+        with self.assertRaisesRegex(RuntimeError, "from a loop"):
+            @torch.jit.script
+            def nest_for_ret(x):
+                for i in range(3):
+                    if bool(x < 3):
+                        return 4
+                return 5
+
 
 class MnistNet(nn.Module):
     def __init__(self):
index 3324221..9319565 100644 (file)
@@ -85,6 +85,7 @@ torch_sources_no_python_default = [
     "torch/csrc/jit/register_special_ops.cpp",
     "torch/csrc/jit/scope.cpp",
     "torch/csrc/jit/script/compiler.cpp",
+    "torch/csrc/jit/script/final_returns.cpp",
     "torch/csrc/jit/script/type_parser.cpp",
     "torch/csrc/jit/script/sugared_value.cpp",
     "torch/csrc/jit/script/schema_matching.cpp",
index 8d44c20..7de5815 100644 (file)
@@ -189,6 +189,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/scope.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/type_parser.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp
index 2c3782e..a2064b2 100644 (file)
@@ -1,6 +1,8 @@
 #include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/script/schema_matching.h>
+#include <torch/csrc/jit/script/final_returns.h>
 #include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/script/type_parser.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/interpreter.h>
@@ -40,6 +42,8 @@ static Value* asSimple(const SugaredValuePtr& value) {
 static bool meaningfulName(const std::string& name) {
   if (name.size() == 0)
     return false;
+  if (name[0] == '$')
+    return false;
   if (name[0] != '_')
     return true;
   for (size_t i = 1; i < name.size(); ++i) {
@@ -361,13 +365,17 @@ inline bool isSupportedListElementType(const TypePtr& type) {
       type->isSubtypeOf(NumberType::get());
 }
 
-c10::optional<std::string> parseBaseTypeName(const Expr& expr);
-TypePtr parseTypeFromExpr(const Expr& expr);
-c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(const Expr& expr);
+// Information for each def being emitted.
+// Defs can be nested to support closures so we need a stack of this information
+// Currently records information about the functions return type.
+struct DefContext {
+  TypePtr declared_return_type_; // nullptr if not annotated
+  TypePtr merged_return_type_; // nullptr if a Return has not been seen yet
+};
 
 struct to_ir {
   to_ir(
-      Def def,
+      const Def& def,
       Resolver resolver_,
       const SugaredValuePtr& self,
       Method& method) // method being constructed
@@ -376,7 +384,7 @@ struct to_ir {
       , resolver(std::move(resolver_))
       , environment_stack(nullptr) {
     JIT_ASSERT(resolver);
-    pushFrame(graph->block());
+    pushFrame(graph->block(), /*starts_def=*/true);
 
     // Type annotations exclude explicitly typing the "self" parameter, so in the
     // case that this is a method with self we expect one fewer parameter annotation
@@ -399,13 +407,20 @@ private:
   // Singly-linked list of environments. This top element contains a member
   // `next` that points to the most immediate enclosing scope's value.
   std::shared_ptr<Environment> environment_stack;
+  std::vector<DefContext> def_stack_;
 
-  void pushFrame(Block * b) {
+  void pushFrame(Block * b, bool starts_def=false) {
+    if (starts_def) {
+      def_stack_.emplace_back();
+    }
     environment_stack = std::make_shared<Environment>(method, resolver, b, environment_stack);
   }
-  std::shared_ptr<Environment> popFrame() {
+  std::shared_ptr<Environment> popFrame(bool ends_def=false) {
     auto old_frame = environment_stack;
     environment_stack = environment_stack->next;
+    if(ends_def) {
+      def_stack_.pop_back();
+    }
     return old_frame;
   }
 
@@ -417,20 +432,16 @@ private:
 
   FunctionSchema emitDef(const Def& def, const SugaredValuePtr& self, Block* block) {
     auto schema = extractSchemaFromDef(def, self);
+    if (schema.returns().size() == 1) {
+      def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
+    }
     std::vector<Argument> arguments = emitFormalArguments(def, self, schema, block);
 
+
     // body
-    auto stmts = def.statements();
-    auto stmts_begin = stmts.begin();
-    auto stmts_end = stmts.end();
-    c10::optional<Return> return_stmt;
-    if (stmts_begin != stmts_end && (*std::prev(stmts_end)).kind() == TK_RETURN) {
-      --stmts_end;
-      return_stmt = Return(*stmts_end);
-    }
-    emitStatements(stmts_begin, stmts_end);
-    const SourceRange& range = return_stmt ? return_stmt->range() : def.range();
-    std::vector<Argument> returns = {emitReturn(range, return_stmt, schema, block)};
+    auto stmts_list = moveAllReturnsToEnd(def.statements());
+    emitStatements(stmts_list.begin(), stmts_list.end());
+    std::vector<Argument> returns = {emitOutput(def.range(), schema, block)};
     return {def.name().name(), std::move(arguments), std::move(returns)};
   }
 
@@ -493,7 +504,7 @@ private:
       c10::optional<int32_t> N;
 
       //BroadcastList list can only appear at the argument level
-      if (auto maybe_broad_list = handleBroadcastList(decl_arg.type())) {
+      if (auto maybe_broad_list = parseBroadcastList(decl_arg.type())) {
         type = maybe_broad_list->first;
         N = maybe_broad_list->second;
       } else {
@@ -523,7 +534,7 @@ private:
     if(!decl.return_type().present())
       return {};
 
-    if (handleBroadcastList(decl.return_type().get()))
+    if (parseBroadcastList(decl.return_type().get()))
       throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type";
     auto parsed_type = parseTypeFromExpr(decl.return_type().get());
     return {Argument(
@@ -574,27 +585,15 @@ private:
     return arguments;
   }
 
-  Argument emitReturn(const SourceRange& range, c10::optional<Return> return_stmt, const FunctionSchema& schema, Block* block) {
-    JIT_ASSERT(schema.returns().size() <= 1);
+  Argument emitOutput(const SourceRange& range, const FunctionSchema& schema, Block* block) {
+    // rewrites ensure there is always a return statement in program
+    JIT_ASSERT(def_stack_.back().merged_return_type_);
     // outputs
-    Value* result = return_stmt ? emitExpr(return_stmt->expr())
-                                : graph->insertConstant(IValue(), range);
-    TypePtr result_type = schema.returns().size() > 0
-        ? schema.returns().at(0).type()
-        : result->type();
-
-    if (return_stmt) {
-      result = tryConvertToType(
-          range, *graph, result_type, result, /*allow_conversions=*/true);
-    }
-
-    if (!result->type()->isSubtypeOf(result_type)) {
-      throw ErrorReport(range) << "Return value was annotated as having type " << result_type->python_str()
-        << " but is actually of type " << result->type()->python_str();
-    }
+    Value* result = environment_stack->getVar("$return", range);
     block->registerOutput(result);
-    return Argument("", result_type);
+    return Argument("", def_stack_.back().merged_return_type_);
   }
+
   void emitStatements(const List<Stmt>& statements) {
     return emitStatements(statements.begin(), statements.end());
   }
@@ -639,9 +638,9 @@ private:
     Block* block = closure_node->addBlock();
     {
       WithInsertPoint guard(block);
-      pushFrame(block);
+      pushFrame(block, /*starts_def=*/true);
       emitDef(def, nullptr, block); //ignore schema return, we just wont use it for now since we never create a Method for the closure
-      popFrame();
+      popFrame(/*ends_def=*/true);
     }
     std::shared_ptr<Graph> subgraph;
     Value* context;
@@ -652,6 +651,42 @@ private:
     auto tup = graph->insertNode(graph->createTuple({closure_node->output(), context}))->output();
     environment_stack->setVar(def.name().range(), def.name().name(), tup);
   }
+
+  void emitReturn(const Return& stmt) {
+    Value* result = emitExpr(stmt.expr());
+    TypePtr result_type = def_stack_.back().declared_return_type_;
+    // result type is annotated, every return must convert to that type
+    if (result_type) {
+      // this guard skips implicit conversion from None -> Tensor for the return type.
+      // otherwise forgetting a return a function returning a tensor will cause a None to be
+      // converted to a tensor.
+      if (!(result_type->isSubtypeOf(DynamicType::get()) && result->type()->isSubtypeOf(NoneType::get()))) {
+        result = tryConvertToType(
+            stmt.range(), *graph, result_type, result, /*allow_conversions=*/true);
+      }
+
+      if (!result->type()->isSubtypeOf(result_type)) {
+        throw ErrorReport(stmt.range()) << "Return value was annotated as having type " << result_type->python_str()
+          << " but is actually of type " << result->type()->python_str();
+      }
+    } else {
+      result_type = def_stack_.back().merged_return_type_;
+      if (!result_type) {
+        result_type = result->type();
+      }
+      if(!unifyTypes(result_type, result->type())) {
+        throw ErrorReport(stmt.range())
+            << "Previous return statement returned a value of type "
+            << result_type->python_str()
+            << " but this return statement returns a value of type "
+            << result->type()->python_str();
+      }
+    }
+    JIT_ASSERT(result_type);
+    def_stack_.back().merged_return_type_ = result_type;
+    environment_stack->setVar(stmt.range(), "$return", result);
+  }
+
   void emitStatements(List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end) {
     for (; begin != end; ++begin) {
       auto stmt = *begin;
@@ -688,10 +723,9 @@ private:
         case TK_ASSERT:
           emitAssert(Assert(stmt));
           break;
-        case TK_RETURN:
-          throw ErrorReport(stmt) << "return statements can appear only at the end "
-                                  << "of the function body";
-          break;
+        case TK_RETURN: {
+          emitReturn(Return(stmt));
+        } break;
         case TK_PASS:
           // Emit nothing for pass
           break;
@@ -773,16 +807,17 @@ private:
     emit_if_expr(true_block, std::move(true_expr));
     emit_if_expr(false_block, std::move(false_expr));
 
-    auto true_type = unshapedType(true_block->outputs().at(0)->type());
-    auto false_type = unshapedType(false_block->outputs().at(0)->type());
-    if (*true_type != *false_type) {
+    auto true_type = true_block->outputs().at(0)->type();
+    auto false_type = false_block->outputs().at(0)->type();
+    auto unified = unifyTypes(true_type, false_type);
+    if (!unified) {
       throw ErrorReport(range)
           << "if-expression's true branch has type " << true_type->str()
           << " but false branch has type " << false_type->str();
     }
 
     // Add op outputs
-    auto expr_value = n->addOutput()->setType(true_type); // Resulting value
+    auto expr_value = n->addOutput()->setType(*unified); // Resulting value
 
     return expr_value;
   }
@@ -2181,7 +2216,6 @@ private:
   }
 };
 
-
 void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::vector<Def>& definitions, const std::vector<Resolver>& resolvers, const SugaredValuePtr& self) {
   JIT_ASSERT(definitions.size() == resolvers.size());
   auto resolver_it = resolvers.begin();
diff --git a/torch/csrc/jit/script/final_returns.cpp b/torch/csrc/jit/script/final_returns.cpp
new file mode 100644 (file)
index 0000000..5541f27
--- /dev/null
@@ -0,0 +1,95 @@
+#include <torch/csrc/jit/script/final_returns.h>
+#include <torch/csrc/jit/ir.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+struct ReturnInfo {
+  bool returns_; // true - all paths through stmts_ always return
+                 // false - all paths through stmts_ do not return
+  List<Stmt> stmts_;
+};
+
+void checkNoReturn(const TreeRef& ref) {
+  if (ref->kind() == TK_RETURN)
+    throw ErrorReport(ref) << "return is not allowed from a loop.";
+  for(const TreeRef& child : ref->trees()) {
+    checkNoReturn(child);
+  }
+}
+
+// transform stmts so that its last action is to return or report that it
+// never returns.
+// return_none - if true, add an implicit `return None` to the end of the block
+//   this handles the case where the return is implicit at the end of the function.
+ReturnInfo makeReturnsFinal(const SourceRange& range, at::ArrayRef<TreeRef> stmts, bool return_none);
+ReturnInfo makeReturnsFinal(const List<Stmt>& stmts, bool return_none) {
+  return makeReturnsFinal(stmts.range(), stmts.get()->trees(), return_none);
+}
+ReturnInfo makeReturnsFinal(const SourceRange& range, at::ArrayRef<TreeRef> stmts, bool return_none) {
+  std::vector<TreeRef> changed;
+  changed.reserve(stmts.size());
+  for(size_t i = 0; i < stmts.size(); ++i) {
+    const TreeRef& stmt = stmts[i];
+    switch(stmt->kind()) {
+      case TK_IF: {
+        auto if_stmt = If(stmt);
+        auto true_final = makeReturnsFinal(if_stmt.trueBranch(), false);
+        // (3) early return an if statement without an else block:
+        if (true_final.returns_ && if_stmt.falseBranch().size() == 0) {
+          auto rest_final = makeReturnsFinal(range, stmts.slice(i + 1), return_none);
+          if (!rest_final.returns_) {
+            throw ErrorReport(if_stmt)
+                  << "This if statement performs an early return, but the block of code that follows it does not return."
+                  << " Early returns are only allowed when the block following them also returns.";
+          }
+          changed.emplace_back(if_stmt.withNewBranches(true_final.stmts_, rest_final.stmts_));
+          return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
+        }
+
+        auto false_final = makeReturnsFinal(if_stmt.falseBranch(), false);
+        // (1) neither branch returns just keep processing the block
+        if (!true_final.returns_ && !false_final.returns_) {
+          changed.emplace_back(if_stmt);
+          break;
+        }
+        // (2) all branches return
+        if (true_final.returns_ && false_final.returns_) {
+          changed.emplace_back(if_stmt.withNewBranches(true_final.stmts_, false_final.stmts_));
+          return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
+        }
+        throw ErrorReport(if_stmt)
+              << "This if statement contains some paths that return and some paths that do not. "
+              << "If statements must either entirely return or never return.";
+      } break;
+      case TK_WHILE:
+      case TK_FOR:
+        changed.emplace_back(stmt);
+        checkNoReturn(stmt);
+        break;
+      case TK_RETURN:
+        changed.emplace_back(stmt);
+        // ignore the rest the the block, it is dead.
+        return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
+      default:
+        changed.emplace_back(stmt);
+        break;
+    }
+  }
+  if (return_none) {
+    // add an implicit return none node
+    changed.emplace_back(Return::create(range, Expr(Compound::create(TK_NONE, range, {}))));
+  }
+  // we reach the end of the block, no returns have happened
+  // unless we just inserted a return_none implicit return.
+  return {return_none, List<Stmt>::unsafeCreate(range, std::move(changed))};
+}
+
+List<Stmt> moveAllReturnsToEnd(const List<Stmt>& stmts) {
+  return makeReturnsFinal(stmts, true).stmts_;
+}
+
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/final_returns.h b/torch/csrc/jit/script/final_returns.h
new file mode 100644 (file)
index 0000000..c2960a4
--- /dev/null
@@ -0,0 +1,60 @@
+#pragma once
+#include <functional>
+#include <memory>
+#include <string>
+
+#include <torch/csrc/jit/script/error_report.h>
+#include <torch/csrc/jit/script/tree_views.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+// This is an AST-to-AST transform that ensures that all return statements
+// are at the end of the natural control-flow of the program.
+//
+// Since the return is at the end of the function, it is equivalent
+// to simply assigning the returned value to to a special `$return` variable
+// that is universally set to be the output of the function.
+//
+// This transform is only intended to support a subset of control-flow
+// structures to make the transformation both easy to do _and_ easy to
+// explain to users of TorchScript. The second constraint is important: if
+// it is unclear what is allowed users will get the impression that the
+// subset is difficult to use.
+//
+//   if <cond>:
+//     <true>
+//   else:
+//     <false>
+//   <rest>
+//
+// In particular we allow:
+// 1. If statements where neither <true> nor <false> branch returns.
+// 2. If statements where both <true> and <false> always return.
+// 3. An 'early return' if statement where <true> always returns <false> is empty, and <rest>
+// always returns.
+//
+// We do not allow returns from loops in any case.
+//
+// This pass handles the following cases as follows:
+//
+// 1. Neither branch returns so we can just leave the branches as is
+// 2. Both branches return, so we recursively transform the program such that
+// <true> and <false>'s final action is to return. We then delete <rest>
+// because the code is dead. The remaining program preserves the inductive
+// property that its last action is to return since both branches end in a return.
+// 3. In this case we know that <true> and <rest> always returns, and <false> is empty.
+//    We transform the graph to:
+//    if <cond>:
+//       <true>
+//     else:
+//       <rest>
+//    Now it is another instance of case (2).
+
+TORCH_API List<Stmt> moveAllReturnsToEnd(const List<Stmt>& stmts);
+
+}
+} // namespace jit
+} // namespace torch
index 787c0df..3b6c5d1 100644 (file)
@@ -144,6 +144,7 @@ struct Compound : public Tree {
     }
     return Compound::create(kind(), range(), std::move(trees_));
   }
+
   const SourceRange& range() const override {
     return range_;
   }
index bb0f151..1640f48 100644 (file)
@@ -159,6 +159,9 @@ struct List : public TreeView {
     TreeList type_erased_sub {subtrees.begin(), subtrees.end()};
     return List(Compound::create(TK_LIST, range, std::move(type_erased_sub)));
   }
+  static List unsafeCreate(const SourceRange& range, TreeList&& subtrees) {
+    return List(Compound::create(TK_LIST, range, std::move(subtrees)));
+  }
   size_t size() const {
     return tree_->trees().size();
   }
@@ -380,6 +383,9 @@ struct If : public Stmt {
   List<Stmt> falseBranch() const {
     return List<Stmt>(subtree(2));
   }
+  If withNewBranches(const List<Stmt>& true_branch, const List<Stmt>& false_branch) const {
+    return create(range(), cond(), true_branch, false_branch);
+  }
   static If create(
       const SourceRange& range,
       const Expr& cond,
index 5b4ce12..e55c6de 100644 (file)
@@ -62,7 +62,7 @@ bool isTorch(const Expr& expr) {
 
 
 
-c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(const Expr& expr) {
+c10::optional<std::pair<TypePtr, int32_t>> parseBroadcastList(const Expr& expr) {
   if (expr.kind() != TK_SUBSCRIPT)
     return c10::nullopt;
   auto subscript = Subscript(expr);
@@ -73,7 +73,7 @@ c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(const Expr& expr)
 
   // handle the case where the BroadcastingList is wrapped in a Optional type
   if(var.name().name() == "Optional") {
-    auto broadcast_list = handleBroadcastList(subscript_exprs[0]);
+    auto broadcast_list = parseBroadcastList(subscript_exprs[0]);
     if (broadcast_list) {
       TypePtr opt_type = OptionalType::create(broadcast_list->first);
       return std::pair<TypePtr, int32_t>(opt_type, broadcast_list->second);
index b583405..50e2ea8 100644 (file)
@@ -7,6 +7,7 @@ namespace script {
 struct Expr;
 TORCH_API c10::optional<std::string> parseBaseTypeName(const Expr& expr);
 TORCH_API c10::TypePtr parseTypeFromExpr(const Expr& expr);
+TORCH_API c10::optional<std::pair<c10::TypePtr, int32_t>> parseBroadcastList(const Expr& expr);
 }
 } // namespace jit
 } // namespace torch