Add at::one_hot (#15208)
authorGao, Xiang <qasdfgtyuiop@gmail.com>
Thu, 20 Dec 2018 22:09:09 +0000 (14:09 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 22:24:58 +0000 (14:24 -0800)
Summary: Closes: https://github.com/pytorch/pytorch/issues/15060

Differential Revision: D13528014

Pulled By: ezyang

fbshipit-source-id: 5a18689a4c5638d92f9390c91517f741e5396293

aten/src/ATen/native/Onehot.cpp [new file with mode: 0644]
aten/src/ATen/native/native_functions.yaml
docs/source/nn.rst
test/test_nn.py
test/test_torch.py
torch/distributions/one_hot_categorical.py
torch/nn/functional.py

diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp
new file mode 100644 (file)
index 0000000..d02f97e
--- /dev/null
@@ -0,0 +1,35 @@
+#include <ATen/ATen.h>
+
+namespace at { namespace native {
+
+Tensor one_hot(const Tensor &self, int64_t num_classes) {
+    AT_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
+    auto shape = self.sizes().vec();
+
+    // empty tensor could be converted to one hot representation,
+    // but shape inference is not possible.
+    if (self.numel() == 0) {
+        if (num_classes <= 0) {
+            AT_ERROR("Can not infer total number of classes from empty tensor.");
+        } else {
+            shape.push_back(num_classes);
+            return at::empty(shape, self.options());
+        }
+    }
+
+    // non-empty tensor
+    AT_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
+    if (num_classes == -1) {
+        num_classes = self.max().item().toLong() + 1;
+    } else {
+        AT_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
+    }
+
+    shape.push_back(num_classes);
+    Tensor ret = at::zeros(shape, self.options());
+    ret.scatter_(-1, self.unsqueeze(-1), 1);
+    return ret;
+}
+
+} // namespace native
+} // namespace at
index 6ffb10b..1a92486 100644 (file)
   variants: method
   device_guard: false
 
+- func: one_hot(IndexTensor self, int64_t num_classes=-1) -> Tensor
+  python_module: nn
+  variants: function
+
 - func: flip(Tensor self, IntList dims) -> Tensor
   variants: function, method
   dispatch:
index 6efa977..f552ebb 100644 (file)
@@ -794,6 +794,11 @@ Utilities
 
 .. autofunction:: torch.nn.utils.remove_spectral_norm
 
+:hidden:`one_hot`
+~~~~~~~~~~~~~~~~~
+
+.. autofunction:: torch.nn.utils.one_hot
+
 
 .. currentmodule:: torch.nn.utils.rnn
 
index d62742f..58a0ca8 100644 (file)
@@ -2725,6 +2725,64 @@ class TestNN(NNTestCase):
         self.assertRaisesRegex(RuntimeError, expected_err_msg,
                                lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect'))
 
+    @staticmethod
+    def _test_one_hot(self, use_cuda=False):
+        device = torch.device('cuda' if use_cuda else 'cpu')
+        with self.assertRaises(RuntimeError):
+            torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
+
+        with self.assertRaises(RuntimeError):
+            torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
+
+        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
+        expected = torch.tensor([[0, 0, 0, 1, 0],
+                                 [0, 0, 0, 0, 1],
+                                 [0, 1, 0, 0, 0],
+                                 [1, 0, 0, 0, 0]], device=device)
+        self.assertEqual(t, expected)
+
+        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
+        expected = torch.tensor([[0, 0, 0, 1, 0],
+                                 [0, 0, 0, 0, 1],
+                                 [0, 1, 0, 0, 0],
+                                 [1, 0, 0, 0, 0]], device=device)
+        self.assertEqual(t, expected)
+
+        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
+        expected = torch.tensor([[0, 0, 0, 1, 0, 0],
+                                 [0, 0, 0, 0, 1, 0],
+                                 [0, 1, 0, 0, 0, 0],
+                                 [1, 0, 0, 0, 0, 0]], device=device)
+        self.assertEqual(t, expected)
+
+        t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
+        expected = torch.tensor([[[0, 0, 0, 1, 0],
+                                  [0, 0, 0, 0, 1]],
+                                 [[0, 1, 0, 0, 0],
+                                  [1, 0, 0, 0, 0]]], device=device)
+        self.assertEqual(t, expected)
+
+        t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
+        expected = torch.tensor([0, 0, 0, 0, 1], device=device)
+        self.assertEqual(t, expected)
+
+        t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
+        expected = torch.empty([4, 0, 100])
+        self.assertEqual(t, expected)
+
+        with self.assertRaises(RuntimeError):
+            torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
+
+        with self.assertRaises(RuntimeError):
+            torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
+
+    def test_one_hot(self):
+        self._test_one_hot(self)
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_one_hot_cuda(self):
+        self._test_one_hot(self, use_cuda=True)
+
     def test_pad_scalar_error(self):
         inputs = torch.tensor(0., requires_grad=True)
         self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
index 2d16c2d..33598ff 100644 (file)
@@ -425,7 +425,7 @@ class _TestTorchMixin(object):
         def compare_reference(input, dtype):
             input = torch.tensor(input, dtype=dtype)
             res1 = torchfn(input.clone())
-            res2 = input.clone().apply_(lambda x: mathfn(x))
+            res2 = input.clone().apply_(mathfn)
             torch.testing.assert_allclose(res1, res2)
 
         # compare against the reference math function
@@ -1034,21 +1034,21 @@ class _TestTorchMixin(object):
     def test_reduction_empty(self):
         fns_to_test = [
             # name, function, identity
-            ('max', lambda *args, **kwargs: torch.max(*args, **kwargs), None),
+            ('max', torch.max, None),
             ('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None),
-            ('argmax', lambda *args, **kwargs: torch.argmax(*args, **kwargs), None),
-            ('min', lambda *args, **kwargs: torch.min(*args, **kwargs), None),
-            ('argmin', lambda *args, **kwargs: torch.argmin(*args, **kwargs), None),
-            ('mode', lambda *args, **kwargs: torch.mode(*args, **kwargs), None),
-            ('median', lambda *args, **kwargs: torch.median(*args, **kwargs), None),
-
-            ('prod', lambda *args, **kwargs: torch.prod(*args, **kwargs), 1),
-            ('sum', lambda *args, **kwargs: torch.sum(*args, **kwargs), 0),
-            ('norm', lambda *args, **kwargs: torch.norm(*args, p=2, **kwargs), 0),
-            ('mean', lambda *args, **kwargs: torch.mean(*args, **kwargs), nan),
-            ('var', lambda *args, **kwargs: torch.var(*args, **kwargs), nan),
-            ('std', lambda *args, **kwargs: torch.std(*args, **kwargs), nan),
-            ('logsumexp', lambda *args, **kwargs: torch.logsumexp(*args, **kwargs), -inf),
+            ('argmax', torch.argmax, None),
+            ('min', torch.min, None),
+            ('argmin', torch.argmin, None),
+            ('mode', torch.mode, None),
+            ('median', torch.median, None),
+
+            ('prod', torch.prod, 1),
+            ('sum', torch.sum, 0),
+            ('norm', torch.norm, 0),
+            ('mean', torch.mean, nan),
+            ('var', torch.var, nan),
+            ('std', torch.std, nan),
+            ('logsumexp', torch.logsumexp, -inf),
         ]
 
         devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
@@ -5268,7 +5268,7 @@ class _TestTorchMixin(object):
             for i in range(o3.size(1)):
                 for j in range(k.size(1)):
                     o32[i].add(torch.xcorr2(x[i + j - 1], k[j]))
-        self._test_conv_corr_eq(lambda x, k: torch.xcorr3(x, k), reference)
+        self._test_conv_corr_eq(torch.xcorr3, reference)
 
     @unittest.skip("Not implemented yet")
     def test_xcorr3_xcorr2_eq_full(self):
@@ -5284,7 +5284,7 @@ class _TestTorchMixin(object):
             for i in range(o3.size(1)):
                 for j in range(k.size(1)):
                     o32[i].add(torch.conv2(x[i + j - 1], k[k.size(1) - j + 1]))
-        self._test_conv_corr_eq(lambda x, k: torch.conv3(x, k), reference)
+        self._test_conv_corr_eq(torch.conv3, reference)
 
     @unittest.skip("Not implemented yet")
     def test_fconv3_fconv2_eq(self):
index 5165ed6..bd23f23 100644 (file)
@@ -76,15 +76,9 @@ class OneHotCategorical(Distribution):
     def sample(self, sample_shape=torch.Size()):
         sample_shape = torch.Size(sample_shape)
         probs = self._categorical.probs
+        num_events = self._categorical._num_events
         indices = self._categorical.sample(sample_shape)
-        if torch._C._get_tracing_state():
-            # [JIT WORKAROUND] lack of support for .scatter_()
-            eye = torch.eye(self.event_shape[-1], dtype=self._param.dtype, device=self._param.device)
-            return eye[indices]
-        one_hot = probs.new_zeros(self._extended_shape(sample_shape))
-        if indices.dim() < one_hot.dim():
-            indices = indices.unsqueeze(-1)
-        return one_hot.scatter_(-1, indices, 1.)
+        return torch.nn.functional.one_hot(indices, num_events).to(probs)
 
     def log_prob(self, value):
         if self._validate_args:
index 975839f..912d65d 100644 (file)
@@ -2784,6 +2784,55 @@ Example::
 """)
 
 
+one_hot = _add_docstr(torch._C._nn.one_hot, r"""
+one_hot(tensor, num_classes=0) -> LongTensor
+
+Takes LongTensor with index values of shape ``(*)`` and returns a tensor
+of shape ``(*, num_classes)`` that have zeros everywhere except where the
+index of last dimension matches the corresponding value of the input tensor,
+in which case it will be 1.
+
+See also `One-hot on Wikipedia`_ .
+
+.. _One-hot on Wikipedia:
+    https://en.wikipedia.org/wiki/One-hot
+
+Arguments:
+    tensor (LongTensor): class values of any shape.
+    num_classes (int):  Total number of classes. If set to -1, the number
+        of classes will be inferred as one greater than the largest class
+        value in the input tensor.
+
+Returns:
+    Tensor: LongTensor that has one more dimension with 1 values at the
+        index of last dimension indicated by the input, and 0 everywhere
+        else.
+
+Examples::
+    >>> torch.one_hot(torch.arange(0, 5) % 3)
+    tensor([[1, 0, 0],
+            [0, 1, 0],
+            [0, 0, 1],
+            [1, 0, 0],
+            [0, 1, 0]])
+    >>> torch.one_hot(torch.arange(0, 5) % 3, num_classes=5)
+    tensor([[1, 0, 0, 0, 0],
+            [0, 1, 0, 0, 0],
+            [0, 0, 1, 0, 0],
+            [1, 0, 0, 0, 0],
+            [0, 1, 0, 0, 0]])
+    >>> torch.one_hot(torch.arange(0, 6).view(3,2) % 3)
+    tensor([[[1, 0, 0],
+             [0, 1, 0]],
+
+            [[0, 0, 1],
+             [1, 0, 0]],
+
+            [[0, 1, 0],
+             [0, 0, 1]]])
+""")
+
+
 @weak_script
 def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
                         reduce=None, reduction="mean"):