Revert D13219647: [pytorch][PR] Support Embedding + EmbeddingBag in Script
authorEdward Yang <ezyang@fb.com>
Wed, 28 Nov 2018 21:36:40 +0000 (13:36 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 28 Nov 2018 21:38:58 +0000 (13:38 -0800)
Differential Revision:
D13219647

Original commit changeset: c90706aa6fbd

fbshipit-source-id: d189e717ba0773de43d633876bc3a688830a9303

aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/Embedding.cpp
aten/src/ATen/native/cuda/Embedding.cu
aten/src/ATen/native/native_functions.yaml
test/test_jit.py
tools/autograd/derivatives.yaml
torch/csrc/jit/script/init.cpp
torch/jit/__init__.py
torch/nn/functional.py
torch/nn/modules/sparse.py

index f588bb4..486d5cf 100644 (file)
@@ -309,7 +309,6 @@ _(aten, embedding_backward) \
 _(aten, embedding_bag) \
 _(aten, embedding_dense_backward) \
 _(aten, embedding_renorm) \
-_(aten, no_grad_embedding_renorm) \
 _(aten, embedding_sparse_backward) \
 _(aten, empty) \
 _(aten, empty_like) \
index b2ed5fa..72518fb 100644 (file)
@@ -178,13 +178,4 @@ Tensor & embedding_renorm_cpu_(
   return self;
 }
 
-// This is a workaround to not being able to call with.no_grad():
-// in script. No derivatives are set when calling no_grad_embedding_renorm_cpu_
-// TODO: remove when script supports set_grad_enabled
-Tensor & no_grad_embedding_renorm_cpu_(
-    Tensor & self, const Tensor & indices, double max_norm, double norm_type) {
-  return embedding_renorm_cpu_(self, indices, max_norm, norm_type);
-}
-
-
 }}  // namespace at::native
index d493faf..6073c5a 100644 (file)
@@ -52,7 +52,7 @@ __global__ void embedding_backward_feature_kernel
     if(batch_start + tid < n)
       indices_batch[tid] = (int)indices[batch_start + tid];
 
-    int batch_end = batch_start + blockDim.x*blockDim.y < n ?
+    int batch_end = batch_start + blockDim.x*blockDim.y < n ? 
                     batch_start + blockDim.x*blockDim.y : n;
 
     // Loop over the batch of <= 1024 loaded indices in chunks of blockDim.y = 32
@@ -62,7 +62,7 @@ __global__ void embedding_backward_feature_kernel
       // leaders are done with their accumulates before other warps start loading again.
       __syncthreads();
 
-      int n_this_chunk = (batch_end - chunk_start) < blockDim.y ?
+      int n_this_chunk = (batch_end - chunk_start) < blockDim.y ? 
                          (batch_end - chunk_start) : blockDim.y;
 
       int src_row = chunk_start + threadIdx.y;
@@ -385,12 +385,4 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
   return self;
 }
 
-// This is a workaround to not being able to call with.no_grad():
-// in script. No derivatives are set when calling no_grad_embedding_renorm_cuda_
-// TODO: remove when script supports set_grad_enabled
-Tensor & no_grad_embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
-                                double max_norm, double norm_type) {
-  return embedding_renorm_cuda_(self, indices, max_norm, norm_type);
-}
-
 }}  // namespace at::native
index f2bf9ca..e6de5a5 100644 (file)
     CPU: embedding_renorm_cpu_
     CUDA: embedding_renorm_cuda_
 
-- func: no_grad_embedding_renorm_(Tensor self, IndexTensor indices, double max_norm, double norm_type) -> Tensor
-  dispatch:
-    CPU: no_grad_embedding_renorm_cpu_
-    CUDA: no_grad_embedding_renorm_cuda_
-
 - func: embedding_sparse_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor
 
 # NOTE [ embedding_bag Native Functions ]
