Add addcmul, lerp to fuser, enable scalar->float specialization in symbolic script...
authorWanchao Liang <wanchaol@users.noreply.github.com>
Mon, 25 Mar 2019 18:02:17 +0000 (11:02 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 25 Mar 2019 18:05:45 +0000 (11:05 -0700)
Summary:
This PR did two things:

1. Enable scalar->float specialization in symbolic script, so AD formula that contains scalar in the schema, should write `float` instead.
2. add addcmul, lerp to AD and fuser.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18081

Differential Revision: D14490493

Pulled By: wanchaol

fbshipit-source-id: b3b86d960d5f051b30733bc908b19786111cdaa4

test/test_jit.py
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/fuser/codegen.cpp
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/symbolic_script.cpp
torch/csrc/jit/symbolic_script.h

index 8606a6e..36d51b9 100644 (file)
@@ -12196,6 +12196,45 @@ class TestFuser(JitTestCase):
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
     @skipIfRocm
+    def test_addcmul_cuda(self):
+        t = torch.randn(1, 4, dtype=torch.float, device='cuda')
+        t1 = torch.randn(4, 1, dtype=torch.float, device='cuda')
+        t2 = torch.randn(1, 4, dtype=torch.float, device='cuda')
+
+        def foo(t, t1, t2):
+            return t.addcmul(t + 1, t2, value=0.1)
+
+        ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
+        graph = ge.graph_for(t, t1, t2)
+        self.assertAllFused(graph)
+
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_lerp_cuda(self):
+        start = torch.randn(4, 1, dtype=torch.float, device='cuda')
+        end = torch.randn(1, 4, dtype=torch.float, device='cuda')
+        weight = torch.tensor(0.5, dtype=torch.float, device='cuda')
+
+        # scalar weight overload
+        def foo_weight_scalar(start, end):
+            return torch.lerp(start + 1, end, 0.5)
+
+        # tensor weight overload
+        def foo_weight_tensor(start, end):
+            return torch.lerp(start + 1, end, weight)
+
+        ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end))
+        graph = ge_weight_scalar.graph_for(start, end)
+        self.assertAllFused(graph)
+
+        ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end))
+        graph = ge_weight_tensor.graph_for(start, end)
+        self.assertAllFused(graph)
+
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
     def test_concat_cuda(self):
         hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
         cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
index cfe244f..61215a8 100644 (file)
@@ -129,7 +129,8 @@ bool isDifferentiable(Node* n) {
     return true;
 
   if (n->matches(
-          "aten::dropout(Tensor input, float p, bool train) -> Tensor", attr::train)) {
+          "aten::dropout(Tensor input, float p, bool train) -> Tensor",
+          attr::train)) {
     return n->get<bool>(attr::train).value();
   }
 
@@ -164,6 +165,14 @@ bool isDifferentiable(Node* n) {
         static_cast<bool (*)(Node*)>(isDifferentiable));
   }
 
+  // formulas are only defined with floating point scalars,
+  // so we fallback to autograd for other cases.
+  for (const Value* input : n->inputs()) {
+    if (input->type() == NumberType::get()) {
+      return false;
+    }
+  }
+
   return false;
 }
 
