pow scalar exponent / base autodiff, fusion (#19324)
authorThomas Viehmann <tv.code@beamnet.de>
Fri, 19 Apr 2019 00:52:33 +0000 (17:52 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 19 Apr 2019 00:58:35 +0000 (17:58 -0700)
Summary:
Fixes: #19253

Fixing pow(Tensor, float) is straightforward.
The breakage for pow(float, Tensor) is a bit more subtle to trigger, and fixing needs `torch.log` (`math.log` didn't work) from the newly merged #19115  (Thanks ngimel for pointing out this has landed.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19324

Differential Revision: D15003531

Pulled By: ailzhang

fbshipit-source-id: 8b22138fa27a43806b82886fb3a7b557bbb5a865

test/test_jit.py
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/symbolic_script.cpp

index da4a20f..6e1e6c3 100644 (file)
@@ -3292,6 +3292,27 @@ a")
         self.checkScript(func, (a, b), optimize=True)
         self.checkScript(func2, (a, b, c, d), optimize=True)
 
+    @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
+    def test_pow_scalar_backward_cuda(self):
+        # see that scalar exponent works with cuda base (#19253)
+
+        for dtype in [torch.float, torch.double]:
+            @torch.jit.script
+            def func(a, b):
+                # type: (Tensor, float) -> Tensor
+                return (a * 2) ** b
+
+            a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype)
+            func(a, 1).backward()
+
+            @torch.jit.script
+            def func(a, b):
+                # type: (float, Tensor) -> Tensor
+                return a ** (b * 2 + 1)
+
+            a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype)
+            func(2, a).backward()
+
     def test_triple(self):
         def func(x):
             return 3. * x
index 20a3373..f0d9115 100644 (file)
@@ -67,6 +67,7 @@ bool isSimpleMap(Node* node) {
       "aten::neg(Tensor self) -> Tensor",
       "aten::pow(Tensor self, Tensor exponent) -> Tensor",
       "aten::pow(Tensor self, Scalar exponent) -> Tensor",
+      "aten::pow(Scalar self, Tensor exponent) -> Tensor",
       "aten::rand_like(Tensor self) -> Tensor",
       "aten::reciprocal(Tensor self) -> Tensor",
       "aten::relu(Tensor self) -> Tensor",
index 3433a9d..87843f6 100644 (file)
@@ -385,7 +385,7 @@ const std::vector<std::string> functions = {
                     tensor1,
                     tensor2,
                     *,
-                    value: float = 1.0):
+                    value: number = 1.0):
             def backward(grad_output):
                 grad = grad_output * value
                 grad_tensor1 = (grad * tensor2)._grad_sum_to_size(tensor1.size())
@@ -448,10 +448,10 @@ const std::vector<std::string> functions = {
 
         def lerp_0(self,
                    end,
-                   weight: float):
+                   weight: number):
             def backward(grad_output):
-                grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self.size())
-                grad_end = (grad_output * weight)._grad_sum_to_size(end.size())
+                grad_self = (grad_output * (1 - float(weight)))._grad_sum_to_size(self.size())
+                grad_end = (grad_output * float(weight))._grad_sum_to_size(end.size())
                 return grad_self, grad_end, None
             return torch.lerp(self, end, weight), backward
 
@@ -578,9 +578,12 @@ const std::vector<std::string> functions = {
             return torch.ones_like(self), backward
 
         def pow_0(self,
-                  exponent: float):
+                  exponent: number):
             def backward(grad_output):
-                grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))
+                if float(exponent) == 0.0:
+                    grad_self = torch.zeros_like(self)
+                else:
+                    grad_self = grad_output * exponent * torch.pow(self, float(exponent) - 1)
                 return grad_self, None
 
             return torch.pow(self, exponent), backward
@@ -594,16 +597,16 @@ const std::vector<std::string> functions = {
 
             return torch.pow(self, exponent), backward
 
-        def pow_2(self: float,
+        def pow_2(self: number,
                   exponent):
             def backward(grad_output):
-                grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(torch.tensor(self))
+                grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(float(self))
                 return None, grad_exponent
 
             return torch.pow(self, exponent), backward
 
         def rsub_0(self, other,
-                   alpha: float = 1.0):
+                   alpha: number = 1.0):
             self_size = self.size()
             other_size = other.size()
             def backward(grad_output):
@@ -614,8 +617,8 @@ const std::vector<std::string> functions = {
             return torch.rsub(self, other, alpha), backward
 
         def rsub_1(self,
-                   other: float,
-                   alpha: float = 1.0):
+                   other: number,
+                   alpha: number = 1.0):
             self_size = self.size()
             def backward(grad_output):
                 grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
@@ -1373,14 +1376,6 @@ c10::optional<GradientPair> gradientInfoForSchema(
     return cache_it->second;
   } else {
     auto schema_str = canonicalSchemaString(schema);
-    // Specialize Scalar to float for the arg type of the node schema
-    // this is used to:
-    // 1. define scalar type as float in TorchScript autodiff formula
-    // 2. to make sure the input of any graph node does not contain scalar type
-    //    in its argument, all scalar arg should already be passed with float
-    //    value since scalar/int aren't differentiable either way.
-    //
-    c10::ReplaceAll(schema_str, "Scalar", "float");
     // For debugging AD change:
     // std::cout << "Looking for " << schema_str << std::endl;