index 6c0d49b..540a13e 100644 (file)
@@ -9376,6 +9376,8 @@ EXCLUDE_SCRIPT = {
     # argument has custom behavior
     'test_nn_fractional_max_pool2d',
     'test_nn_max_unpool3d',
+    'test_nn_embedding',
+    'test_nn_embedding_bag',
     'test_nn_batch_norm',
 
     # aten op has additional cudnn argument
@@ -9900,7 +9902,6 @@ S = 5
 #     constructor arguments,
 #     args (tuple represents shape of a tensor arg),
 #     use_as_constant (should the submodule be listed in __constants__?)
-#     test variant name(will be used at test name suffix),
 # )
 nn_module_tests = [
     ('AlphaDropout', (), ((S,),)),
@@ -9923,18 +9924,9 @@ nn_module_tests = [
     ('Tanh', (), ((S,),)),
     ('Tanhshrink', (), ((S,),)),
     ('Threshold', (2., 2.), ((S,),)),
-    ('Embedding', (4, 3), (torch.empty(2, 3, dtype=torch.long).random_(4)),),
-    ('EmbeddingBag', (4, 3), (torch.empty(2, 3, dtype=torch.long).random_(4)),),
-    ('EmbeddingBag', (4, 3, None, 2., False, 'sum'), torch.empty(2, 3, dtype=torch.long).random_(4), False, 'sum'),
-    ('EmbeddingBag', (4, 3, None, 2., False, 'max'), torch.empty(2, 3, dtype=torch.long).random_(4), False, 'max'),
     ('Sequential', (torch.nn.Sigmoid(), torch.nn.Threshold(1., 2.)), ((S,),), True),
 ]
 
-# module cannot be exported /imported currently
-EXCLUDE_MODULE_EXPORT_IMPORT = {
-    'EmbeddingBag',
-}
-
 # NB: JIT script tests for all nn functional interfaces, script mode does
 # not support in_place operations yet, so no inplace operation tests added.
 # removed all the deprecated functions
@@ -10208,7 +10200,7 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=()
 
 
 def add_nn_module_test(module_name, constructor_args, call_args,
-                       use_as_constant=False, variant=None, skipTestIf=()):
+                       use_as_constant=False, skipTestIf=()):
     def do_test(self):
         nn_module = getattr(torch.nn, module_name)
 
@@ -10233,20 +10225,14 @@ def add_nn_module_test(module_name, constructor_args, call_args,
                 def __init__(self):
                     super(TheModule, self).__init__()
                     self.submodule = nn_module(*constructor_args)
-            # module cannot be imported / exported
-            if module_name in EXCLUDE_MODULE_EXPORT_IMPORT:
-                with self.disableModuleHook():
-                    module = TheModule()
-                    module.define(script)
-                    create_script_module.last_graph = module.graph
-                    mod = module(*args)
-            else:
-                module = TheModule()
-                module.define(script)
-                self.assertExportImportModule(module, tensors)
-                create_script_module.last_graph = module.graph
-                mod = module(*args)
-            return mod
+
+            module = TheModule()
+            module.define(script)
+
+            # Check there are no Python ops by exporting
+            self.assertExportImportModule(module, tensors)
+            create_script_module.last_graph = module.graph
+            return module(*args)
 
         # Construct a normal nn module to stay consistent with create_script_module
         # and make use of a single global rng_state in module initialization
@@ -10261,8 +10247,6 @@ def add_nn_module_test(module_name, constructor_args, call_args,
         check_against_reference(self, create_script_module, create_nn_module, f_args_variable)
 
     test_name = 'test_nn_{}'.format(module_name)
-    if variant is not None:
-        test_name = test_name + "_" + variant
     post_add_test(test_name, skipTestIf, do_test)
 
 
index 1ea59ec..ad83012 100644 (file)
 - name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type)
   self: not_implemented("embedding_renorm")
 
-- name: no_grad_embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type)
-  output_differentiability: [False, False, False, False]
-
 - name: kl_div(Tensor self, Tensor target, int64_t reduction)
   self: kl_div_backward(grad, self, target, reduction)
   target: kl_div_target_backward(grad, self, target, reduction)
index b25edab..7ae3226 100644 (file)
@@ -337,10 +337,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
       return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc));
     } else if (py::isinstance<py::float_>(obj)) {
       return toSimple(g.insertConstant(py::cast<float>(obj), loc));
-    } else if (py::isinstance<py::str>(obj)) {
-      return toSimple(g.insertConstant(py::cast<std::string>(obj), loc));
-    } else if (obj.is(py::none())) {
-      return toSimple(g.insertConstant(IValue(), loc));
     } else if (THPDevice_Check(obj.ptr())) {
       auto device = reinterpret_cast<THPDevice*>(obj.ptr());
       std::vector<int64_t> v = {static_cast<int64_t>(device->device.type()),
index b7fccec..9df2933 100644 (file)
@@ -862,7 +862,7 @@ class OrderedBufferDict(OrderedDictWrapper):
 # in addition, tuples and lists of these base types are also considered constants
 # If you edit this list, then you also need to edit the handlers in
 # ConstantValue in jit/script/init.cpp
-_constant_types = (bool, float, int, str, type(None), types.FunctionType, torch.device, torch.layout, torch.dtype)
+_constant_types = (bool, float, int, types.FunctionType, torch.device, torch.layout, torch.dtype)
 
 
 def _get_valid_constant(attr, v):
index b8a5c8b..a45751f 100644 (file)
@@ -1328,10 +1328,8 @@ def bilinear(input1, input2, weight, bias=None):
     return torch.bilinear(input1, input2, weight, bias)
 
 
-@torch._jit_internal.weak_script
-def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
+def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2,
               scale_grad_by_freq=False, sparse=False):
-    # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor
     r"""A simple lookup table that looks up embeddings in a fixed dictionary and size.
 
     This module is often used to retrieve word embeddings using indices.
@@ -1390,32 +1388,25 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
                  [ 0.6262,  0.2438,  0.7471]]])
     """
     if padding_idx is not None:
-        padding_idx = torch.jit._unwrap_optional(padding_idx)
         if padding_idx > 0:
             assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings'
         elif padding_idx < 0:
             assert padding_idx >= -weight.size(0), 'Padding_idx must be within num_embeddings'
             padding_idx = weight.size(0) + padding_idx
-    else:
+    elif padding_idx is None:
         padding_idx = -1
     if max_norm is not None:
-        max_norm = torch.jit._unwrap_optional(max_norm)
         # `embedding_renorm_` will call .contiguous() on input anyways, so we
         # call it here and take advantage of the improved locality in the
         # `embedding` call below too.
         input = input.contiguous()
-        # XXX: equivalent to
-        # with torch.no_grad():
-        #   torch.nembedding_renorm_
-        # remove once script supports set_grad_enabled
-        torch.no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
+        with torch.no_grad():
+            torch.embedding_renorm_(weight, input, max_norm, norm_type)
     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
 
 
-@torch._jit_internal.weak_script
 def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
                   scale_grad_by_freq=False, mode='mean', sparse=False):
