From b9291f55bb5e16e5acb9c8e9e2cb69bac56e8ab1 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 18 Apr 2019 17:52:33 -0700 Subject: [PATCH] pow scalar exponent / base autodiff, fusion (#19324) Summary: Fixes: #19253 Fixing pow(Tensor, float) is straightforward. The breakage for pow(float, Tensor) is a bit more subtle to trigger, and fixing needs `torch.log` (`math.log` didn't work) from the newly merged #19115 (Thanks ngimel for pointing out this has landed.) Pull Request resolved: https://github.com/pytorch/pytorch/pull/19324 Differential Revision: D15003531 Pulled By: ailzhang fbshipit-source-id: 8b22138fa27a43806b82886fb3a7b557bbb5a865 --- test/test_jit.py | 21 +++++++++++++++++++++ torch/csrc/jit/passes/graph_fuser.cpp | 1 + torch/csrc/jit/symbolic_script.cpp | 33 ++++++++++++++------------------- 3 files changed, 36 insertions(+), 19 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index da4a20f..6e1e6c3 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3292,6 +3292,27 @@ a") self.checkScript(func, (a, b), optimize=True) self.checkScript(func2, (a, b, c, d), optimize=True) + @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") + def test_pow_scalar_backward_cuda(self): + # see that scalar exponent works with cuda base (#19253) + + for dtype in [torch.float, torch.double]: + @torch.jit.script + def func(a, b): + # type: (Tensor, float) -> Tensor + return (a * 2) ** b + + a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype) + func(a, 1).backward() + + @torch.jit.script + def func(a, b): + # type: (float, Tensor) -> Tensor + return a ** (b * 2 + 1) + + a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype) + func(2, a).backward() + def test_triple(self): def func(x): return 3. * x diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 20a3373..f0d9115 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -67,6 +67,7 @@ bool isSimpleMap(Node* node) { "aten::neg(Tensor self) -> Tensor", "aten::pow(Tensor self, Tensor exponent) -> Tensor", "aten::pow(Tensor self, Scalar exponent) -> Tensor", + "aten::pow(Scalar self, Tensor exponent) -> Tensor", "aten::rand_like(Tensor self) -> Tensor", "aten::reciprocal(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 3433a9d..87843f6 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -385,7 +385,7 @@ const std::vector functions = { tensor1, tensor2, *, - value: float = 1.0): + value: number = 1.0): def backward(grad_output): grad = grad_output * value grad_tensor1 = (grad * tensor2)._grad_sum_to_size(tensor1.size()) @@ -448,10 +448,10 @@ const std::vector functions = { def lerp_0(self, end, - weight: float): + weight: number): def backward(grad_output): - grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self.size()) - grad_end = (grad_output * weight)._grad_sum_to_size(end.size()) + grad_self = (grad_output * (1 - float(weight)))._grad_sum_to_size(self.size()) + grad_end = (grad_output * float(weight))._grad_sum_to_size(end.size()) return grad_self, grad_end, None return torch.lerp(self, end, weight), backward @@ -578,9 +578,12 @@ const std::vector functions = { return torch.ones_like(self), backward def pow_0(self, - exponent: float): + exponent: number): def backward(grad_output): - grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1)) + if float(exponent) == 0.0: + grad_self = torch.zeros_like(self) + else: + grad_self = grad_output * exponent * torch.pow(self, float(exponent) - 1) return grad_self, None return torch.pow(self, exponent), backward @@ -594,16 +597,16 @@ const std::vector functions = { return torch.pow(self, exponent), backward - def pow_2(self: float, + def pow_2(self: number, exponent): def backward(grad_output): - grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(torch.tensor(self)) + grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(float(self)) return None, grad_exponent return torch.pow(self, exponent), backward def rsub_0(self, other, - alpha: float = 1.0): + alpha: number = 1.0): self_size = self.size() other_size = other.size() def backward(grad_output): @@ -614,8 +617,8 @@ const std::vector functions = { return torch.rsub(self, other, alpha), backward def rsub_1(self, - other: float, - alpha: float = 1.0): + other: number, + alpha: number = 1.0): self_size = self.size() def backward(grad_output): grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size) @@ -1373,14 +1376,6 @@ c10::optional gradientInfoForSchema( return cache_it->second; } else { auto schema_str = canonicalSchemaString(schema); - // Specialize Scalar to float for the arg type of the node schema - // this is used to: - // 1. define scalar type as float in TorchScript autodiff formula - // 2. to make sure the input of any graph node does not contain scalar type - // in its argument, all scalar arg should already be passed with float - // value since scalar/int aren't differentiable either way. - // - c10::ReplaceAll(schema_str, "Scalar", "float"); // For debugging AD change: // std::cout << "Looking for " << schema_str << std::endl; -- 2.7.4