From 416474a720504b2e4a139356ad34a226969e1c2f Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Fri, 1 Mar 2019 19:12:08 -0800 Subject: [PATCH] Remove more usages of BoolTensor and IndexTensor from native_functions.yaml Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16468 Differential Revision: D14095405 Pulled By: cpuhrsch fbshipit-source-id: ea4d6bb7a4e81c05fe9861190ddbf52201612bbf --- aten/src/ATen/native/native_functions.yaml | 28 +++++++++++++++++----------- tools/autograd/derivatives.yaml | 2 ++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index bb6f8cc..1b3540b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -784,7 +784,8 @@ - func: embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor matches_jit_signature: True -- func: embedding_backward(Tensor grad, IndexTensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor +- func: embedding_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor + matches_jit_signature: True - func: embedding_dense_backward(Tensor grad, IndexTensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor dispatch: @@ -797,7 +798,8 @@ CPU: embedding_renorm_cpu_ CUDA: embedding_renorm_cuda_ -- func: embedding_sparse_backward(Tensor grad, IndexTensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor +- func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor + matches_jit_signature: True # NOTE [ embedding_bag Native Functions ] # The `_embedding_bag.*` variants assume that input tensors except for `weight`, @@ -808,16 +810,20 @@ # applying indices = indices.contiguous(). # The backward functions apply a check that these input tensors are contiguous. -- func: embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor, Tensor, Tensor) +- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor, Tensor, Tensor) + matches_jit_signature: True -- func: _embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor, Tensor, Tensor) +- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor, Tensor, Tensor) + matches_jit_signature: True dispatch: CPU: _embedding_bag_cpu CUDA: _embedding_bag_cuda -- func: _embedding_bag_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse) -> Tensor +- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse) -> Tensor + matches_jit_signature: True -- func: _embedding_bag_sparse_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int num_weights, bool scale_grad_by_freq, int mode) -> Tensor +- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int num_weights, bool scale_grad_by_freq, int mode) -> Tensor + matches_jit_signature: True - func: _embedding_bag_dense_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode) -> Tensor dispatch: @@ -2292,7 +2298,8 @@ # we define both of these because 'where' does the broadcast and '_s_where' doesn't; # this allows us to implicitly calculate the broadcast derivative, while only dealing with the # _s_where derivative. -- func: where(BoolTensor condition, Tensor self, Tensor other) -> Tensor +- func: where(Tensor condition, Tensor self, Tensor other) -> Tensor + matches_jit_signature: True variants: function, method - func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor @@ -2660,12 +2667,11 @@ # the default would never make sense. - func: sparse_coo_tensor(int[] size, *, TensorOptions options) -> Tensor -- func: sparse_coo_tensor(IndexTensor indices, Tensor values, *, TensorOptions options=[]) -> Tensor - -- func: sparse_coo_tensor(IndexTensor indices, Tensor values, int[] size, *, TensorOptions options=[]) -> Tensor +- func: sparse_coo_tensor(Tensor indices, Tensor values, *, TensorOptions options=[]) -> Tensor -- func: _sparse_coo_tensor_unsafe(IndexTensor indices, Tensor values, int[] size, *, TensorOptions options=[]) -> Tensor +- func: sparse_coo_tensor(Tensor indices, Tensor values, int[] size, *, TensorOptions options=[]) -> Tensor +- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, int[] size, *, TensorOptions options=[]) -> Tensor - func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, TensorOptions options) -> Tensor dispatch: diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index ee05b63..dc28a6d 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -935,6 +935,8 @@ weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse) - name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) + indices: not_differentiable + offsets: not_differentiable weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse) - name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) -- 2.7.4