From 7749804099b8a64aea4bf91e298a20976f9b10ad Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 28 Nov 2018 10:50:26 -0800 Subject: [PATCH] Support Embedding + EmbeddingBag in Script (#14415) Summary: Add support for Embedding and EmbeddingBag in script. Both functions require with torch.no_grad(), which we don't have any plans to support in the near future. To work around this, I added a embedding_renorm function without derivatives. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14415 Reviewed By: wanchaol Differential Revision: D13219647 Pulled By: eellison fbshipit-source-id: c90706aa6fbd48686eb10f3efdb65844be7b8717 --- aten/src/ATen/core/aten_interned_strings.h | 1 + aten/src/ATen/native/Embedding.cpp | 9 +++++++ aten/src/ATen/native/cuda/Embedding.cu | 12 +++++++-- aten/src/ATen/native/native_functions.yaml | 5 ++++ test/test_jit.py | 38 +++++++++++++++++++-------- tools/autograd/derivatives.yaml | 3 +++ torch/csrc/jit/script/init.cpp | 4 +++ torch/jit/__init__.py | 2 +- torch/nn/functional.py | 41 ++++++++++++++++++++---------- torch/nn/modules/sparse.py | 14 ++++++++-- 10 files changed, 100 insertions(+), 29 deletions(-) diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 486d5cf..f588bb4 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -309,6 +309,7 @@ _(aten, embedding_backward) \ _(aten, embedding_bag) \ _(aten, embedding_dense_backward) \ _(aten, embedding_renorm) \ +_(aten, no_grad_embedding_renorm) \ _(aten, embedding_sparse_backward) \ _(aten, empty) \ _(aten, empty_like) \ diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index 72518fb..b2ed5fa 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -178,4 +178,13 @@ Tensor & embedding_renorm_cpu_( return self; } +// This is a workaround to not being able to call with.no_grad(): +// in script. No derivatives are set when calling no_grad_embedding_renorm_cpu_ +// TODO: remove when script supports set_grad_enabled +Tensor & no_grad_embedding_renorm_cpu_( + Tensor & self, const Tensor & indices, double max_norm, double norm_type) { + return embedding_renorm_cpu_(self, indices, max_norm, norm_type); +} + + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 6073c5a..d493faf 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -52,7 +52,7 @@ __global__ void embedding_backward_feature_kernel if(batch_start + tid < n) indices_batch[tid] = (int)indices[batch_start + tid]; - int batch_end = batch_start + blockDim.x*blockDim.y < n ? + int batch_end = batch_start + blockDim.x*blockDim.y < n ? batch_start + blockDim.x*blockDim.y : n; // Loop over the batch of <= 1024 loaded indices in chunks of blockDim.y = 32 @@ -62,7 +62,7 @@ __global__ void embedding_backward_feature_kernel // leaders are done with their accumulates before other warps start loading again. __syncthreads(); - int n_this_chunk = (batch_end - chunk_start) < blockDim.y ? + int n_this_chunk = (batch_end - chunk_start) < blockDim.y ? (batch_end - chunk_start) : blockDim.y; int src_row = chunk_start + threadIdx.y; @@ -385,4 +385,12 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, return self; } +// This is a workaround to not being able to call with.no_grad(): +// in script. No derivatives are set when calling no_grad_embedding_renorm_cuda_ +// TODO: remove when script supports set_grad_enabled +Tensor & no_grad_embedding_renorm_cuda_(Tensor & self, const Tensor & indices, + double max_norm, double norm_type) { + return embedding_renorm_cuda_(self, indices, max_norm, norm_type); +} + }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e6de5a5..f2bf9ca 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -628,6 +628,11 @@ CPU: embedding_renorm_cpu_ CUDA: embedding_renorm_cuda_ +- func: no_grad_embedding_renorm_(Tensor self, IndexTensor indices, double max_norm, double norm_type) -> Tensor + dispatch: + CPU: no_grad_embedding_renorm_cpu_ + CUDA: no_grad_embedding_renorm_cuda_ + - func: embedding_sparse_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor # NOTE [ embedding_bag Native Functions ] diff --git a/test/test_jit.py b/test/test_jit.py index 540a13e..6c0d49b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9376,8 +9376,6 @@ EXCLUDE_SCRIPT = { # argument has custom behavior 'test_nn_fractional_max_pool2d', 'test_nn_max_unpool3d', - 'test_nn_embedding', - 'test_nn_embedding_bag', 'test_nn_batch_norm', # aten op has additional cudnn argument @@ -9902,6 +9900,7 @@ S = 5 # constructor arguments, # args (tuple represents shape of a tensor arg), # use_as_constant (should the submodule be listed in __constants__?) +# test variant name(will be used at test name suffix), # ) nn_module_tests = [ ('AlphaDropout', (), ((S,),)), @@ -9924,9 +9923,18 @@ nn_module_tests = [ ('Tanh', (), ((S,),)), ('Tanhshrink', (), ((S,),)), ('Threshold', (2., 2.), ((S,),)), + ('Embedding', (4, 3), (torch.empty(2, 3, dtype=torch.long).random_(4)),), + ('EmbeddingBag', (4, 3), (torch.empty(2, 3, dtype=torch.long).random_(4)),), + ('EmbeddingBag', (4, 3, None, 2., False, 'sum'), torch.empty(2, 3, dtype=torch.long).random_(4), False, 'sum'), + ('EmbeddingBag', (4, 3, None, 2., False, 'max'), torch.empty(2, 3, dtype=torch.long).random_(4), False, 'max'), ('Sequential', (torch.nn.Sigmoid(), torch.nn.Threshold(1., 2.)), ((S,),), True), ] +# module cannot be exported /imported currently +EXCLUDE_MODULE_EXPORT_IMPORT = { + 'EmbeddingBag', +} + # NB: JIT script tests for all nn functional interfaces, script mode does # not support in_place operations yet, so no inplace operation tests added. # removed all the deprecated functions @@ -10200,7 +10208,7 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=() def add_nn_module_test(module_name, constructor_args, call_args, - use_as_constant=False, skipTestIf=()): + use_as_constant=False, variant=None, skipTestIf=()): def do_test(self): nn_module = getattr(torch.nn, module_name) @@ -10225,14 +10233,20 @@ def add_nn_module_test(module_name, constructor_args, call_args, def __init__(self): super(TheModule, self).__init__() self.submodule = nn_module(*constructor_args) - - module = TheModule() - module.define(script) - - # Check there are no Python ops by exporting - self.assertExportImportModule(module, tensors) - create_script_module.last_graph = module.graph - return module(*args) + # module cannot be imported / exported + if module_name in EXCLUDE_MODULE_EXPORT_IMPORT: + with self.disableModuleHook(): + module = TheModule() + module.define(script) + create_script_module.last_graph = module.graph + mod = module(*args) + else: + module = TheModule() + module.define(script) + self.assertExportImportModule(module, tensors) + create_script_module.last_graph = module.graph + mod = module(*args) + return mod # Construct a normal nn module to stay consistent with create_script_module # and make use of a single global rng_state in module initialization @@ -10247,6 +10261,8 @@ def add_nn_module_test(module_name, constructor_args, call_args, check_against_reference(self, create_script_module, create_nn_module, f_args_variable) test_name = 'test_nn_{}'.format(module_name) + if variant is not None: + test_name = test_name + "_" + variant post_add_test(test_name, skipTestIf, do_test) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index ad83012..1ea59ec 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -898,6 +898,9 @@ - name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) self: not_implemented("embedding_renorm") +- name: no_grad_embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) + output_differentiability: [False, False, False, False] + - name: kl_div(Tensor self, Tensor target, int64_t reduction) self: kl_div_backward(grad, self, target, reduction) target: kl_div_target_backward(grad, self, target, reduction) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 7ae3226..b25edab 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -337,6 +337,10 @@ std::shared_ptr toSugaredValue( return toSimple(g.insertConstant(py::cast(obj), loc)); } else if (py::isinstance(obj)) { return toSimple(g.insertConstant(py::cast(obj), loc)); + } else if (py::isinstance(obj)) { + return toSimple(g.insertConstant(py::cast(obj), loc)); + } else if (obj.is(py::none())) { + return toSimple(g.insertConstant(IValue(), loc)); } else if (THPDevice_Check(obj.ptr())) { auto device = reinterpret_cast(obj.ptr()); std::vector v = {static_cast(device->device.type()), diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 9df2933..b7fccec 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -862,7 +862,7 @@ class OrderedBufferDict(OrderedDictWrapper): # in addition, tuples and lists of these base types are also considered constants # If you edit this list, then you also need to edit the handlers in # ConstantValue in jit/script/init.cpp -_constant_types = (bool, float, int, types.FunctionType, torch.device, torch.layout, torch.dtype) +_constant_types = (bool, float, int, str, type(None), types.FunctionType, torch.device, torch.layout, torch.dtype) def _get_valid_constant(attr, v): diff --git a/torch/nn/functional.py b/torch/nn/functional.py index a45751f..b8a5c8b 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1328,8 +1328,10 @@ def bilinear(input1, input2, weight, bias=None): return torch.bilinear(input1, input2, weight, bias) -def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2, +@torch._jit_internal.weak_script +def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False): + # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor r"""A simple lookup table that looks up embeddings in a fixed dictionary and size. This module is often used to retrieve word embeddings using indices. @@ -1388,25 +1390,32 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2, [ 0.6262, 0.2438, 0.7471]]]) """ if padding_idx is not None: + padding_idx = torch.jit._unwrap_optional(padding_idx) if padding_idx > 0: assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings' elif padding_idx < 0: assert padding_idx >= -weight.size(0), 'Padding_idx must be within num_embeddings' padding_idx = weight.size(0) + padding_idx - elif padding_idx is None: + else: padding_idx = -1 if max_norm is not None: + max_norm = torch.jit._unwrap_optional(max_norm) # `embedding_renorm_` will call .contiguous() on input anyways, so we # call it here and take advantage of the improved locality in the # `embedding` call below too. input = input.contiguous() - with torch.no_grad(): - torch.embedding_renorm_(weight, input, max_norm, norm_type) + # XXX: equivalent to + # with torch.no_grad(): + # torch.nembedding_renorm_ + # remove once script supports set_grad_enabled + torch.no_grad_embedding_renorm_(weight, input, max_norm, norm_type) return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) +@torch._jit_internal.weak_script def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode='mean', sparse=False): + # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool) -> Tensor r"""Computes sums, means or maxes of 'bags' of embeddings, without instantiating the intermediate embeddings. @@ -1491,26 +1500,27 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, elif input.dim() == 1: if offsets is None: raise ValueError("offsets has to be a 1D Tensor but got None") + offsets = torch.jit._unwrap_optional(offsets) if offsets.dim() != 1: raise ValueError("offsets has to be a 1D Tensor") - if offsets[0].item() != 0: + if int(offsets[0]) != 0: raise ValueError("offsets[0] has to be 0, i.e., the first sequence " "in the mini-batch has to start from position 0. " "However, got {}".format(offsets[0].item())) - if offsets[-1].item() > input.size(0): + if int(offsets[-1]) > input.size(0): raise ValueError("offsets[-1] can not be greater than input's length" " ({}), but got offsets[-1] of {}" .format(input.size(0), offsets[-1].item())) else: raise ValueError("input has to be 1D or 2D Tensor," " but got Tensor of dimension {}".format(input.dim())) - + offsets = torch.jit._unwrap_optional(offsets) # TODO remove when exception control flow logic if mode == 'sum': - mode = 0 + mode_enum = 0 elif mode == 'mean': - mode = 1 + mode_enum = 1 elif mode == 'max': - mode = 2 + mode_enum = 2 if scale_grad_by_freq: raise ValueError("max mode does not support scaling the gradient by the frequency") @@ -1519,18 +1529,23 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, raise ValueError("max mode does not support sparse weights") else: + mode_enum = -1 # TODO when exception control flow logic raise ValueError("mode has to be one of sum, mean or max") if max_norm is not None: - with torch.no_grad(): - torch.embedding_renorm_(weight, input, max_norm, norm_type) + max_norm = torch.jit._unwrap_optional(max_norm) + # XXX: equivalent to + # with torch.no_grad(): + # torch.nembedding_renorm_ + # remove once script supports set_grad_enabled + torch.no_grad_embedding_renorm_(weight, input, max_norm, norm_type) ret, _, _, _ = torch.embedding_bag( weight, input, offsets, scale_grad_by_freq, - mode, + mode_enum, sparse) return ret diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index 2ee4103..a88d94b 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -4,8 +4,10 @@ from torch.nn.parameter import Parameter from .module import Module from .. import functional as F from .. import init +from torch._jit_internal import weak_module, weak_script, weak_script_method +@weak_module class Embedding(Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -75,9 +77,11 @@ class Embedding(Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm', + 'norm_type', 'scale_grad_by_freq', 'sparse', '_weight'] def __init__(self, num_embeddings, embedding_dim, padding_idx=None, - max_norm=None, norm_type=2, scale_grad_by_freq=False, + max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False, _weight=None): super(Embedding, self).__init__() self.num_embeddings = num_embeddings @@ -107,6 +111,7 @@ class Embedding(Module): with torch.no_grad(): self.weight[self.padding_idx].fill_(0) + @weak_script_method def forward(self, input): return F.embedding( input, self.weight, self.padding_idx, self.max_norm, @@ -161,6 +166,7 @@ class Embedding(Module): return embedding +@weak_module class EmbeddingBag(Module): r"""Computes sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings. @@ -223,9 +229,11 @@ class EmbeddingBag(Module): tensor([[-0.8861, -5.4350, -0.0523], [ 1.1306, -2.5798, -1.0044]]) """ + __constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type', + 'scale_grad_by_freq', 'mode', 'sparse'] def __init__(self, num_embeddings, embedding_dim, - max_norm=None, norm_type=2, scale_grad_by_freq=False, + max_norm=None, norm_type=2., scale_grad_by_freq=False, mode='mean', sparse=False): super(EmbeddingBag, self).__init__() self.num_embeddings = num_embeddings @@ -242,7 +250,9 @@ class EmbeddingBag(Module): def reset_parameters(self): init.normal_(self.weight) + @weak_script_method def forward(self, input, offsets=None): + # type: (Tensor, Optional[Tensor]) -> Tensor return F.embedding_bag(input, self.weight, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse) -- 2.7.4