From 68e90a398e11e64425de285bbc3f8bb87c1d21fa Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Mon, 25 Feb 2019 16:22:16 -0800 Subject: [PATCH] Temporarily disable select/topk/kthvalue AD (#17470) Summary: Temporarily disable them for perf consideration. Will figure out a way to do `torch.zeros(sizes, grad.options())` in torchscript before enabling these. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17470 Differential Revision: D14210313 Pulled By: ailzhang fbshipit-source-id: efaf44df1192ae42f4fe75998ff0073234bb4204 --- torch/csrc/jit/symbolic_script.cpp | 59 +++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index cff83d6..b309ccb 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -166,30 +166,30 @@ const std::vector functions = { # FIXME: torchscript: torch.zeros(sizes, grad.options()) return torch.zeros(sizes).to(grad).scatter_(dim, indices, grad) - def topk(self, - k: int, - dim: int = -1, - largest: bool = True, - sorted: bool = True): - result0, result1 = torch.topk(self, k, dim, largest, sorted) - self_size = self.size() - def backward(grad_output): - grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, True) - return grad_self, None, None, None, None + # def topk(self, + # k: int, + # dim: int = -1, + # largest: bool = True, + # sorted: bool = True): + # result0, result1 = torch.topk(self, k, dim, largest, sorted) + # self_size = self.size() + # def backward(grad_output): + # grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, True) + # return grad_self, None, None, None, None - return result0, result1, backward + # return result0, result1, backward - def kthvalue(self, - k: int, - dim: int, - keepdim: bool): - result0, result1 = torch.kthvalue(self, k, dim, keepdim) - self_size = self.size() - def backward(grad_output): - grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, keepdim) - return grad_self, None, None, None + # def kthvalue(self, + # k: int, + # dim: int, + # keepdim: bool): + # result0, result1 = torch.kthvalue(self, k, dim, keepdim) + # self_size = self.size() + # def backward(grad_output): + # grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, keepdim) + # return grad_self, None, None, None - return result0, result1, backward + # return result0, result1, backward def AD_mm_backward_self(grad, mat2): return grad.mm(mat2.t()) @@ -232,15 +232,16 @@ const std::vector functions = { grad_input.select(dim, index).copy_(grad) return grad_input - def select(self, - dim: int, - index: int): - self_size = self.size() - def backward(grad_output): - grad_self = AD_select_backward(grad_output, self_size, dim, index) - return grad_self, None, None + # TODO: fix torch.zeros(sizes, grad.options()) before enabling select, topk, kthvalue + # def select(self, + # dim: int, + # index: int): + # self_size = self.size() + # def backward(grad_output): + # grad_self = AD_select_backward(grad_output, self_size, dim, index) + # return grad_self, None, None - return torch.select(self, dim, index), backward + # return torch.select(self, dim, index), backward def AD_slice_backward(grad, input_sizes: List[int], -- 2.7.4