--- /dev/null
+#include <ATen/ATen.h>
+
+namespace at { namespace native {
+
+Tensor one_hot(const Tensor &self, int64_t num_classes) {
+ AT_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
+ auto shape = self.sizes().vec();
+
+ // empty tensor could be converted to one hot representation,
+ // but shape inference is not possible.
+ if (self.numel() == 0) {
+ if (num_classes <= 0) {
+ AT_ERROR("Can not infer total number of classes from empty tensor.");
+ } else {
+ shape.push_back(num_classes);
+ return at::empty(shape, self.options());
+ }
+ }
+
+ // non-empty tensor
+ AT_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
+ if (num_classes == -1) {
+ num_classes = self.max().item().toLong() + 1;
+ } else {
+ AT_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
+ }
+
+ shape.push_back(num_classes);
+ Tensor ret = at::zeros(shape, self.options());
+ ret.scatter_(-1, self.unsqueeze(-1), 1);
+ return ret;
+}
+
+} // namespace native
+} // namespace at
variants: method
device_guard: false
+- func: one_hot(IndexTensor self, int64_t num_classes=-1) -> Tensor
+ python_module: nn
+ variants: function
+
- func: flip(Tensor self, IntList dims) -> Tensor
variants: function, method
dispatch:
.. autofunction:: torch.nn.utils.remove_spectral_norm
+:hidden:`one_hot`
+~~~~~~~~~~~~~~~~~
+
+.. autofunction:: torch.nn.utils.one_hot
+
.. currentmodule:: torch.nn.utils.rnn
self.assertRaisesRegex(RuntimeError, expected_err_msg,
lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect'))
+ @staticmethod
+ def _test_one_hot(self, use_cuda=False):
+ device = torch.device('cuda' if use_cuda else 'cpu')
+ with self.assertRaises(RuntimeError):
+ torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
+
+ with self.assertRaises(RuntimeError):
+ torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
+
+ t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
+ expected = torch.tensor([[0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [1, 0, 0, 0, 0]], device=device)
+ self.assertEqual(t, expected)
+
+ t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
+ expected = torch.tensor([[0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [1, 0, 0, 0, 0]], device=device)
+ self.assertEqual(t, expected)
+
+ t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
+ expected = torch.tensor([[0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 1, 0],
+ [0, 1, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0]], device=device)
+ self.assertEqual(t, expected)
+
+ t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
+ expected = torch.tensor([[[0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1]],
+ [[0, 1, 0, 0, 0],
+ [1, 0, 0, 0, 0]]], device=device)
+ self.assertEqual(t, expected)
+
+ t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
+ expected = torch.tensor([0, 0, 0, 0, 1], device=device)
+ self.assertEqual(t, expected)
+
+ t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
+ expected = torch.empty([4, 0, 100])
+ self.assertEqual(t, expected)
+
+ with self.assertRaises(RuntimeError):
+ torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
+
+ with self.assertRaises(RuntimeError):
+ torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
+
+ def test_one_hot(self):
+ self._test_one_hot(self)
+
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ def test_one_hot_cuda(self):
+ self._test_one_hot(self, use_cuda=True)
+
def test_pad_scalar_error(self):
inputs = torch.tensor(0., requires_grad=True)
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
def compare_reference(input, dtype):
input = torch.tensor(input, dtype=dtype)
res1 = torchfn(input.clone())
- res2 = input.clone().apply_(lambda x: mathfn(x))
+ res2 = input.clone().apply_(mathfn)
torch.testing.assert_allclose(res1, res2)
# compare against the reference math function
def test_reduction_empty(self):
fns_to_test = [
# name, function, identity
- ('max', lambda *args, **kwargs: torch.max(*args, **kwargs), None),
+ ('max', torch.max, None),
('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None),
- ('argmax', lambda *args, **kwargs: torch.argmax(*args, **kwargs), None),
- ('min', lambda *args, **kwargs: torch.min(*args, **kwargs), None),
- ('argmin', lambda *args, **kwargs: torch.argmin(*args, **kwargs), None),
- ('mode', lambda *args, **kwargs: torch.mode(*args, **kwargs), None),
- ('median', lambda *args, **kwargs: torch.median(*args, **kwargs), None),
-
- ('prod', lambda *args, **kwargs: torch.prod(*args, **kwargs), 1),
- ('sum', lambda *args, **kwargs: torch.sum(*args, **kwargs), 0),
- ('norm', lambda *args, **kwargs: torch.norm(*args, p=2, **kwargs), 0),
- ('mean', lambda *args, **kwargs: torch.mean(*args, **kwargs), nan),
- ('var', lambda *args, **kwargs: torch.var(*args, **kwargs), nan),
- ('std', lambda *args, **kwargs: torch.std(*args, **kwargs), nan),
- ('logsumexp', lambda *args, **kwargs: torch.logsumexp(*args, **kwargs), -inf),
+ ('argmax', torch.argmax, None),
+ ('min', torch.min, None),
+ ('argmin', torch.argmin, None),
+ ('mode', torch.mode, None),
+ ('median', torch.median, None),
+
+ ('prod', torch.prod, 1),
+ ('sum', torch.sum, 0),
+ ('norm', torch.norm, 0),
+ ('mean', torch.mean, nan),
+ ('var', torch.var, nan),
+ ('std', torch.std, nan),
+ ('logsumexp', torch.logsumexp, -inf),
]
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
for i in range(o3.size(1)):
for j in range(k.size(1)):
o32[i].add(torch.xcorr2(x[i + j - 1], k[j]))
- self._test_conv_corr_eq(lambda x, k: torch.xcorr3(x, k), reference)
+ self._test_conv_corr_eq(torch.xcorr3, reference)
@unittest.skip("Not implemented yet")
def test_xcorr3_xcorr2_eq_full(self):
for i in range(o3.size(1)):
for j in range(k.size(1)):
o32[i].add(torch.conv2(x[i + j - 1], k[k.size(1) - j + 1]))
- self._test_conv_corr_eq(lambda x, k: torch.conv3(x, k), reference)
+ self._test_conv_corr_eq(torch.conv3, reference)
@unittest.skip("Not implemented yet")
def test_fconv3_fconv2_eq(self):
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
probs = self._categorical.probs
+ num_events = self._categorical._num_events
indices = self._categorical.sample(sample_shape)
- if torch._C._get_tracing_state():
- # [JIT WORKAROUND] lack of support for .scatter_()
- eye = torch.eye(self.event_shape[-1], dtype=self._param.dtype, device=self._param.device)
- return eye[indices]
- one_hot = probs.new_zeros(self._extended_shape(sample_shape))
- if indices.dim() < one_hot.dim():
- indices = indices.unsqueeze(-1)
- return one_hot.scatter_(-1, indices, 1.)
+ return torch.nn.functional.one_hot(indices, num_events).to(probs)
def log_prob(self, value):
if self._validate_args:
""")
+one_hot = _add_docstr(torch._C._nn.one_hot, r"""
+one_hot(tensor, num_classes=0) -> LongTensor
+
+Takes LongTensor with index values of shape ``(*)`` and returns a tensor
+of shape ``(*, num_classes)`` that have zeros everywhere except where the
+index of last dimension matches the corresponding value of the input tensor,
+in which case it will be 1.
+
+See also `One-hot on Wikipedia`_ .
+
+.. _One-hot on Wikipedia:
+ https://en.wikipedia.org/wiki/One-hot
+
+Arguments:
+ tensor (LongTensor): class values of any shape.
+ num_classes (int): Total number of classes. If set to -1, the number
+ of classes will be inferred as one greater than the largest class
+ value in the input tensor.
+
+Returns:
+ Tensor: LongTensor that has one more dimension with 1 values at the
+ index of last dimension indicated by the input, and 0 everywhere
+ else.
+
+Examples::
+ >>> torch.one_hot(torch.arange(0, 5) % 3)
+ tensor([[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ >>> torch.one_hot(torch.arange(0, 5) % 3, num_classes=5)
+ tensor([[1, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [1, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0]])
+ >>> torch.one_hot(torch.arange(0, 6).view(3,2) % 3)
+ tensor([[[1, 0, 0],
+ [0, 1, 0]],
+
+ [[0, 0, 1],
+ [1, 0, 0]],
+
+ [[0, 1, 0],
+ [0, 0, 1]]])
+""")
+
+
@weak_script
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
reduce=None, reduction="mean"):