Autograd using torchscript (#14604)
authorAiling Zhang <ailzhang@fb.com>
Wed, 19 Dec 2018 02:56:06 +0000 (18:56 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 03:10:57 +0000 (19:10 -0800)
Summary:
This PR enables autodiff to use the forward/backward graph compiled from python code, instead of using symbolic gradients(modifying the original graph directly).

We put the map in a separate .h file for now to wait for the native_functions.yaml and derivatives.yaml merge. This should ideally go into native_functions.yaml eventually.

This PR should be enough to unblock us for now, we can start writing gradients for aten functions in python.

Differential Revision: D13494635

Pulled By: ailzhang

fbshipit-source-id: f8d51a15243ac46afd09d930c573ccdfcd9fdaaf

21 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/TensorShape.cpp
aten/src/ATen/native/native_functions.yaml
test/cpp/jit/tests.h
test/expect/TestFuser.test_lstm_cuda-backward.expect
test/expect/TestFuser.test_milstm_cuda-backward.expect
test/expect/TestJit.test_cpp_cuda.expect
tools/build_variables.py
torch/CMakeLists.txt
torch/_tensor_docs.py
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/operator.cpp
torch/csrc/jit/operator.h
torch/csrc/jit/passes/dead_code_elimination.cpp
torch/csrc/jit/passes/lower_tuples.cpp
torch/csrc/jit/passes/lower_tuples.h
torch/csrc/jit/symbolic_script.cpp [new file with mode: 0644]
torch/csrc/jit/symbolic_script.h [new file with mode: 0644]

index b455ee3..2a2422c 100644 (file)
@@ -456,6 +456,7 @@ public:
   Tensor sum(IntList dim, bool keepdim, ScalarType dtype) const;
   Tensor sum(IntList dim, bool keepdim=false) const;
   Tensor sum(IntList dim, ScalarType dtype) const;
+  Tensor sum_to_size(IntList size) const;
   Tensor sqrt() const;
   Tensor & sqrt_();
   Tensor std(bool unbiased=true) const;
index 6bc6f1f..a107111 100644 (file)
@@ -589,6 +589,9 @@ inline Tensor Tensor::sum(IntList dim, bool keepdim) const {
 inline Tensor Tensor::sum(IntList dim, ScalarType dtype) const {
     return type().sum(*this, dim, dtype);
 }
+inline Tensor Tensor::sum_to_size(IntList size) const {
+    return type().sum_to_size(*this, size);
+}
 inline Tensor Tensor::sqrt() const {
     return type().sqrt(*this);
 }
index 56885dd..ea57210 100644 (file)
@@ -363,6 +363,7 @@ struct CAFFE2_API Type {
   virtual Tensor sum(const Tensor & self, IntList dim, bool keepdim, ScalarType dtype) const = 0;
   virtual Tensor sum(const Tensor & self, IntList dim, bool keepdim) const = 0;
   virtual Tensor sum(const Tensor & self, IntList dim, ScalarType dtype) const = 0;
+  virtual Tensor sum_to_size(const Tensor & self, IntList size) const = 0;
   virtual Tensor sqrt(const Tensor & self) const = 0;
   virtual Tensor & sqrt_(Tensor & self) const = 0;
   virtual Tensor std(const Tensor & self, bool unbiased) const = 0;
index 951867d..4500bfa 100644 (file)
@@ -626,6 +626,7 @@ _(aten, sub) \
 _(aten, sub_) \
 _(aten, rsub) \
 _(aten, sum) \
+_(aten, sum_to_size) \
 _(aten, svd) \
 _(aten, symeig) \
 _(aten, t) \
index c570fba..02a96ae 100644 (file)
@@ -171,7 +171,7 @@ static Tensor cat_sparse(TensorList tensors, int64_t dim) {
 }
 
 Tensor cat(TensorList tensors, int64_t dim) {
-  if (tensors.size() > 0 && 
+  if (tensors.size() > 0 &&
         tensors[0].is_sparse()) {
     return cat_sparse(tensors, dim);
   }
@@ -291,6 +291,13 @@ Tensor expand_as(const Tensor& self, const Tensor& other) {
   return self.expand(other.sizes());
 }
 
+Tensor sum_to_size(const Tensor& self, IntList size) {
+  AT_CHECK(is_expandable_to(size, self.sizes()),
+           "size {", size, "} is not expandable to size {", self.sizes(), "}.");
+
+  return sum_to(self, size);
+}
+
 Tensor as_strided(const Tensor& self, IntList size, IntList stride, int64_t storage_offset) {
   auto tid = self.type_id();
   AT_CHECK(
index 29c095d..337c450 100644 (file)
 
 - func: sum_out(Tensor result, Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor
 
+- func: sum_to_size(Tensor self, IntList size) -> Tensor
+  variants: method
+  device_guard: false
+
 - func: sqrt(Tensor self) -> Tensor
   variants: function, method
 
index e91c4f5..42e33d8 100644 (file)
@@ -52,6 +52,7 @@
 #include "torch/csrc/jit/passes/shape_analysis.h"
 #include "torch/csrc/jit/passes/utils/subgraph_utils.h"
 #include "torch/csrc/jit/symbolic_variable.h"
+#include "torch/csrc/jit/symbolic_script.h"
 #include "torch/csrc/jit/tracer.h"
 #include "torch/csrc/utils/hash.h"
 #include "torch/csrc/autograd/generated/variable_factories.h"
index af8dedf..a9982b7 100644 (file)
@@ -25,28 +25,28 @@ graph(%0 : Float(*, *)
       %24 : int[]
       %25 : int[]
       %26 : Float(*, *)) {
-  %27 : int[] = aten::size(%26)
+  %27 : Float(*, *) = aten::mul(%0, %26)
   %28 : int[] = aten::size(%outgate)
-  %29 : int[] = aten::size(%cellgate)
-  %30 : int[] = aten::size(%ingate)
-  %31 : int[] = aten::size(%9)
-  %32 : int[] = aten::size(%forgetgate)
-  %33 : Float(*, *) = aten::mul(%0, %26)
-  %34 : Tensor = prim::SumToSize(%33, %28)
-  %35 : Float(*, *) = aten::mul(%0, %outgate)
-  %36 : Tensor = prim::SumToSize(%35, %27)
-  %37 : Tensor = prim::FusionGroup_0(%1, %36, %26)
-  %38 : Tensor = prim::SumToSize(%37, %24)
-  %39 : Tensor = prim::SumToSize(%37, %25)
-  %40 : Tensor = aten::mul(%39, %cellgate)
-  %41 : Tensor = prim::SumToSize(%40, %30)
-  %42 : Tensor = aten::mul(%39, %ingate)
-  %43 : Tensor = prim::SumToSize(%42, %29)
-  %44 : Tensor = aten::mul(%38, %9)
-  %45 : Tensor = prim::SumToSize(%44, %32)
-  %46 : Tensor = aten::mul(%38, %forgetgate)
-  %47 : Tensor = prim::SumToSize(%46, %31)
-  %48 : Tensor = prim::FusionGroup_1(%41, %ingate, %45, %forgetgate, %43, %cellgate, %34, %outgate)
+  %29 : Tensor = aten::sum_to_size(%27, %28)
+  %30 : Float(*, *) = aten::mul(%0, %outgate)
+  %31 : int[] = aten::size(%26)
+  %32 : Tensor = aten::sum_to_size(%30, %31)
+  %33 : Tensor = prim::FusionGroup_0(%1, %32, %26)
+  %34 : Tensor = prim::SumToSize(%33, %24)
+  %35 : Tensor = prim::SumToSize(%33, %25)
+  %36 : Tensor = aten::mul(%35, %cellgate)
+  %37 : int[] = aten::size(%ingate)
+  %38 : Tensor = aten::sum_to_size(%36, %37)
+  %39 : Tensor = aten::mul(%35, %ingate)
+  %40 : int[] = aten::size(%cellgate)
+  %41 : Tensor = aten::sum_to_size(%39, %40)
+  %42 : Tensor = aten::mul(%34, %9)
+  %43 : int[] = aten::size(%forgetgate)
+  %44 : Tensor = aten::sum_to_size(%42, %43)
+  %45 : Tensor = aten::mul(%34, %forgetgate)
+  %46 : int[] = aten::size(%9)
+  %47 : Tensor = aten::sum_to_size(%45, %46)
+  %48 : Tensor = prim::FusionGroup_1(%38, %ingate, %44, %forgetgate, %41, %cellgate, %29, %outgate)
   %49 : Tensor = prim::SumToSize(%48, %19)
   %50 : Tensor = prim::SumToSize(%48, %17)
   %51 : Tensor = prim::SumToSize(%48, %14)
index 10b0d5d..56ddf8f 100644 (file)
@@ -31,60 +31,60 @@ graph(%0 : Float(*, *)
       %30 : int[]
       %31 : Float(*, *)) {
   %32 : int = prim::Constant[value=1]()
-  %33 : int[] = aten::size(%31)
+  %33 : Float(*, *) = aten::mul(%0, %31)
   %34 : int[] = aten::size(%outgate)
-  %35 : int[] = aten::size(%cellgate)
-  %36 : int[] = aten::size(%ingate)
-  %37 : int[] = aten::size(%forgetgate)
-  %38 : int[] = aten::size(%11)
-  %39 : int[] = aten::size(%12)
-  %40 : int[] = aten::size(%Uz)
-  %41 : int[] = aten::size(%18)
-  %42 : int[] = aten::size(%Wx)
-  %43 : int[] = aten::size(%13)
-  %44 : Float(*, *) = aten::mul(%0, %31)
-  %45 : Tensor = prim::SumToSize(%44, %34)
-  %46 : Float(*, *) = aten::mul(%0, %outgate)
-  %47 : Tensor = prim::SumToSize(%46, %33)
-  %48 : Tensor = prim::FusionGroup_0(%1, %47, %31)
-  %49 : Tensor = prim::SumToSize(%48, %29)
-  %50 : Tensor = prim::SumToSize(%48, %30)
-  %51 : Tensor = aten::mul(%50, %cellgate)
-  %52 : Tensor = prim::SumToSize(%51, %36)
-  %53 : Tensor = aten::mul(%50, %ingate)
-  %54 : Tensor = prim::SumToSize(%53, %35)
-  %55 : Tensor = aten::mul(%49, %10)
-  %56 : Tensor = prim::SumToSize(%55, %37)
-  %57 : Tensor = prim::FusionGroup_1(%52, %ingate, %56, %forgetgate, %54, %cellgate, %45, %outgate)
-  %58 : Tensor = prim::SumToSize(%57, %24)
-  %59 : Tensor = prim::SumToSize(%57, %22)
-  %60 : Tensor = aten::mul(%59, %Uz)
-  %61 : Tensor = prim::SumToSize(%60, %38)
-  %62 : Tensor = aten::mul(%59, %11)
-  %63 : Tensor = prim::SumToSize(%62, %40)
-  %64 : Tensor = prim::SumToSize(%57, %19)
-  %65 : Tensor = prim::SumToSize(%57, %20)
-  %66 : Tensor = aten::mul(%65, %Wx)
-  %67 : Tensor = prim::SumToSize(%66, %39)
-  %68 : Tensor = aten::mul(%65, %12)
-  %69 : Tensor = prim::SumToSize(%68, %42)
-  %70 : Tensor = aten::mul(%64, %Uz)
-  %71 : Tensor = prim::SumToSize(%70, %41)
-  %72 : Tensor = aten::mul(%64, %18)
-  %73 : Tensor = prim::SumToSize(%72, %40)
-  %74 : Tensor = aten::add(%63, %73, %32)
-  %75 : Tensor = aten::mul(%71, %Wx)
-  %76 : Tensor = prim::SumToSize(%75, %43)
-  %77 : Tensor = aten::mul(%71, %13)
-  %78 : Tensor = prim::SumToSize(%77, %42)
-  %79 : Tensor = aten::add(%69, %78, %32)
+  %35 : Tensor = aten::sum_to_size(%33, %34)
+  %36 : Float(*, *) = aten::mul(%0, %outgate)
+  %37 : int[] = aten::size(%31)
+  %38 : Tensor = aten::sum_to_size(%36, %37)
+  %39 : Tensor = prim::FusionGroup_0(%1, %38, %31)
+  %40 : Tensor = prim::SumToSize(%39, %29)
+  %41 : Tensor = prim::SumToSize(%39, %30)
+  %42 : Tensor = aten::mul(%41, %cellgate)
+  %43 : int[] = aten::size(%ingate)
+  %44 : Tensor = aten::sum_to_size(%42, %43)
+  %45 : Tensor = aten::mul(%41, %ingate)
+  %46 : int[] = aten::size(%cellgate)
+  %47 : Tensor = aten::sum_to_size(%45, %46)
+  %48 : Tensor = aten::mul(%40, %10)
+  %49 : int[] = aten::size(%forgetgate)
+  %50 : Tensor = aten::sum_to_size(%48, %49)
+  %51 : Tensor = prim::FusionGroup_1(%44, %ingate, %50, %forgetgate, %47, %cellgate, %35, %outgate)
+  %52 : Tensor = prim::SumToSize(%51, %24)
+  %53 : Tensor = prim::SumToSize(%51, %22)
+  %54 : Tensor = aten::mul(%53, %Uz)
+  %55 : int[] = aten::size(%11)
+  %56 : Tensor = aten::sum_to_size(%54, %55)
+  %57 : Tensor = aten::mul(%53, %11)
+  %58 : int[] = aten::size(%Uz)
+  %59 : Tensor = aten::sum_to_size(%57, %58)
+  %60 : Tensor = prim::SumToSize(%51, %19)
+  %61 : Tensor = prim::SumToSize(%51, %20)
+  %62 : Tensor = aten::mul(%61, %Wx)
+  %63 : int[] = aten::size(%12)
+  %64 : Tensor = aten::sum_to_size(%62, %63)
+  %65 : Tensor = aten::mul(%61, %12)
+  %66 : int[] = aten::size(%Wx)
+  %67 : Tensor = aten::sum_to_size(%65, %66)
+  %68 : Tensor = aten::mul(%60, %Uz)
+  %69 : int[] = aten::size(%18)
+  %70 : Tensor = aten::sum_to_size(%68, %69)
+  %71 : Tensor = aten::mul(%60, %18)
+  %72 : Tensor = aten::sum_to_size(%71, %58)
+  %73 : Tensor = aten::add(%59, %72, %32)
+  %74 : Tensor = aten::mul(%70, %Wx)
+  %75 : int[] = aten::size(%13)
+  %76 : Tensor = aten::sum_to_size(%74, %75)
+  %77 : Tensor = aten::mul(%70, %13)
+  %78 : Tensor = aten::sum_to_size(%77, %66)
+  %79 : Tensor = aten::add(%67, %78, %32)
   %80 : Float(*, *) = aten::t(%14)
-  %81 : Float(*, *) = aten::mm(%80, %74)
+  %81 : Float(*, *) = aten::mm(%80, %73)
   %82 : Float(*, *) = aten::t(%81)
   %83 : Float(*, *) = aten::t(%15)
   %84 : Float(*, *) = aten::mm(%83, %79)
   %85 : Float(*, *) = aten::t(%84)
-  return (%58, %61, %67, %76, %82, %85);
+  return (%52, %56, %64, %76, %82, %85);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Tensor
index 1c69ed1..1872b15 100644 (file)
@@ -106,37 +106,39 @@ graph(%0 : Float(2, 3, 4)
       %3 : Float(2, 3, 4)
       %4 : Float(2, 3, 4)
       %5 : int[]) {
-  %9 : int = prim::Constant[value=1]()
-  %6 : int[] = aten::size(%4)
-  %7 : int[] = aten::size(%3)
-  %8 : int[] = aten::size(%2)
-  %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::add"](%0)
+  %7 : int = prim::Constant[value=1]()
+  %6 : int[] = aten::size(%3)
+  %8 : Tensor, %9 : Tensor = prim::GradOf[name="aten::add"](%0)
     block0() {
-      %12 : Tensor = prim::SumToSize(%0, %5)
-      %13 : Float(2, 3, 4) = aten::mul(%0, %9)
-      %14 : Tensor = prim::SumToSize(%13, %7)
-      -> (%12, %14)
+      %10 : Tensor = prim::SumToSize(%0, %5)
+      %11 : Float(2, 3, 4) = aten::mul(%0, %7)
+      %12 : Tensor = prim::SumToSize(%11, %6)
+      -> (%10, %12)
     }
-  %15 : Tensor, %16 : Tensor = prim::GradOf[name="aten::mul"](%10)
+  %13 : Tensor, %14 : Tensor = prim::GradOf[name="aten::mul"](%8)
     block0() {
-      %17 : Tensor = aten::mul(%10, %2)
-      %18 : Tensor = prim::SumToSize(%17, %6)
-      %19 : Tensor = aten::mul(%10, %4)
-      %20 : Tensor = prim::SumToSize(%19, %8)
-      -> (%18, %20)
+      %15 : Tensor = aten::mul(%8, %2)
+      %16 : int[] = aten::size(%4)
+      %17 : Tensor = aten::sum_to_size(%15, %16)
+      %18 : Tensor = aten::mul(%8, %4)
+      %19 : int[] = aten::size(%2)
+      %20 : Tensor = aten::sum_to_size(%18, %19)
+      -> (%17, %20)
     }
-  %21 : Tensor = prim::AutogradAdd(%1, %15)
+  %21 : Tensor = prim::AutogradAdd(%1, %13)
   %22 : Tensor, %23 : Tensor = prim::GradOf[name="aten::mul"](%21)
     block0() {
       %24 : Tensor = aten::mul(%21, %3)
-      %25 : Tensor = prim::SumToSize(%24, %8)
-      %26 : Tensor = aten::mul(%21, %2)
-      %27 : Tensor = prim::SumToSize(%26, %7)
-      -> (%25, %27)
+      %25 : int[] = aten::size(%2)
+      %26 : Tensor = aten::sum_to_size(%24, %25)
+      %27 : Tensor = aten::mul(%21, %2)
+      %28 : int[] = aten::size(%3)
+      %29 : Tensor = aten::sum_to_size(%27, %28)
+      -> (%26, %29)
     }
-  %28 : Tensor = prim::AutogradAdd(%16, %22)
-  %29 : Tensor = prim::AutogradAdd(%11, %23)
-  return (%28, %29);
+  %30 : Tensor = prim::AutogradAdd(%14, %22)
+  %31 : Tensor = prim::AutogradAdd(%9, %23)
+  return (%30, %31);
 }
 
 testDifferentiateWithRequiresGrad
@@ -145,11 +147,9 @@ graph(%0 : Float(*)
   %2 : Float(*) = aten::mul(%1, %1)
   %3 : int = prim::Constant[value=1]()
   %4 : Float(*) = aten::add(%2, %1, %3)
-  %26 : int[] = aten::size(%4)
   %6 : Float(*) = aten::add(%4, %0, %3)
   %7 : Float(*) = aten::mul(%6, %0)
   %11 : int[] = aten::size(%7)
-  %14 : int[] = aten::size(%1)
   %9 : Float(*) = aten::add(%7, %1, %3)
   return (%4, %9, %6, %11);
 }
@@ -158,30 +158,31 @@ graph(%0 : Float(*)
       %2 : Float(*)
       %3 : Float(*)
       %4 : int[]) {
-  %7 : int = prim::Constant[value=1]()
-  %5 : int[] = aten::size(%3)
-  %6 : int[] = aten::size(%2)
-  %8 : Tensor = prim::GradOf[name="aten::add"](%0)
+  %6 : int = prim::Constant[value=1]()
+  %5 : int[] = aten::size(%2)
+  %7 : Tensor = prim::GradOf[name="aten::add"](%0)
     block0() {
-      %9 : Tensor = prim::SumToSize(%0, %4)
-      -> (%9)
+      %8 : Tensor = prim::SumToSize(%0, %4)
+      -> (%8)
     }
-  %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::mul"](%8)
+  %9 : Tensor, %10 : Tensor = prim::GradOf[name="aten::mul"](%7)
     block0() {
-      %12 : Tensor = aten::mul(%8, %2)
-      %13 : Tensor = prim::SumToSize(%12, %5)
-      %14 : Tensor = aten::mul(%8, %3)
-      %15 : Tensor = prim::SumToSize(%14, %6)
-      -> (%13, %15)
+      %11 : Tensor = aten::mul(%7, %2)
+      %12 : int[] = aten::size(%3)
+      %13 : Tensor = aten::sum_to_size(%11, %12)
+      %14 : Tensor = aten::mul(%7, %3)
+      %15 : int[] = aten::size(%2)
+      %16 : Tensor = aten::sum_to_size(%14, %15)
+      -> (%13, %16)
     }
-  %16 : Tensor = prim::AutogradAdd(%1, %10)
-  %17 : Tensor = prim::GradOf[name="aten::add"](%16)
+  %17 : Tensor = prim::AutogradAdd(%1, %9)
+  %18 : Tensor = prim::GradOf[name="aten::add"](%17)
     block0() {
-      %18 : Tensor = aten::mul(%16, %7)
-      %19 : Tensor = prim::SumToSize(%18, %6)
-      -> (%19)
+      %19 : Tensor = aten::mul(%17, %6)
+      %20 : Tensor = prim::SumToSize(%19, %5)
+      -> (%20)
     }
-  %20 : Tensor = prim::AutogradAdd(%11, %17)
-  return (%20);
+  %21 : Tensor = prim::AutogradAdd(%10, %18)
+  return (%21);
 }
 
index d4bb1bb..f9b71fe 100644 (file)
@@ -56,6 +56,7 @@ torch_sources_no_python_default = [
     "torch/csrc/jit/import.cpp",
     "torch/csrc/jit/interpreter.cpp",
     "torch/csrc/jit/ir.cpp",
+    "torch/csrc/jit/symbolic_script.cpp",
     "torch/csrc/jit/operator.cpp",
     "torch/csrc/jit/passes/alias_analysis.cpp",
     "torch/csrc/jit/passes/batch_mm.cpp",
index d8ad5c0..5b3e832 100644 (file)
@@ -158,7 +158,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/node_hashing.cpp
   ${TORCH_SRC_DIR}/csrc/jit/ir.cpp
   ${TORCH_SRC_DIR}/csrc/jit/operator.cpp
-  ${TORCH_SRC_DIR}/csrc/jit/operator.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/symbolic_script.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/alias_analysis.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize.cpp
index c9170c0..a7a344e 100644 (file)
@@ -2832,6 +2832,18 @@ Args:
         as :attr:`other`.
 """)
 
+add_docstr_all('sum_to_size',
+               r"""
+sum_to_size(*size) -> Tensor
+
+Sum ``this`` tensor to :attr:`size`.
+:attr:`size` must be broadcastable to ``this`` tensor size.
+Args:
+    other (:class:`torch.Tensor`): The result tensor has the same size
+        as :attr:`other`.
+""")
+
+
 add_docstr_all('zero_',
                r"""
 zero_() -> Tensor
index 12b0a7b..4f23150 100644 (file)
@@ -1,11 +1,14 @@
 #include <torch/csrc/jit/autodiff.h>
 
+#include "torch/csrc/jit/passes/lower_tuples.h"
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
+#include "torch/csrc/jit/symbolic_script.h"
 #include <torch/csrc/jit/symbolic_variable.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/utils/functional.h>
+#include "torch/csrc/jit/script/compiler.h"
 
 #include <torch/csrc/jit/assertions.h>
 
@@ -110,6 +113,11 @@ bool isDifferentiable(Node * n) {
   if (differentiable_ops.find(n))
     return true;
 
+  auto schema = n->maybeSchema();
+  if (schema && hasGradientInfoForSchema(*schema)) {
+    return true;
+  }
+
   if (n->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
     return n->get<std::vector<int64_t>>(attr::size) && n->is_constant(attr::implicit) &&
       n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
@@ -142,6 +150,72 @@ bool isDifferentiable(Graph & g) {
                      static_cast<bool(*)(Node*)>(isDifferentiable));
 }
 
+// TODO: Remove this after #15355.
+namespace {
+  std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
+    auto outputs = script::inlineCallTo(g, callee, inputs);
+    if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
+      auto tc = script::createTupleUnpack(outputs.at(0));
+      outputs = std::vector<Value*>(tc.begin(), tc.end());
+    }
+    return outputs;
+  }
+} //anonymous namespace
+
+
+// NB: Write gradient using torchscript
+// For example, node aten::mul() should be defined as follows
+// def forward(x, y):
+//     return x*y, (x, y)
+// def backward(ctx, grad_output):
+//     x, y = ctx
+//     return (y * grad_output).sum_to_size(x), (x * grad_output).sum_to_size(y)
+//
+// Here ctx is a tuple that carries all input/intermediate results needed in
+// backward from forward pass.
+// This python code is compiled into a GradientPair which includes a forward graph
+// and a backward graph. Forward graph will be used to replace the node in grad_desc.f,
+// and backward graph will be used to construct GradOf(node) in reverse_block.
+// Grad_values(a.k.a gradOutputs) propagated through node->owningGraph() in
+// **reversed** order, thus GradientPair.forward ahould be inserted **after**
+// the node being replaced, so that we don't traverse the graph infinite times.
+// The output of compiled forward graph is [real_outputs, ctx]
+// The input of compiled backward graph is [ctx, grad_values]
+// We run LowerSimpleTuples afterwards to elmininate all tuples generated in this process.
+// The original node and TupleConstruct nodes in forward graph will be cleaned up
+// later using EliminateDeadCode(block).
+// TupleUnPack node in backward graph will be removed in eliminateDeadcode(ReverseDetails)
+// defined in this file.
+static c10::optional<std::vector<Value*>> build_script_grad(
+        Node* node,
+        const ArrayRef<Value*>& grads) {
+  auto graph = node->owningGraph();
+
+  auto compiled_graphs = gradientInfoForSchema(node->schema());
+  if (!compiled_graphs) {
+    return c10::nullopt;
+  }
+  // Use forward graph to replace node in grad_desc.f
+  value_list new_outputs;
+  {
+    WithInsertPoint guard(node->next());
+    auto fw_graph = compiled_graphs->forward;
+    new_outputs = inlineUnpackedCallTo(*graph, *fw_graph, node->inputs());
+    for (size_t i = 0; i < node->outputs().size(); ++i) {
+      new_outputs.at(i)->setType(node->outputs()[i]->type());
+      new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]);
+    }
+  }
+
+  // Use backward graph to construct reverse_block
+  auto bw_graph = compiled_graphs->backward;
+  auto grad_vec = grads.vec();
+  auto it = grad_vec.begin();
+  grad_vec.insert(it, new_outputs.back());
+  ArrayRef<Value*> grad(grad_vec);
+  auto grad_inputs = inlineUnpackedCallTo(*graph, *bw_graph, grad);
+  return grad_inputs;
+};
 
 static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_values) {
   static const OperatorSet comparison_ops = {
@@ -543,6 +617,12 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
     throw std::runtime_error(std::string("differentiation of ") + node->kind().toDisplayString() + " "
                              "is not supported, or it is missing necessary type information");
   }
+  // If AD is defined using torchscript, use it instead of symbolic
+  auto script_grads = build_script_grad(node, grad_values);
+  if (script_grads)
+    return *script_grads;
+  // Definition not found in torchscript, look up in the build_sym_grad
+  // TODO: migrate all to using torchscript
   auto sym_grads = build_sym_grad(fmap<SymbolicVariable>(grad_values));
   return fmap(sym_grads, [](const SymbolicVariable &v) { return v.value(); });
 }
@@ -647,6 +727,8 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) {
     }
 
     value_list grad_inputs = linearGradientForNode(node, fmap(node->outputs(), get_grad));
+    LowerSimpleTuples(reverse_block);
+
     JIT_ASSERT(grad_inputs.size() == node->inputs().size());
     for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
       if (!inputs[i]->requires_grad()) continue;
@@ -670,6 +752,7 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) {
     reverse_block->registerOutput(get_grad(input));
     grad_desc.df_output_vjps.push_back(i);
   }
+
   return ReverseDetails(std::move(grad_map), reverse_block);
 }
 
@@ -918,6 +1001,9 @@ Gradient differentiate(std::shared_ptr<Graph>& graph) {
   // Fills in df_input_vjps and df_output_vjps
   auto rev_info = addReverseInline(grad_desc);
   Optimize(grad_desc, rev_info);
+  // Clean up old nodes which has been replaced by forward graphs in torchscript
+  EliminateDeadCode(grad_desc.f->block());
+
   // Fills in f, df, f_real_outputs, df_input_captures,
   // modifies df_input_vjps (new vjps are added for temporaries)
   lambdaLiftReverse(grad_desc, rev_info);
index dbe7e64..5168702 100644 (file)
@@ -352,38 +352,6 @@ struct SchemaParser {
 } // namespace script
 
 namespace {
-
-std::string canonicalSchemaString(const FunctionSchema& schema) {
-  std::ostringstream out;
-
-  out << schema.name();
-  out << "(";
-
-  bool seen_kwarg_only = false;
-  for(size_t i = 0; i < schema.arguments().size(); ++i) {
-    if (i > 0) out << ", ";
-    if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
-      out << "*, ";
-      seen_kwarg_only = true;
-    }
-    const auto & arg = schema.arguments()[i];
-    out << arg.type()->str() << " " << arg.name();
-  }
-
-  out << ") -> ";
-  if (schema.returns().size() == 1) {
-    out << schema.returns().at(0).type()->str();
-  } else if (schema.returns().size() > 1) {
-    out << "(";
-    for (size_t i = 0; i < schema.returns().size(); ++i) {
-      if (i > 0) out << ", ";
-      out << schema.returns()[i].type()->str();
-    }
-    out << ")";
-  }
-  return out.str();
-}
-
 using OperatorMap = std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
 struct OperatorRegistry  {
 private:
@@ -484,6 +452,37 @@ FunctionSchema parseSchema(const std::string& schema) {
   return script::SchemaParser(schema).parseDeclarations().at(0);
 }
 
+std::string canonicalSchemaString(const FunctionSchema& schema) {
+  std::ostringstream out;
+
+  out << schema.name();
+  out << "(";
+
+  bool seen_kwarg_only = false;
+  for(size_t i = 0; i < schema.arguments().size(); ++i) {
+    if (i > 0) out << ", ";
+    if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
+      out << "*, ";
+      seen_kwarg_only = true;
+    }
+    const auto & arg = schema.arguments()[i];
+    out << arg.type()->str() << " " << arg.name();
+  }
+
+  out << ") -> ";
+  if (schema.returns().size() == 1) {
+    out << schema.returns().at(0).type()->str();
+  } else if (schema.returns().size() > 1) {
+    out << "(";
+    for (size_t i = 0; i < schema.returns().size(); ++i) {
+      if (i > 0) out << ", ";
+      out << schema.returns()[i].type()->str();
+    }
+    out << ")";
+  }
+  return out.str();
+}
+
 bool Operator::matches(const Node* node) const {
   // wrong name
   if (node->kind().toQualString() != schema().name()) {
index 310271f..5d66f5a 100644 (file)
@@ -80,6 +80,8 @@ private:
  OperationCreator op_creator_;
 };
 
+TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
+
 TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name);
 std::shared_ptr<Operator> findOperatorFor(const Node* node);
 const Operator& getOperatorFor(const Node* node);
index b4dc733..0167d03 100644 (file)
@@ -205,7 +205,11 @@ class DeadCodeEliminator {
           sweep(block, true);
         }
       }
-      if (!marked_.count(node)) {
+      // NB: Checking hasUses() is required. AD graphs are not perfectly
+      // valid, as a node in grad_desc.f might be used in reverse_block.
+      // Reverse_block is inlined in grad_desc.f before it's separated
+      // to grad_desc.df.
+      if (!(marked_.count(node) || node->hasUses())) {
         it.destroyCurrent();
       }
     }
index d7a7de3..cde02d6 100644 (file)
@@ -161,7 +161,7 @@ void LowerAllTuples(std::shared_ptr<Graph>& graph) {
   EnsureNoTuples(graph->block());
 }
 
-static void LowerSimpleTuples(Block* block) {
+void LowerSimpleTuples(Block* block) {
   for(auto n : block->nodes()) {
     removeTupleNodes(n, /*must_remove_tuples*/false);
     for(auto b : n->blocks()) {
index b41e383..4e7c990 100644 (file)
@@ -13,5 +13,6 @@ TORCH_API void LowerSimpleTuples(std::shared_ptr<Graph>& graph);
 // but will not work on graphs whose inputs contain tuples.
 TORCH_API void LowerAllTuples(std::shared_ptr<Graph>& graph);
 
+TORCH_API void LowerSimpleTuples(Block* block);
 
 }}
diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp
new file mode 100644 (file)
index 0000000..9ae53d4
--- /dev/null
@@ -0,0 +1,59 @@
+#include <torch/csrc/jit/symbolic_script.h>
+
+namespace torch { namespace jit {
+  namespace {
+    std::mutex lock;
+    const std::unordered_map<std::string, std::string> symbolic_scripts({
+      {"aten::mul(Tensor self, Tensor other) -> Tensor",
+R"(
+def forward(self, other):
+    return self * other, (self, other)
+def backward(ctx, grad_output):
+    # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor]
+    self, other = ctx
+    return (grad_output * other).sum_to_size(self.size()), (grad_output * self).sum_to_size(other.size())
+)"},
+      });
+
+    // This map is a workaround to cache compiled gradient_pairs. Ideally this graph
+    // should be compiled only once and saved in Operator structure.
+    // This should be done along with merging into native_functions.yaml.
+    std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
+  } // anonymous namespace
+
+  c10::optional<GradientPair> gradientInfoForSchema(const FunctionSchema& schema) {
+    std::lock_guard<std::mutex> guard(lock);
+    auto cache_it = cached_gradient_pairs.find(&schema);
+    if (cache_it != cached_gradient_pairs.end()) {
+      return cache_it->second;
+    } else {
+      auto schema_str = canonicalSchemaString(schema);
+      auto sym_script_it = symbolic_scripts.find(schema_str);
+
+      if (sym_script_it != symbolic_scripts.end()) {
+        // Compile the python code to a script module
+        auto cu = std::make_shared<script::Module>();
+        script::defineMethodsInModule(cu, symbolic_scripts.at(schema_str), script::nativeResolver, nullptr);
+        auto fw_graph = cu->find_method("forward")->graph();
+        auto bw_graph = cu->find_method("backward")->graph();
+
+        GradientPair compiled_graphs{fw_graph, bw_graph};
+        cached_gradient_pairs.emplace_hint(cache_it, &schema, compiled_graphs);
+        return compiled_graphs;
+      }
+    }
+    return c10::nullopt;
+  }
+
+  bool hasGradientInfoForSchema(const FunctionSchema& schema) {
+    std::lock_guard<std::mutex> guard(lock);
+    auto cache_it = cached_gradient_pairs.find(&schema);
+    if (cache_it == cached_gradient_pairs.end()) {
+      auto schema_str = canonicalSchemaString(schema);
+      auto sym_script_it = symbolic_scripts.find(schema_str);
+      return !(sym_script_it == symbolic_scripts.end());
+    }
+    return true;
+  }
+}}
+
diff --git a/torch/csrc/jit/symbolic_script.h b/torch/csrc/jit/symbolic_script.h
new file mode 100644 (file)
index 0000000..45496ab
--- /dev/null
@@ -0,0 +1,18 @@
+#pragma once
+// This file is temporary until native_functions.yaml and derivatives.yaml are merged.
+// Ideally this should all go into native_functions.yaml
+
+#include <c10/util/Optional.h>
+#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/module.h>
+#include <torch/csrc/jit/operator.h>
+
+namespace torch { namespace jit {
+  struct GradientPair {
+    std::shared_ptr<Graph> forward;
+    std::shared_ptr<Graph> backward;
+  };
+
+  TORCH_API c10::optional<GradientPair> gradientInfoForSchema(const FunctionSchema& schema);
+  TORCH_API bool hasGradientInfoForSchema(const FunctionSchema& schema);
+}}