Support enough of closures to write autograd functions (#15411)
authorZachary DeVito <zdevito@fb.com>
Thu, 20 Dec 2018 22:26:06 +0000 (14:26 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 22:39:11 +0000 (14:39 -0800)
Summary:
This PR adds enough of the infra for supporting closures (inner script functions) in order to allow us to expression symbolic gradients using them. We do not actually ever run graphs that contain these closures. The symbolic_script infrastructure just extracts them out of the original forward graph and turns them into discrete forward/backward pairs. This cuts down on the type annotations necessary to write forward/backward pairs and aligns closely with the "differentiator" function approach to expression reverse-mode AD.

Example:

This code:
```
import torch

r = torch.jit.CompilationUnit(
'''
def mul_forward(self, other):
    def backward(grad_output):
        grad_self = (grad_output * other).sum_to_size(self.size())
        grad_other = (grad_output * self).sum_to_size(other.size())
        return grad_self, grad_other
    return self * other, backward
''')

print(r.module.code)
```

Will produce this graph (pretty printed for clarity):

```
def mul_forward(self,
    self: Tensor,
    other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]:
  backward = (self.__lambda, (other, self))
  return (torch.mul(self, other), backward)

def __lambda(self,
    context: Tuple[Tensor, Tensor],
    grad_output: Tensor) -> Tuple[Tensor, Tensor]:
  other, self, = context
  grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self))
  grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other))
  return (grad_self, grad_other)
```

symbolic_script will then do some modifications to remove the unsuppored prim::Function node, yielding:

```
def mul_forward(self,
    self: Tensor,
    other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]:
  return (torch.mul(self, other), (other, self))

def backward(self,
    context: Tuple[Tensor, Tensor],
    grad_output: Tensor) -> Tuple[Tensor, Tensor]:
  other, self, = context
  grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self))
  grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other))
  return (grad_self, grad_other)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15411

Differential Revision: D13523340

Pulled By: zdevito

fbshipit-source-id: 4d4a269460e595b16802c00ec55ae00e3e682d49

aten/src/ATen/core/interned_strings.h
test/expect/TestFuser.test_lstm_cuda-backward.expect
test/expect/TestFuser.test_milstm_cuda-backward.expect
test/expect/TestJit.test_cpp_cuda.expect
test/test_jit.py
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/parser.cpp
torch/csrc/jit/script/tree_views.h
torch/csrc/jit/symbolic_script.cpp

index e94f37d..29d0062 100644 (file)
@@ -75,6 +75,7 @@ namespace c10 {
   _(aten, __round_to_zero_floordiv)\
   _(prim, fork)                    \
   _(prim, RaiseException)          \
+  _(prim, Function)                \
   _(aten, append)                  \
   _(aten, format)                  \
   _(aten, __not__)                 \
index a9982b7..de28768 100644 (file)
@@ -27,26 +27,26 @@ graph(%0 : Float(*, *)
       %26 : Float(*, *)) {
   %27 : Float(*, *) = aten::mul(%0, %26)
   %28 : int[] = aten::size(%outgate)
-  %29 : Tensor = aten::sum_to_size(%27, %28)
+  %grad_self.1 : Tensor = aten::sum_to_size(%27, %28)
   %30 : Float(*, *) = aten::mul(%0, %outgate)
   %31 : int[] = aten::size(%26)
-  %32 : Tensor = aten::sum_to_size(%30, %31)
-  %33 : Tensor = prim::FusionGroup_0(%1, %32, %26)
+  %grad_other.1 : Tensor = aten::sum_to_size(%30, %31)
+  %33 : Tensor = prim::FusionGroup_0(%1, %grad_other.1, %26)
   %34 : Tensor = prim::SumToSize(%33, %24)
   %35 : Tensor = prim::SumToSize(%33, %25)
   %36 : Tensor = aten::mul(%35, %cellgate)
   %37 : int[] = aten::size(%ingate)
-  %38 : Tensor = aten::sum_to_size(%36, %37)
+  %grad_self.3 : Tensor = aten::sum_to_size(%36, %37)
   %39 : Tensor = aten::mul(%35, %ingate)
   %40 : int[] = aten::size(%cellgate)
-  %41 : Tensor = aten::sum_to_size(%39, %40)
+  %grad_other.3 : Tensor = aten::sum_to_size(%39, %40)
   %42 : Tensor = aten::mul(%34, %9)
   %43 : int[] = aten::size(%forgetgate)
-  %44 : Tensor = aten::sum_to_size(%42, %43)
+  %grad_self.5 : Tensor = aten::sum_to_size(%42, %43)
   %45 : Tensor = aten::mul(%34, %forgetgate)
   %46 : int[] = aten::size(%9)
-  %47 : Tensor = aten::sum_to_size(%45, %46)
-  %48 : Tensor = prim::FusionGroup_1(%38, %ingate, %44, %forgetgate, %41, %cellgate, %29, %outgate)
+  %grad_other.5 : Tensor = aten::sum_to_size(%45, %46)
+  %48 : Tensor = prim::FusionGroup_1(%grad_self.3, %ingate, %grad_self.5, %forgetgate, %grad_other.3, %cellgate, %grad_self.1, %outgate)
   %49 : Tensor = prim::SumToSize(%48, %19)
   %50 : Tensor = prim::SumToSize(%48, %17)
   %51 : Tensor = prim::SumToSize(%48, %14)
@@ -61,7 +61,7 @@ graph(%0 : Float(*, *)
   %60 : Float(*, *) = aten::t(%11)
   %61 : Float(*, *) = aten::mm(%60, %51)
   %62 : Float(*, *) = aten::t(%61)
-  return (%47, %49, %50, %54, %57, %59, %62);
+  return (%grad_other.5, %49, %50, %54, %57, %59, %62);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
index 56ddf8f..7f45d2c 100644 (file)
@@ -33,58 +33,58 @@ graph(%0 : Float(*, *)
   %32 : int = prim::Constant[value=1]()
   %33 : Float(*, *) = aten::mul(%0, %31)
   %34 : int[] = aten::size(%outgate)
-  %35 : Tensor = aten::sum_to_size(%33, %34)
+  %grad_self.1 : Tensor = aten::sum_to_size(%33, %34)
   %36 : Float(*, *) = aten::mul(%0, %outgate)
   %37 : int[] = aten::size(%31)
-  %38 : Tensor = aten::sum_to_size(%36, %37)
-  %39 : Tensor = prim::FusionGroup_0(%1, %38, %31)
+  %grad_other.1 : Tensor = aten::sum_to_size(%36, %37)
+  %39 : Tensor = prim::FusionGroup_0(%1, %grad_other.1, %31)
   %40 : Tensor = prim::SumToSize(%39, %29)
   %41 : Tensor = prim::SumToSize(%39, %30)
   %42 : Tensor = aten::mul(%41, %cellgate)
   %43 : int[] = aten::size(%ingate)
-  %44 : Tensor = aten::sum_to_size(%42, %43)
+  %grad_self.3 : Tensor = aten::sum_to_size(%42, %43)
   %45 : Tensor = aten::mul(%41, %ingate)
   %46 : int[] = aten::size(%cellgate)
-  %47 : Tensor = aten::sum_to_size(%45, %46)
+  %grad_other.3 : Tensor = aten::sum_to_size(%45, %46)
   %48 : Tensor = aten::mul(%40, %10)
   %49 : int[] = aten::size(%forgetgate)
-  %50 : Tensor = aten::sum_to_size(%48, %49)
-  %51 : Tensor = prim::FusionGroup_1(%44, %ingate, %50, %forgetgate, %47, %cellgate, %35, %outgate)
+  %grad_self.5 : Tensor = aten::sum_to_size(%48, %49)
+  %51 : Tensor = prim::FusionGroup_1(%grad_self.3, %ingate, %grad_self.5, %forgetgate, %grad_other.3, %cellgate, %grad_self.1, %outgate)
   %52 : Tensor = prim::SumToSize(%51, %24)
   %53 : Tensor = prim::SumToSize(%51, %22)
   %54 : Tensor = aten::mul(%53, %Uz)
   %55 : int[] = aten::size(%11)
-  %56 : Tensor = aten::sum_to_size(%54, %55)
+  %grad_self.7 : Tensor = aten::sum_to_size(%54, %55)
   %57 : Tensor = aten::mul(%53, %11)
   %58 : int[] = aten::size(%Uz)
-  %59 : Tensor = aten::sum_to_size(%57, %58)
+  %grad_other.7 : Tensor = aten::sum_to_size(%57, %58)
   %60 : Tensor = prim::SumToSize(%51, %19)
   %61 : Tensor = prim::SumToSize(%51, %20)
   %62 : Tensor = aten::mul(%61, %Wx)
   %63 : int[] = aten::size(%12)
-  %64 : Tensor = aten::sum_to_size(%62, %63)
+  %grad_self.9 : Tensor = aten::sum_to_size(%62, %63)
   %65 : Tensor = aten::mul(%61, %12)
   %66 : int[] = aten::size(%Wx)
-  %67 : Tensor = aten::sum_to_size(%65, %66)
+  %grad_other.9 : Tensor = aten::sum_to_size(%65, %66)
   %68 : Tensor = aten::mul(%60, %Uz)
   %69 : int[] = aten::size(%18)
-  %70 : Tensor = aten::sum_to_size(%68, %69)
+  %grad_self.11 : Tensor = aten::sum_to_size(%68, %69)
   %71 : Tensor = aten::mul(%60, %18)
-  %72 : Tensor = aten::sum_to_size(%71, %58)
-  %73 : Tensor = aten::add(%59, %72, %32)
-  %74 : Tensor = aten::mul(%70, %Wx)
+  %grad_other.11 : Tensor = aten::sum_to_size(%71, %58)
+  %73 : Tensor = aten::add(%grad_other.7, %grad_other.11, %32)
+  %74 : Tensor = aten::mul(%grad_self.11, %Wx)
   %75 : int[] = aten::size(%13)
-  %76 : Tensor = aten::sum_to_size(%74, %75)
-  %77 : Tensor = aten::mul(%70, %13)
-  %78 : Tensor = aten::sum_to_size(%77, %66)
-  %79 : Tensor = aten::add(%67, %78, %32)
+  %grad_self.13 : Tensor = aten::sum_to_size(%74, %75)
+  %77 : Tensor = aten::mul(%grad_self.11, %13)
+  %grad_other.13 : Tensor = aten::sum_to_size(%77, %66)
+  %79 : Tensor = aten::add(%grad_other.9, %grad_other.13, %32)
   %80 : Float(*, *) = aten::t(%14)
   %81 : Float(*, *) = aten::mm(%80, %73)
   %82 : Float(*, *) = aten::t(%81)
   %83 : Float(*, *) = aten::t(%15)
   %84 : Float(*, *) = aten::mm(%83, %79)
   %85 : Float(*, *) = aten::t(%84)
-  return (%52, %56, %64, %76, %82, %85);
+  return (%52, %grad_self.7, %grad_self.9, %grad_self.13, %82, %85);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
index 1872b15..5b3b092 100644 (file)
@@ -115,29 +115,29 @@ graph(%0 : Float(2, 3, 4)
       %12 : Tensor = prim::SumToSize(%11, %6)
       -> (%10, %12)
     }
-  %13 : Tensor, %14 : Tensor = prim::GradOf[name="aten::mul"](%8)
+  %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%8)
     block0() {
       %15 : Tensor = aten::mul(%8, %2)
       %16 : int[] = aten::size(%4)
-      %17 : Tensor = aten::sum_to_size(%15, %16)
+      %grad_self.1 : Tensor = aten::sum_to_size(%15, %16)
       %18 : Tensor = aten::mul(%8, %4)
       %19 : int[] = aten::size(%2)
-      %20 : Tensor = aten::sum_to_size(%18, %19)
-      -> (%17, %20)
+      %grad_other.1 : Tensor = aten::sum_to_size(%18, %19)
+      -> (%grad_self.1, %grad_other.1)
     }
-  %21 : Tensor = prim::AutogradAdd(%1, %13)
-  %22 : Tensor, %23 : Tensor = prim::GradOf[name="aten::mul"](%21)
+  %21 : Tensor = prim::AutogradAdd(%1, %grad_self.2)
+  %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%21)
     block0() {
       %24 : Tensor = aten::mul(%21, %3)
       %25 : int[] = aten::size(%2)
-      %26 : Tensor = aten::sum_to_size(%24, %25)
+      %grad_self.3 : Tensor = aten::sum_to_size(%24, %25)
       %27 : Tensor = aten::mul(%21, %2)
       %28 : int[] = aten::size(%3)
-      %29 : Tensor = aten::sum_to_size(%27, %28)
-      -> (%26, %29)
+      %grad_other.3 : Tensor = aten::sum_to_size(%27, %28)
+      -> (%grad_self.3, %grad_other.3)
     }
-  %30 : Tensor = prim::AutogradAdd(%14, %22)
-  %31 : Tensor = prim::AutogradAdd(%9, %23)
+  %30 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self)
+  %31 : Tensor = prim::AutogradAdd(%9, %grad_other)
   return (%30, %31);
 }
 
@@ -165,24 +165,24 @@ graph(%0 : Float(*)
       %8 : Tensor = prim::SumToSize(%0, %4)
       -> (%8)
     }
-  %9 : Tensor, %10 : Tensor = prim::GradOf[name="aten::mul"](%7)
+  %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%7)
     block0() {
       %11 : Tensor = aten::mul(%7, %2)
       %12 : int[] = aten::size(%3)
-      %13 : Tensor = aten::sum_to_size(%11, %12)
+      %grad_self.1 : Tensor = aten::sum_to_size(%11, %12)
       %14 : Tensor = aten::mul(%7, %3)
       %15 : int[] = aten::size(%2)
-      %16 : Tensor = aten::sum_to_size(%14, %15)
-      -> (%13, %16)
+      %grad_other.1 : Tensor = aten::sum_to_size(%14, %15)
+      -> (%grad_self.1, %grad_other.1)
     }
-  %17 : Tensor = prim::AutogradAdd(%1, %9)
+  %17 : Tensor = prim::AutogradAdd(%1, %grad_self)
   %18 : Tensor = prim::GradOf[name="aten::add"](%17)
     block0() {
       %19 : Tensor = aten::mul(%17, %6)
       %20 : Tensor = prim::SumToSize(%19, %5)
       -> (%20)
     }
-  %21 : Tensor = prim::AutogradAdd(%10, %18)
+  %21 : Tensor = prim::AutogradAdd(%grad_other, %18)
   return (%21);
 }
 
index 07690f3..07f3b8f 100644 (file)
@@ -259,7 +259,9 @@ class JitTestCase(TestCase):
             try:
                 pp, constant_table = module._python_print()
             except RuntimeError as e:
-                if "could not export python function" not in str(e):
+                se = str(e)
+                if "could not export python function" not in se and \
+                   "closures are not exportable" not in se:
                     raise
                 else:
                     return
index 24b8ed4..895e68a 100644 (file)
@@ -312,6 +312,7 @@ void AliasDb::analyze(Node* node) {
     case prim::None:
     case prim::BroadcastSizes:
     case prim::ChunkSizes:
+    case prim::Function:
       return analyzeCreator(node);
     case prim::TupleUnpack:
     case prim::TupleIndex:
@@ -591,6 +592,5 @@ void AliasDb::giveFreshAlias(const Value* value) {
   }
   addAlias(value, getFreshAlias());
 }
-
 } // namespace jit
 } // namespace torch
index 27e7594..a478bd8 100644 (file)
@@ -768,6 +768,17 @@ struct PythonPrintPass {
         }
         stmt << ")";
       } break;
+      case prim::Function: {
+        if (enforce_importable_) {
+          throw script::ErrorReport(node->getSourceLocation()) << "closures are not exportable";
+        }
+        auto name = genMethodName("__lambda");
+        std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
+        worklist.emplace_back([graph, name, this] {
+          printFunctionDefinition(*graph, name);
+        });
+        stmt << "self." << name;
+      } break;
       default: {
         Symbol kind = node->kind();
         if (kind.is_aten()) {
index 58817ba..2c3782e 100644 (file)
@@ -367,15 +367,13 @@ c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(const Expr& expr)
 
 struct to_ir {
   to_ir(
-      Def def_,
+      Def def,
       Resolver resolver_,
-      SugaredValuePtr self_,
+      const SugaredValuePtr& self,
       Method& method) // method being constructed
       : method(method)
       , graph(method.graph())
-      , def(std::move(def_))
       , resolver(std::move(resolver_))
-      , self(std::move(self_))
       , environment_stack(nullptr) {
     JIT_ASSERT(resolver);
     pushFrame(graph->block());
@@ -386,34 +384,15 @@ struct to_ir {
     if (self && def.decl().params().size() == 0) {
       throw ErrorReport(def.decl().params().range()) << "methods must have a self argument";
     }
-    auto schema = extractSchemaFromDef(def);
-    std::vector<Argument> arguments = emitFormalArguments(self, schema);
 
-    // body
-    auto stmts = def.statements();
-    auto stmts_begin = stmts.begin();
-    auto stmts_end = stmts.end();
-    c10::optional<Return> return_stmt;
-    if (stmts_begin != stmts_end && (*std::prev(stmts_end)).kind() == TK_RETURN) {
-      --stmts_end;
-      return_stmt = Return(*stmts_end);
-    }
-    emitStatements(stmts_begin, stmts_end);
-    std::vector<Argument> returns = {emitReturn(
-        return_stmt ? return_stmt->range() : def.range(), return_stmt, schema)};
-
-    method.setSchema({def.name().name(), std::move(arguments), std::move(returns)});
-    // remove any uses of tuples that we inserted that are not needed
-    LowerSimpleTuples(graph);
-    ConstantPooling(graph);
+    method.setSchema(emitDef(def, self, graph->block()));
+    runCleanupPasses(graph);
   }
 
 private:
   Method& method;
   std::shared_ptr<Graph> graph;
-  Def def;
   Resolver resolver;
-  SugaredValuePtr self;
   std::unordered_map<int64_t, Value*> integral_constants;
   std::unordered_map<double, Value*> fp_constants;
 
@@ -430,6 +409,31 @@ private:
     return old_frame;
   }
 
+  void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
+    // remove any uses of tuples that we inserted that are not needed
+    LowerSimpleTuples(to_clean);
+    ConstantPooling(to_clean);
+  }
+
+  FunctionSchema emitDef(const Def& def, const SugaredValuePtr& self, Block* block) {
+    auto schema = extractSchemaFromDef(def, self);
+    std::vector<Argument> arguments = emitFormalArguments(def, self, schema, block);
+
+    // body
+    auto stmts = def.statements();
+    auto stmts_begin = stmts.begin();
+    auto stmts_end = stmts.end();
+    c10::optional<Return> return_stmt;
+    if (stmts_begin != stmts_end && (*std::prev(stmts_end)).kind() == TK_RETURN) {
+      --stmts_end;
+      return_stmt = Return(*stmts_end);
+    }
+    emitStatements(stmts_begin, stmts_end);
+    const SourceRange& range = return_stmt ? return_stmt->range() : def.range();
+    std::vector<Argument> returns = {emitReturn(range, return_stmt, schema, block)};
+    return {def.name().name(), std::move(arguments), std::move(returns)};
+  }
+
   std::vector<IValue> evaluateDefaults(const SourceRange& r, const std::vector<Expr>& default_types, const std::vector<Expr>& default_exprs) {
     std::vector<IValue> default_values;
     if (default_exprs.empty())
@@ -461,7 +465,7 @@ private:
     return stack.at(0).toTuple()->elements();
   }
 
-  std::vector<Argument> parseArgsFromDecl(const Decl& decl) {
+  std::vector<Argument> parseArgsFromDecl(const Decl& decl, const SugaredValuePtr& self) {
     auto params_begin = decl.params().begin();
     auto params_end = decl.params().end();
     if (self)
@@ -529,14 +533,14 @@ private:
         /*default_value =*/c10::nullopt,
         /*kwarg_only =*/false)};
   }
-  FunctionSchema extractSchemaFromDef(const Def &def) {
+  FunctionSchema extractSchemaFromDef(const Def &def, const SugaredValuePtr& self) {
       auto name = def.name().name();
-      std::vector<Argument> args = parseArgsFromDecl(def.decl());
+      std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
       std::vector<Argument> returns = parseReturnFromDecl(def.decl());
       return FunctionSchema(name, std::move(args), std::move(returns), false, false);
   }
 
-  std::vector<Argument> emitFormalArguments(const SugaredValuePtr& self, const FunctionSchema& schema) {
+  std::vector<Argument> emitFormalArguments(const Def& def, const SugaredValuePtr& self, const FunctionSchema& schema, Block* block) {
     std::vector<Argument> arguments; // for schema
     // inputs
     auto it = def.decl().params().begin();
@@ -557,7 +561,7 @@ private:
     for(;it != end; ++it) {
       auto& name = (*it).ident().name();
       // Add the input to the graph
-      Value *new_input = graph->addInput();
+      Value *new_input = block->addInput();
       if (meaningfulName(name)) {
         new_input->setUniqueName(name);
       }
@@ -570,7 +574,7 @@ private:
     return arguments;
   }
 
-  Argument emitReturn(const SourceRange& range, c10::optional<Return> return_stmt, const FunctionSchema& schema) {
+  Argument emitReturn(const SourceRange& range, c10::optional<Return> return_stmt, const FunctionSchema& schema, Block* block) {
     JIT_ASSERT(schema.returns().size() <= 1);
     // outputs
     Value* result = return_stmt ? emitExpr(return_stmt->expr())
@@ -588,12 +592,66 @@ private:
       throw ErrorReport(range) << "Return value was annotated as having type " << result_type->python_str()
         << " but is actually of type " << result->type()->python_str();
     }
-    graph->registerOutput(result);
+    block->registerOutput(result);
     return Argument("", result_type);
   }
   void emitStatements(const List<Stmt>& statements) {
     return emitStatements(statements.begin(), statements.end());
   }
+  std::pair<std::shared_ptr<Graph>, Value*> lambdaLift(Block* block) {
+      auto subgraph = std::make_shared<Graph>();
+      // note: type is set later on pack_context and context when we know it
+      Node* pack_context = graph->insertNode(graph->create(prim::TupleConstruct, {}, 1));
+      Value* context = subgraph->addInput("context");
+      // cannot use createTupleUnpack because the type is not known yet
+      Node* unpack_context = subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
+
+      std::unordered_map<Value*, Value*> captures;
+      auto env = [&](Value* v) -> Value* {
+        auto it = captures.find(v);
+        if (it != captures.end()) {
+            return it->second;
+        }
+        pack_context->addInput(v);
+        Value* r = unpack_context->addOutput()->copyMetadata(v);
+        captures[v] = r;
+        return r;
+      };
+      subgraph->block()->cloneFrom(block, env);
+      auto context_type = TupleType::create(
+          fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
+      pack_context->output()->setType(context_type);
+      context->setType(context_type);
+      return std::make_pair(std::move(subgraph), pack_context->output());
+  }
+  // XXX - right now closures are used _only_ for defining gradients internally
+  // There are several unfinished aspects that make them unusable generally
+  // 1. We do not have a type, ivalue, operator to represent prim::Function, so closure_node has type None
+  //    and any graphs that contain it cannot be run
+  // 2. There is no export logic for it yet, so it cannot be exported/python_printed
+  // 3. There is nothing preventing the assignment of already existing variables inside the closures
+  //    the changes to those variables will just get forgotten.
+  // 4. There is no parsing support in frontend.py, this is intentional since it
+  //    prevents people from accidentally using this feature.
+  void emitClosure(const Def& def) {
+    Node* closure_node = graph->insertNode(graph->create(prim::Function, 1));
+    closure_node->output()->setType(NoneType::get()); //it is not a real thing yet, so just say the type is none.
+    Block* block = closure_node->addBlock();
+    {
+      WithInsertPoint guard(block);
+      pushFrame(block);
+      emitDef(def, nullptr, block); //ignore schema return, we just wont use it for now since we never create a Method for the closure
+      popFrame();
+    }
+    std::shared_ptr<Graph> subgraph;
+    Value* context;
+    std::tie(subgraph, context) = lambdaLift(block);
+    runCleanupPasses(subgraph);
+    closure_node->eraseBlock(0);
+    closure_node->g_(attr::Subgraph, std::move(subgraph));
+    auto tup = graph->insertNode(graph->createTuple({closure_node->output(), context}))->output();
+    environment_stack->setVar(def.name().range(), def.name().name(), tup);
+  }
   void emitStatements(List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end) {
     for (; begin != end; ++begin) {
       auto stmt = *begin;
@@ -637,6 +695,9 @@ private:
         case TK_PASS:
           // Emit nothing for pass
           break;
+        case TK_DEF:
+          emitClosure(Def(stmt));
+          break;
         default:
           throw ErrorReport(stmt)
               << "Unrecognized statement kind " << kindToString(stmt.kind());
index 65201af..95ae7e2 100644 (file)
@@ -392,6 +392,9 @@ struct ParserImpl {
         L.expect(TK_NEWLINE);
         return Pass::create(range);
       }
+      case TK_DEF: {
+        return parseFunction(/*is_method=*/false);
+      }
       default: {
         auto lhs = parseExpOrExpTuple();
         if (L.cur().kind != TK_NEWLINE) {
index 6a8f7b0..bb0f151 100644 (file)
@@ -33,6 +33,7 @@ namespace script {
 //       | Return(List<Expr> values)                                    TK_RETURN
 //       | ExprStmt(List<Expr> expr)                                    TK_EXPR_STMT
 //       | Raise(Expr expr)                                             TK_RAISE
+//       | Def                                                          TK_DEF
 //
 // Expr  = TernaryIf(Expr cond, Expr true_expr, Expr false_expr)        TK_IF_EXPR
 //       | BinOp(Expr lhs, Expr rhs)
@@ -218,6 +219,7 @@ struct Stmt : public TreeView {
       case TK_RAISE:
       case TK_ASSERT:
       case TK_PASS:
+      case TK_DEF:
         return;
       default:
         throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Stmt";
index 9ae53d4..dd670ed 100644 (file)
@@ -1,19 +1,21 @@
 #include <torch/csrc/jit/symbolic_script.h>
 
+
+
 namespace torch { namespace jit {
   namespace {
     std::mutex lock;
-    const std::unordered_map<std::string, std::string> symbolic_scripts({
-      {"aten::mul(Tensor self, Tensor other) -> Tensor",
-R"(
-def forward(self, other):
-    return self * other, (self, other)
-def backward(ctx, grad_output):
-    # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor]
-    self, other = ctx
-    return (grad_output * other).sum_to_size(self.size()), (grad_output * self).sum_to_size(other.size())
-)"},
-      });
+    const std::vector<std::string> functions = {
+      R"(
+        def mul(self, other):
+            def backward(grad_output):
+                grad_self = (grad_output * other).sum_to_size(self.size())
+                grad_other = (grad_output * self).sum_to_size(other.size())
+                return grad_self, grad_other
+            return self * other, backward
+      )"
+    };
+    std::unordered_map<std::string, GradientPair> schema_to_graphs;
 
     // This map is a workaround to cache compiled gradient_pairs. Ideally this graph
     // should be compiled only once and saved in Operator structure.
@@ -21,39 +23,93 @@ def backward(ctx, grad_output):
     std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
   } // anonymous namespace
 
+  std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
+    AT_CHECK(closure->node()->kind() == prim::TupleConstruct, "closure must be a literal tuple construct");
+    Value* fn = closure->node()->inputs().at(0);
+    Value* context = closure->node()->inputs().at(1);
+
+    AT_CHECK(fn->node()->kind() == prim::Function, "closure tuple must contain a prim::Function");
+    return std::make_pair(fn->node()->g(attr::Subgraph), context);
+  }
+
+  Argument originalReturnType(const TupleTypePtr& tup) {
+    AT_CHECK(tup->elements().size() > 1);
+    if(tup->elements().size() == 2)
+      return Argument("", tup->elements().at(0));
+    std::vector<TypePtr> types = tup->elements().vec();
+    types.pop_back();
+    return Argument("", TupleType::create(std::move(types)));
+  }
+
+  void loadModule(const std::shared_ptr<script::Module>& module) {
+    for(const auto& method_ : module->get_methods()) {
+      const auto& method = method_.value();
+      GradientPair pair;
+      pair.forward = method->graph();
+
+      // lookup the backward function
+      Node* forward_tuple = pair.forward->outputs().at(0)->node();
+
+      if (forward_tuple->kind() != prim::TupleConstruct) {
+        throw script::ErrorReport(forward_tuple->getSourceLocation()) << "gradient must return literal a tuple";
+      }
+
+      Value* context;
+      std::tie(pair.backward, context) = extractClosure(forward_tuple->inputs().back());
+
+      // do surgery on the forward function to remove the closure tuple and replace it with the
+      // context variable:
+      //  backward = (<lambda>, context_tuple)
+      //  return original, backward
+      //  -----
+      //  return original, context_tuple
+      std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
+      new_inputs.back() = context;
+      Value* new_tuple = pair.forward->appendNode(pair.forward->createTuple(new_inputs))->output();
+      pair.forward->eraseOutput(0);
+      pair.forward->registerOutput(new_tuple);
+      forward_tuple->destroy();
+
+      // derive schema from original function's schema:
+      const FunctionSchema& loaded_schema = method->getSchema();
+      FunctionSchema actual_schema(Symbol::aten(loaded_schema.name()),
+        loaded_schema.arguments(),
+        {originalReturnType(new_tuple->type()->expect<TupleType>())}
+      );
+      std::string key = canonicalSchemaString(actual_schema);
+      schema_to_graphs[key] = std::move(pair);
+    }
+  }
+
+  void loadFunctions() {
+    for(const std::string& str : functions) {
+      auto cu = std::make_shared<script::Module>();
+      script::defineMethodsInModule(cu, str, script::nativeResolver, nullptr);
+      loadModule(cu);
+    }
+  }
+
   c10::optional<GradientPair> gradientInfoForSchema(const FunctionSchema& schema) {
     std::lock_guard<std::mutex> guard(lock);
+    if (schema_to_graphs.size() == 0) {
+      loadFunctions();
+    }
     auto cache_it = cached_gradient_pairs.find(&schema);
     if (cache_it != cached_gradient_pairs.end()) {
       return cache_it->second;
     } else {
       auto schema_str = canonicalSchemaString(schema);
-      auto sym_script_it = symbolic_scripts.find(schema_str);
-
-      if (sym_script_it != symbolic_scripts.end()) {
-        // Compile the python code to a script module
-        auto cu = std::make_shared<script::Module>();
-        script::defineMethodsInModule(cu, symbolic_scripts.at(schema_str), script::nativeResolver, nullptr);
-        auto fw_graph = cu->find_method("forward")->graph();
-        auto bw_graph = cu->find_method("backward")->graph();
-
-        GradientPair compiled_graphs{fw_graph, bw_graph};
-        cached_gradient_pairs.emplace_hint(cache_it, &schema, compiled_graphs);
-        return compiled_graphs;
+      auto sym_script_it = schema_to_graphs.find(schema_str);
+      if (sym_script_it != schema_to_graphs.end()) {
+        cached_gradient_pairs.emplace_hint(cache_it, &schema, sym_script_it->second);
+        return sym_script_it->second;
       }
     }
     return c10::nullopt;
   }
 
   bool hasGradientInfoForSchema(const FunctionSchema& schema) {
-    std::lock_guard<std::mutex> guard(lock);
-    auto cache_it = cached_gradient_pairs.find(&schema);
-    if (cache_it == cached_gradient_pairs.end()) {
-      auto schema_str = canonicalSchemaString(schema);
-      auto sym_script_it = symbolic_scripts.find(schema_str);
-      return !(sym_script_it == symbolic_scripts.end());
-    }
-    return true;
+    return gradientInfoForSchema(schema).has_value();
   }
-}}
 
+}}