#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
bool size_average,
int n_classes,
int64_t ignore_index) {
- CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 & threadIdx.z == 0);
+ CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0);
int t = static_cast<int>(*target);
if (t != static_cast<int>(ignore_index)) {
*total_weight = static_cast<scalar_t>(total_weight_acc);
if (size_average && nframe == 0) {
// Mean reduction on empty tensors produces NaN
- *output = std::numeric_limits<double>::quiet_NaN();
+ *output = std::numeric_limits<scalar_t>::quiet_NaN();
} else if (size_average && total_weight_acc != 0) {
*output = static_cast<scalar_t>(output_acc / total_weight_acc);
} else {
auto weight_ = weight.defined() ? weight.contiguous() : weight;
- if (reduction == Reduction::None & n_dims == 2) {
+ if (reduction == Reduction::None && n_dims == 2) {
output.resize_({batch_size});
if (batch_size == 0) {
// This guards from unnecessary operations and launching CUDA kernel with
target.scalar_type(),
"nll_loss_forward_reduce_cuda_kernel_2d_index",
[&] {
- nll_loss_forward_reduce_cuda_kernel_2d<scalar_t, float, index_t>
+ using accscalar_t = at::acc_type<scalar_t, /*is_cuda*/true>;
+ nll_loss_forward_reduce_cuda_kernel_2d<scalar_t, accscalar_t, index_t>
<<<1,
NLL_LOSS_THREADS,
0,
from torch._six import inf
import collections.abc
-from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Dict
from torch.testing import \
(make_non_contiguous, floating_types, floating_types_and, complex_types,
return sample_inputs
+def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
+ batch_size, num_classes = shape = (2, 3)
+
+ input_shape_and_kwargs: List[Tuple[Tuple[int, ...], Dict[str, Any]]] = [
+ ((*shape, 1), dict()),
+ ((*shape, 1, 2), dict()),
+ ((*shape, 1, 2, 3), dict()),
+ (shape, dict(weight=make_tensor((num_classes,), device=device, dtype=dtype).abs())),
+ (shape, dict(ignore_index=num_classes // 2)),
+ (shape, dict(reduction="sum")),
+ (shape, dict(reduction="mean")),
+ ]
+
+ sample_inputs = []
+ for input_shape, kwargs in input_shape_and_kwargs:
+ input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ target = make_tensor(
+ (batch_size, *input_shape[2:]),
+ low=0,
+ high=num_classes,
+ device=device,
+ dtype=torch.long,
+ requires_grad=requires_grad
+ )
+
+ sample_inputs.append(SampleInput(input, args=(target,), kwargs=kwargs))
+
+ return sample_inputs
+
foreach_unary_op_db: List[OpInfo] = [
ForeachFuncInfo('exp'),
ForeachFuncInfo('acos'),
SkipInfo('TestReductions', 'test_dim_none_keepdim'),
),
),
+ OpInfo(
+ "nn.functional.nll_loss",
+ ref=_NOTHING,
+ dtypesIfCPU=floating_types_and(torch.bfloat16),
+ dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ supports_out=False,
+ sample_inputs_func=sample_inputs_nll_loss,
+ skips=(
+ SkipInfo(
+ "TestJit",
+ "test_variant_consistency_jit",
+ dtypes=(torch.float32,),
+ ),
+ ),
+ ),
]
# Common operator groupings