From 6c9b312fd48c5a3e00719be881aa067b436d05e5 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 25 Mar 2019 11:02:17 -0700 Subject: [PATCH] Add addcmul, lerp to fuser, enable scalar->float specialization in symbolic script (#18081) Summary: This PR did two things: 1. Enable scalar->float specialization in symbolic script, so AD formula that contains scalar in the schema, should write `float` instead. 2. add addcmul, lerp to AD and fuser. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18081 Differential Revision: D14490493 Pulled By: wanchaol fbshipit-source-id: b3b86d960d5f051b30733bc908b19786111cdaa4 --- test/test_jit.py | 39 ++++++++++++++++++++++++++++++++++ torch/csrc/jit/autodiff.cpp | 14 +++++++++--- torch/csrc/jit/fuser/codegen.cpp | 2 ++ torch/csrc/jit/graph_executor.cpp | 3 +-- torch/csrc/jit/passes/graph_fuser.cpp | 3 +++ torch/csrc/jit/symbolic_script.cpp | 40 +++++++++++++++++++++++++++++++++++ torch/csrc/jit/symbolic_script.h | 1 + 7 files changed, 97 insertions(+), 5 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 8606a6e..36d51b9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -12196,6 +12196,45 @@ class TestFuser(JitTestCase): @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @skipIfRocm + def test_addcmul_cuda(self): + t = torch.randn(1, 4, dtype=torch.float, device='cuda') + t1 = torch.randn(4, 1, dtype=torch.float, device='cuda') + t2 = torch.randn(1, 4, dtype=torch.float, device='cuda') + + def foo(t, t1, t2): + return t.addcmul(t + 1, t2, value=0.1) + + ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) + graph = ge.graph_for(t, t1, t2) + self.assertAllFused(graph) + + @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @skipIfRocm + def test_lerp_cuda(self): + start = torch.randn(4, 1, dtype=torch.float, device='cuda') + end = torch.randn(1, 4, dtype=torch.float, device='cuda') + weight = torch.tensor(0.5, dtype=torch.float, device='cuda') + + # scalar weight overload + def foo_weight_scalar(start, end): + return torch.lerp(start + 1, end, 0.5) + + # tensor weight overload + def foo_weight_tensor(start, end): + return torch.lerp(start + 1, end, weight) + + ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end)) + graph = ge_weight_scalar.graph_for(start, end) + self.assertAllFused(graph) + + ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end)) + graph = ge_weight_tensor.graph_for(start, end) + self.assertAllFused(graph) + + @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @skipIfRocm def test_concat_cuda(self): hx = torch.randn(3, 20, dtype=torch.float, device='cuda') cx = torch.randn(3, 20, dtype=torch.float, device='cuda') diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index cfe244f..61215a8 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -129,7 +129,8 @@ bool isDifferentiable(Node* n) { return true; if (n->matches( - "aten::dropout(Tensor input, float p, bool train) -> Tensor", attr::train)) { + "aten::dropout(Tensor input, float p, bool train) -> Tensor", + attr::train)) { return n->get(attr::train).value(); } @@ -164,6 +165,14 @@ bool isDifferentiable(Node* n) { static_cast(isDifferentiable)); } + // formulas are only defined with floating point scalars, + // so we fallback to autograd for other cases. + for (const Value* input : n->inputs()) { + if (input->type() == NumberType::get()) { + return false; + } + } + return false; } @@ -204,7 +213,6 @@ static c10::optional> build_script_grad( Node* node, const ArrayRef& grads) { auto graph = node->owningGraph(); - auto compiled_graphs = gradientInfoForSchema(node->schema()); if (!compiled_graphs) { return c10::nullopt; @@ -228,7 +236,7 @@ static c10::optional> build_script_grad( auto bw_graph = compiled_graphs->backward; auto grad_vec = grads.vec(); if (needTrimGrad(node)) { - grad_vec.erase(grad_vec.begin()+1, grad_vec.end()); + grad_vec.erase(grad_vec.begin() + 1, grad_vec.end()); } auto it = grad_vec.begin(); grad_vec.insert(it, new_outputs.back()); diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp index fd5beef..80535b3 100644 --- a/torch/csrc/jit/fuser/codegen.cpp +++ b/torch/csrc/jit/fuser/codegen.cpp @@ -208,6 +208,7 @@ static std::string encodeRHS(const Node* n) { {aten::__or__, "${0} || ${1}"}, {aten::__rshift__, "${0} >> ${1}"}, {aten::__xor__, "${0} ^ ${1}"}, + {aten::addcmul, "${cast_0} + ${cast_3} * ${cast_1} * ${cast_2}"}, {aten::div, "${cast_0} / ${cast_1}"}, {aten::eq, "${0} == ${1}"}, {aten::fmod, "fmodf(${cast_0}, ${cast_1})"}, @@ -215,6 +216,7 @@ static std::string encodeRHS(const Node* n) { {aten::gt, "${0} > ${1}"}, {aten::le, "(${0} <= ${1})"}, {aten::lt, "${0} < ${1}"}, + {aten::lerp, "${cast_0} + ${cast_2} * (${cast_1} - ${cast_0})"}, {aten::type_as, "(${cast_0})"}, {aten::mul, "${cast_0} * ${cast_1}"}, {aten::ne, "${0} != ${1}"}, diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 15def35..f7cd332 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -8,8 +8,6 @@ #include #include #include -#include -#include #include #include #include @@ -27,6 +25,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 5923582..ad1a835 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -59,6 +59,8 @@ bool isSimpleMap(Node* node) { "aten::log10(Tensor self) -> Tensor", "aten::log1p(Tensor self) -> Tensor", "aten::log2(Tensor self) -> Tensor", + "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor", + "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor", "aten::max(Tensor self, Tensor other) -> Tensor", "aten::min(Tensor self, Tensor other) -> Tensor", "aten::mul(Tensor self, Tensor other) -> Tensor", @@ -98,6 +100,7 @@ bool isSimpleMap(Node* node) { "aten::lt(Tensor self, Tensor other) -> Tensor", "aten::lt(Tensor self, Scalar other) -> Tensor", + "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor", "aten::type_as(Tensor self, Tensor other) -> Tensor", diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 98e5a74..cce3552 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -383,6 +383,18 @@ const std::vector functions = { return torch.matmul(self, other), backward )", R"( + def addcmul(self, + tensor1, + tensor2, + *, + value: float = 1.0): + def backward(grad_output): + grad = grad_output * value + grad_tensor1 = (grad * tensor2)._grad_sum_to_size(tensor1.size()) + grad_tensor2 = (grad * tensor1)._grad_sum_to_size(tensor2.size()) + return grad_output._grad_sum_to_size(self.size()), grad_tensor1, grad_tensor2, None + return torch.addcmul(self, tensor1, tensor2, value=value), backward + def _dim_arange(like, dim: int): def backward(grad_output): @@ -439,6 +451,24 @@ const std::vector functions = { return torch.full_like(self, fill_value), backward + def lerp_0(self, + end, + weight: float): + def backward(grad_output): + grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self.size()) + grad_end = (grad_output * weight)._grad_sum_to_size(end.size()) + return grad_self, grad_end, None + return torch.lerp(self, end, weight), backward + + def lerp_1(self, + end, + weight): + def backward(grad_output): + grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self.size()) + grad_end = (grad_output * weight)._grad_sum_to_size(end.size()) + return grad_self, grad_end, None + return torch.lerp(self, end, weight), backward + def mul(self, other): def backward(grad_output): # self & other are used in backward. No need to pass in their size @@ -889,6 +919,7 @@ std::string overloadedSchemaString(const FunctionSchema& schema) { schema_name.length(), schema_name.substr(0, pos)); } + return schema_string; } @@ -969,6 +1000,15 @@ c10::optional gradientInfoForSchema( return cache_it->second; } else { auto schema_str = canonicalSchemaString(schema); + // Specialize Scalar to float for the arg type of the node schema + // this is used to: + // 1. define scalar type as float in TorchScript autodiff formula + // 2. to make sure the input of any graph node does not contain scalar type + // in its argument, all scalar arg should already be passed with float + // value since scalar/int aren't differentiable either way. + // + c10::ReplaceAll(schema_str, "Scalar", "float"); + auto sym_script_it = schema_to_graphs.find(schema_str); if (sym_script_it != schema_to_graphs.end()) { diff --git a/torch/csrc/jit/symbolic_script.h b/torch/csrc/jit/symbolic_script.h index bc8284c..6c17eb5 100644 --- a/torch/csrc/jit/symbolic_script.h +++ b/torch/csrc/jit/symbolic_script.h @@ -3,6 +3,7 @@ // merged. Ideally this should all go into native_functions.yaml #include +#include #include #include #include -- 2.7.4