-    # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool) -> Tensor
     r"""Computes sums, means or maxes of 'bags' of embeddings, without instantiating the
     intermediate embeddings.
 
@@ -1500,27 +1491,26 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
     elif input.dim() == 1:
         if offsets is None:
             raise ValueError("offsets has to be a 1D Tensor but got None")
-        offsets = torch.jit._unwrap_optional(offsets)
         if offsets.dim() != 1:
             raise ValueError("offsets has to be a 1D Tensor")
-        if int(offsets[0]) != 0:
+        if offsets[0].item() != 0:
             raise ValueError("offsets[0] has to be 0, i.e., the first sequence "
                              "in the mini-batch has to start from position 0. "
                              "However, got {}".format(offsets[0].item()))
-        if int(offsets[-1]) > input.size(0):
+        if offsets[-1].item() > input.size(0):
             raise ValueError("offsets[-1] can not be greater than input's length"
                              " ({}), but got offsets[-1] of {}"
                              .format(input.size(0), offsets[-1].item()))
     else:
         raise ValueError("input has to be 1D or 2D Tensor,"
                          " but got Tensor of dimension {}".format(input.dim()))
-    offsets = torch.jit._unwrap_optional(offsets)  # TODO remove when exception control flow logic
+
     if mode == 'sum':
-        mode_enum = 0
+        mode = 0
     elif mode == 'mean':
-        mode_enum = 1
+        mode = 1
     elif mode == 'max':
-        mode_enum = 2
+        mode = 2
 
         if scale_grad_by_freq:
             raise ValueError("max mode does not support scaling the gradient by the frequency")
@@ -1529,23 +1519,18 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
             raise ValueError("max mode does not support sparse weights")
 
     else:
-        mode_enum = -1  # TODO when exception control flow logic
         raise ValueError("mode has to be one of sum, mean or max")
 
     if max_norm is not None:
-        max_norm = torch.jit._unwrap_optional(max_norm)
-        # XXX: equivalent to
-        # with torch.no_grad():
-        #   torch.nembedding_renorm_
-        # remove once script supports set_grad_enabled
-        torch.no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
+        with torch.no_grad():
+            torch.embedding_renorm_(weight, input, max_norm, norm_type)
 
     ret, _, _, _ = torch.embedding_bag(
         weight,
         input,
         offsets,
         scale_grad_by_freq,
-        mode_enum,
+        mode,
         sparse)
     return ret
 
index a88d94b..2ee4103 100644 (file)
@@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter
 from .module import Module
 from .. import functional as F
 from .. import init
-from torch._jit_internal import weak_module, weak_script, weak_script_method
 
 
-@weak_module
 class Embedding(Module):
     r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
 
@@ -77,11 +75,9 @@ class Embedding(Module):
                  [ 0.0000,  0.0000,  0.0000],
                  [-0.1655,  0.9897,  0.0635]]])
     """