@@ -204,7 +213,6 @@ static c10::optional<std::vector<Value*>> build_script_grad(
     Node* node,
     const ArrayRef<Value*>& grads) {
   auto graph = node->owningGraph();
-
   auto compiled_graphs = gradientInfoForSchema(node->schema());
   if (!compiled_graphs) {
     return c10::nullopt;
@@ -228,7 +236,7 @@ static c10::optional<std::vector<Value*>> build_script_grad(
   auto bw_graph = compiled_graphs->backward;
   auto grad_vec = grads.vec();
   if (needTrimGrad(node)) {
-    grad_vec.erase(grad_vec.begin()+1, grad_vec.end());
+    grad_vec.erase(grad_vec.begin() + 1, grad_vec.end());
   }
   auto it = grad_vec.begin();
   grad_vec.insert(it, new_outputs.back());
index fd5beef..80535b3 100644 (file)
@@ -208,6 +208,7 @@ static std::string encodeRHS(const Node* n) {
       {aten::__or__, "${0} || ${1}"},
       {aten::__rshift__, "${0} >> ${1}"},
       {aten::__xor__, "${0} ^ ${1}"},
+      {aten::addcmul, "${cast_0} + ${cast_3} * ${cast_1} * ${cast_2}"},
       {aten::div, "${cast_0} / ${cast_1}"},
       {aten::eq, "${0} == ${1}"},
       {aten::fmod, "fmodf(${cast_0}, ${cast_1})"},
@@ -215,6 +216,7 @@ static std::string encodeRHS(const Node* n) {
       {aten::gt, "${0} > ${1}"},
       {aten::le, "(${0} <= ${1})"},
       {aten::lt, "${0} < ${1}"},
+      {aten::lerp, "${cast_0} + ${cast_2} * (${cast_1} - ${cast_0})"},
       {aten::type_as, "(${cast_0})"},
       {aten::mul, "${cast_0} * ${cast_1}"},
       {aten::ne, "${0} != ${1}"},
index 15def35..f7cd332 100644 (file)
@@ -8,8 +8,6 @@
 #include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/interpreter.h>
 #include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/resource_guard.h>
-#include <ATen/core/ivalue.h>
 #include <torch/csrc/jit/passes/batch_mm.h>
 #include <torch/csrc/jit/passes/canonicalize_ops.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
@@ -27,6 +25,7 @@
 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
+#include <torch/csrc/jit/resource_guard.h>
 #include <torch/csrc/jit/symbolic_variable.h>
 #include <torch/csrc/jit/tracer.h>
 
index 5923582..ad1a835 100644 (file)
@@ -59,6 +59,8 @@ bool isSimpleMap(Node* node) {
       "aten::log10(Tensor self) -> Tensor",
       "aten::log1p(Tensor self) -> Tensor",
       "aten::log2(Tensor self) -> Tensor",
+      "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
+      "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
       "aten::max(Tensor self, Tensor other) -> Tensor",
       "aten::min(Tensor self, Tensor other) -> Tensor",
       "aten::mul(Tensor self, Tensor other) -> Tensor",
@@ -98,6 +100,7 @@ bool isSimpleMap(Node* node) {
       "aten::lt(Tensor self, Tensor other) -> Tensor",
       "aten::lt(Tensor self, Scalar other) -> Tensor",
 
+      "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor",
       "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
 
       "aten::type_as(Tensor self, Tensor other) -> Tensor",
index 98e5a74..cce3552 100644 (file)
@@ -383,6 +383,18 @@ const std::vector<std::string> functions = {
             return torch.matmul(self, other), backward
     )",
     R"(
+        def addcmul(self,
+                    tensor1,
+                    tensor2,
+                    *,
+                    value: float = 1.0):
+            def backward(grad_output):
+                grad = grad_output * value
+                grad_tensor1 = (grad * tensor2)._grad_sum_to_size(tensor1.size())
+                grad_tensor2 = (grad * tensor1)._grad_sum_to_size(tensor2.size())
+                return grad_output._grad_sum_to_size(self.size()), grad_tensor1, grad_tensor2, None
+            return torch.addcmul(self, tensor1, tensor2, value=value), backward
+
         def _dim_arange(like,
                         dim: int):
             def backward(grad_output):
@@ -439,6 +451,24 @@ const std::vector<std::string> functions = {
 
             return torch.full_like(self, fill_value), backward
 
+        def lerp_0(self,
+                   end,
+                   weight: float):
+            def backward(grad_output):
+                grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self.size())
+                grad_end = (grad_output * weight)._grad_sum_to_size(end.size())
+                return grad_self, grad_end, None
+            return torch.lerp(self, end, weight), backward
+
+        def lerp_1(self,
+                   end,
+                   weight):
+            def backward(grad_output):
+                grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self.size())
+                grad_end = (grad_output * weight)._grad_sum_to_size(end.size())
+                return grad_self, grad_end, None
+            return torch.lerp(self, end, weight), backward
+
         def mul(self, other):
             def backward(grad_output):
                 # self & other are used in backward. No need to pass in their size
@@ -889,6 +919,7 @@ std::string overloadedSchemaString(const FunctionSchema& schema) {
         schema_name.length(),
         schema_name.substr(0, pos));
   }
+
   return schema_string;
 }
 
@@ -969,6 +1000,15 @@ c10::optional<GradientPair> gradientInfoForSchema(
     return cache_it->second;
   } else {
     auto schema_str = canonicalSchemaString(schema);
+    // Specialize Scalar to float for the arg type of the node schema
+    // this is used to:
+    // 1. define scalar type as float in TorchScript autodiff formula
+    // 2. to make sure the input of any graph node does not contain scalar type
+    //    in its argument, all scalar arg should already be passed with float
+    //    value since scalar/int aren't differentiable either way.
+    //
+    c10::ReplaceAll(schema_str, "Scalar", "float");
+
     auto sym_script_it = schema_to_graphs.find(schema_str);
 
     if (sym_script_it != schema_to_graphs.end()) {
index bc8284c..6c17eb5 100644 (file)
@@ -3,6 +3,7 @@
 // merged. Ideally this should all go into native_functions.yaml
 
 #include <c10/util/Optional.h>
+#include <c10/util/StringUtil.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/script/module.h>