-    __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
-                     'norm_type', 'scale_grad_by_freq', 'sparse', '_weight']
 
     def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
-                 max_norm=None, norm_type=2., scale_grad_by_freq=False,
+                 max_norm=None, norm_type=2, scale_grad_by_freq=False,
                  sparse=False, _weight=None):
         super(Embedding, self).__init__()
         self.num_embeddings = num_embeddings
@@ -111,7 +107,6 @@ class Embedding(Module):
             with torch.no_grad():
                 self.weight[self.padding_idx].fill_(0)
 
-    @weak_script_method
     def forward(self, input):
         return F.embedding(
             input, self.weight, self.padding_idx, self.max_norm,
@@ -166,7 +161,6 @@ class Embedding(Module):
         return embedding
 
 
-@weak_module
 class EmbeddingBag(Module):
     r"""Computes sums or means of 'bags' of embeddings, without instantiating the
     intermediate embeddings.
@@ -229,11 +223,9 @@ class EmbeddingBag(Module):
         tensor([[-0.8861, -5.4350, -0.0523],
                 [ 1.1306, -2.5798, -1.0044]])
     """
-    __constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type',
-                      'scale_grad_by_freq', 'mode', 'sparse']
 
     def __init__(self, num_embeddings, embedding_dim,
-                 max_norm=None, norm_type=2., scale_grad_by_freq=False,
+                 max_norm=None, norm_type=2, scale_grad_by_freq=False,
                  mode='mean', sparse=False):
         super(EmbeddingBag, self).__init__()
         self.num_embeddings = num_embeddings
@@ -250,9 +242,7 @@ class EmbeddingBag(Module):
     def reset_parameters(self):
         init.normal_(self.weight)
 
-    @weak_script_method
     def forward(self, input, offsets=None):
-        # type: (Tensor, Optional[Tensor]) -> Tensor
         return F.embedding_bag(input, self.weight, offsets,
                                self.max_norm, self.norm_type,
                                self.scale_grad_by_freq, self.mode, self.sparse)