Summary:
Fixes https://github.com/pytorch/pytorch/issues/24608
Fixes https://github.com/pytorch/pytorch/issues/24607
With the following benchmark, the backward pass runs a little slower. This is strange since the implementation should be exactly the same.
<details>
<summary>Benchmark script</summary>
```python
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
torch.manual_seed(0)
MS_PER_SECOND = 1000
def _time():
torch.cuda.synchronize()
return time.perf_counter() * MS_PER_SECOND
device = "cuda"
C = 3
n_runs = 30
reductions = ["none", "sum", "mean"]
Ns = [128, 256, 512]
Hs = [128, 256, 512]
for reduction, N, H in product(reductions, Ns, Hs):
total_fwd_time = 0
total_back_time = 0
if reduction == "none":
grad_out = torch.randn(N, H, H, device=device)
else:
grad_out = torch.randn(1)[0]
for _ in range(n_runs):
input = torch.randn(N, C, H, H, device=device, requires_grad=True)
target = torch.rand(N, H, H, device=device).mul(3).floor().long()
# forward
start = _time()
result = F.nll_loss(input, target, reduction=reduction)
total_fwd_time += _time() - start
result = F.nll_loss(input, target, reduction=reduction)
for _ in range(n_runs):
# backward
start = _time()
result.backward(grad_out, retain_graph=True)
total_back_time += _time() - start
fwd_avg = total_fwd_time / n_runs
bwd_avg = total_back_time / n_runs
print(
f"input size({N}, {C}, {H}, {H}), reduction: {reduction}, fwd: {fwd_avg:.2f} (ms), back: {bwd_avg:.2f} (ms)"
)
```
</details>
<details>
<summary>master results</summary>
```
input size(128, 3, 128, 128), reduction: none, fwd: 0.34 (ms), back: 0.57 (ms)
input size(128, 3, 256, 256), reduction: none, fwd: 2.56 (ms), back: 3.85 (ms)
input size(128, 3, 512, 512), reduction: none, fwd: 14.54 (ms), back: 16.62 (ms)
input size(256, 3, 128, 128), reduction: none, fwd: 1.26 (ms), back: 1.78 (ms)
input size(256, 3, 256, 256), reduction: none, fwd: 7.07 (ms), back: 8.22 (ms)
input size(256, 3, 512, 512), reduction: none, fwd: 29.38 (ms), back: 33.29 (ms)
input size(512, 3, 128, 128), reduction: none, fwd: 3.41 (ms), back: 4.05 (ms)
input size(512, 3, 256, 256), reduction: none, fwd: 14.32 (ms), back: 16.46 (ms)
input size(512, 3, 512, 512), reduction: none, fwd: 59.20 (ms), back: 66.68 (ms)
input size(128, 3, 128, 128), reduction: sum, fwd: 0.08 (ms), back: 0.21 (ms)
input size(128, 3, 256, 256), reduction: sum, fwd: 0.21 (ms), back: 0.73 (ms)
input size(128, 3, 512, 512), reduction: sum, fwd: 0.82 (ms), back: 2.86 (ms)
input size(256, 3, 128, 128), reduction: sum, fwd: 0.12 (ms), back: 0.39 (ms)
input size(256, 3, 256, 256), reduction: sum, fwd: 0.42 (ms), back: 1.45 (ms)
input size(256, 3, 512, 512), reduction: sum, fwd: 1.53 (ms), back: 5.66 (ms)
input size(512, 3, 128, 128), reduction: sum, fwd: 0.21 (ms), back: 0.74 (ms)
input size(512, 3, 256, 256), reduction: sum, fwd: 0.78 (ms), back: 2.86 (ms)
input size(512, 3, 512, 512), reduction: sum, fwd: 2.98 (ms), back: 11.23 (ms)
input size(128, 3, 128, 128), reduction: mean, fwd: 0.07 (ms), back: 0.21 (ms)
input size(128, 3, 256, 256), reduction: mean, fwd: 0.21 (ms), back: 0.73 (ms)
input size(128, 3, 512, 512), reduction: mean, fwd: 0.82 (ms), back: 2.86 (ms)
input size(256, 3, 128, 128), reduction: mean, fwd: 0.13 (ms), back: 0.39 (ms)
input size(256, 3, 256, 256), reduction: mean, fwd: 0.42 (ms), back: 1.45 (ms)
input size(256, 3, 512, 512), reduction: mean, fwd: 1.54 (ms), back: 5.65 (ms)
input size(512, 3, 128, 128), reduction: mean, fwd: 0.22 (ms), back: 0.74 (ms)
input size(512, 3, 256, 256), reduction: mean, fwd: 0.78 (ms), back: 2.87 (ms)
input size(512, 3, 512, 512), reduction: mean, fwd: 2.98 (ms), back: 11.23 (ms)
```
</details>
<details>
<summary>PR results</summary>
```
input size(128, 3, 128, 128), reduction: none, fwd: 0.33 (ms), back: 0.59 (ms)
input size(128, 3, 256, 256), reduction: none, fwd: 2.51 (ms), back: 3.92 (ms)
input size(128, 3, 512, 512), reduction: none, fwd: 14.52 (ms), back: 17.05 (ms)
input size(256, 3, 128, 128), reduction: none, fwd: 1.23 (ms), back: 1.85 (ms)
input size(256, 3, 256, 256), reduction: none, fwd: 7.07 (ms), back: 8.45 (ms)
input size(256, 3, 512, 512), reduction: none, fwd: 29.39 (ms), back: 34.21 (ms)
input size(512, 3, 128, 128), reduction: none, fwd: 3.40 (ms), back: 4.18 (ms)
input size(512, 3, 256, 256), reduction: none, fwd: 14.33 (ms), back: 16.90 (ms)
input size(512, 3, 512, 512), reduction: none, fwd: 59.04 (ms), back: 68.36 (ms)
input size(128, 3, 128, 128), reduction: sum, fwd: 0.07 (ms), back: 0.25 (ms)
input size(128, 3, 256, 256), reduction: sum, fwd: 0.21 (ms), back: 0.86 (ms)
input size(128, 3, 512, 512), reduction: sum, fwd: 0.82 (ms), back: 3.33 (ms)
input size(256, 3, 128, 128), reduction: sum, fwd: 0.12 (ms), back: 0.46 (ms)
input size(256, 3, 256, 256), reduction: sum, fwd: 0.42 (ms), back: 1.70 (ms)
input size(256, 3, 512, 512), reduction: sum, fwd: 1.53 (ms), back: 6.58 (ms)
input size(512, 3, 128, 128), reduction: sum, fwd: 0.21 (ms), back: 0.87 (ms)
input size(512, 3, 256, 256), reduction: sum, fwd: 0.78 (ms), back: 3.34 (ms)
input size(512, 3, 512, 512), reduction: sum, fwd: 2.98 (ms), back: 13.07 (ms)
input size(128, 3, 128, 128), reduction: mean, fwd: 0.07 (ms), back: 0.26 (ms)
input size(128, 3, 256, 256), reduction: mean, fwd: 0.21 (ms), back: 0.86 (ms)
input size(128, 3, 512, 512), reduction: mean, fwd: 0.82 (ms), back: 3.34 (ms)
input size(256, 3, 128, 128), reduction: mean, fwd: 0.12 (ms), back: 0.46 (ms)
input size(256, 3, 256, 256), reduction: mean, fwd: 0.42 (ms), back: 1.72 (ms)
input size(256, 3, 512, 512), reduction: mean, fwd: 1.53 (ms), back: 6.60 (ms)
input size(512, 3, 128, 128), reduction: mean, fwd: 0.21 (ms), back: 0.87 (ms)
input size(512, 3, 256, 256), reduction: mean, fwd: 0.78 (ms), back: 3.33 (ms)
input size(512, 3, 512, 512), reduction: mean, fwd: 2.98 (ms), back: 13.07 (ms)
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62826
Reviewed By: bdhirsh
Differential Revision:
D30282279
Pulled By: ngimel
fbshipit-source-id:
4aa0ff3f8af0632957417931d332ec486a12b52d
"aten/src/THCUNN/SoftMarginCriterion.cu.cc",
"aten/src/THCUNN/SoftPlus.cu.cc",
"aten/src/THCUNN/SoftShrink.cu.cc",
- "aten/src/THCUNN/SpatialClassNLLCriterion.cu.cc",
"aten/src/THCUNN/SpatialConvolutionMM.cu.cc",
"aten/src/THCUNN/Tanh.cu.cc",
],
"aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu.cc",
"aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu.cc",
"aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu.cc",
+ "aten/src/ATen/native/cuda/NLLLoss2d.cu.cc",
"aten/src/ATen/native/cuda/Normalization.cu.cc",
"aten/src/ATen/native/cuda/PointwiseOpsKernel.cu.cc",
"aten/src/ATen/native/cuda/PowKernel.cu.cc",
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper);
Tensor _th_potri(const Tensor & self, bool upper);
Tensor & _th_copy_ignoring_overlaps_(Tensor & self, const Tensor & src);
-std::tuple<Tensor &,Tensor &> _thnn_nll_loss2d_forward_out(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output, Tensor & total_weight);
-std::tuple<Tensor,Tensor> _thnn_nll_loss2d_forward(const Tensor & self, const Tensor & target, const optional<Tensor> & weight, int64_t reduction, int64_t ignore_index);
-Tensor & _thnn_nll_loss2d_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, const Tensor & total_weight, Tensor & grad_input);
-Tensor _thnn_nll_loss2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const optional<Tensor> & weight, int64_t reduction, int64_t ignore_index, const Tensor & total_weight);
Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, const Scalar& lower, const Scalar& upper, bool training);
std::tuple<Tensor &,Tensor &,Tensor &> _thnn_conv2d_forward_out(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const c10::optional<Tensor>& bias_opt, IntArrayRef stride, IntArrayRef padding, Tensor & output, Tensor & columns, Tensor & ones);
std::tuple<Tensor,Tensor,Tensor> _thnn_conv2d_forward(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const optional<Tensor> & bias, IntArrayRef stride, IntArrayRef padding);
}
return self;
}
-std::tuple<Tensor &,Tensor &> _thnn_nll_loss2d_forward_out(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output, Tensor & total_weight) {
- // See [Note: hacky wrapper removal for optional tensor]
- c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
- const Tensor& weight = *weight_maybe_owned;
-
- const OptionalDeviceGuard device_guard(device_of(self));
- auto dispatch_scalar_type = infer_scalar_type(self);
-
- switch (dispatch_scalar_type) {
- case ScalarType::Double: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type);
- auto output_ = checked_dense_tensor_unwrap(output, "output", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaDoubleSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::Float: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type);
- auto output_ = checked_dense_tensor_unwrap(output, "output", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::Half: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type);
- auto output_ = checked_dense_tensor_unwrap(output, "output", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaHalfSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::BFloat16: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type);
- auto output_ = checked_dense_tensor_unwrap(output, "output", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaBFloat16SpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- default:
- AT_ERROR("_thnn_nll_loss2d_forward_out not supported on CUDAType for ", dispatch_scalar_type);
- }
- return std::tuple<Tensor &, Tensor &>(output, total_weight);
-}
-std::tuple<Tensor,Tensor> _thnn_nll_loss2d_forward(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index) {
- // See [Note: hacky wrapper removal for optional tensor]
- c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
- const Tensor& weight = *weight_maybe_owned;
-
- const OptionalDeviceGuard device_guard(device_of(self));
- auto dispatch_scalar_type = infer_scalar_type(self);
- auto output_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
- auto output = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(output_));
- auto total_weight_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
- auto total_weight = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(total_weight_));
- switch (dispatch_scalar_type) {
- case ScalarType::Double: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward", true, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaDoubleSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::Float: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward", true, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::Half: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward", true, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaHalfSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::BFloat16: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward", true, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaBFloat16SpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- default:
- AT_ERROR("_thnn_nll_loss2d_forward not supported on CUDAType for ", dispatch_scalar_type);
- }
- return std::tuple<Tensor, Tensor>(output, total_weight);
-}
-Tensor & _thnn_nll_loss2d_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, const Tensor & total_weight, Tensor & grad_input) {
- // See [Note: hacky wrapper removal for optional tensor]
- c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
- const Tensor& weight = *weight_maybe_owned;
-
- const OptionalDeviceGuard device_guard(device_of(self));
- auto dispatch_scalar_type = infer_scalar_type(self);
-
- switch (dispatch_scalar_type) {
- case ScalarType::Double: {
- auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaDoubleSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::Float: {
- auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::Half: {
- auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaHalfSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::BFloat16: {
- auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaBFloat16SpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- default:
- AT_ERROR("_thnn_nll_loss2d_backward_out not supported on CUDAType for ", dispatch_scalar_type);
- }
- return grad_input;
-}
-Tensor _thnn_nll_loss2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, const Tensor & total_weight) {
- // See [Note: hacky wrapper removal for optional tensor]
- c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
- const Tensor& weight = *weight_maybe_owned;
-
- const OptionalDeviceGuard device_guard(device_of(self));
- auto dispatch_scalar_type = infer_scalar_type(self);
- auto grad_input_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
- auto grad_input = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(grad_input_));
- switch (dispatch_scalar_type) {
- case ScalarType::Double: {
- auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward", true, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaDoubleSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::Float: {
- auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward", true, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::Half: {
- auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward", true, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaHalfSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- case ScalarType::BFloat16: {
- auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, ScalarType::Long);
- auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward", true, DeviceType::CUDA, dispatch_scalar_type);
- auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type);
- THNN_CudaBFloat16SpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
- break;
- }
- default:
- AT_ERROR("_thnn_nll_loss2d_backward not supported on CUDAType for ", dispatch_scalar_type);
- }
- return grad_input;
-}
std::tuple<Tensor &,Tensor &,Tensor &> _thnn_conv2d_forward_out(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const c10::optional<Tensor>& bias_opt, IntArrayRef stride, IntArrayRef padding, Tensor & output, Tensor & columns, Tensor & ones) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
--- /dev/null
+#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
+#include <ATen/Dispatch.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/core/TensorAccessor.h>
+#include <ATen/cuda/detail/KernelUtils.h>
+#include <THC/THCAtomics.cuh>
+#include <c10/cuda/CUDAException.h>
+#include <c10/macros/Macros.h>
+#include <ATen/native/Resize.h>
+#include <ATen/native/cuda/block_reduce.cuh>
+
+namespace at {
+namespace native {
+
+namespace {
+
+// Returns a contiguous tensor if the source tensor
+// is defined. Otherwise returns the undefined
+// source tensor unmodified.
+inline Tensor optional_contiguous(const Tensor& source) {
+ return source.defined() ? source.contiguous() : source;
+}
+
+// Returns the address of the first element of a tensor
+// or nullptr if the tensor is undefined.
+template <typename scalar_t>
+inline scalar_t* optional_data(const Tensor& source) {
+ return source.defined() ? source.data_ptr<scalar_t>() : nullptr;
+}
+
+using at::cuda::detail::CUDA_NUM_THREADS;
+using at::cuda::detail::GET_BLOCKS;
+
+template <typename scalar_t>
+C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
+__global__ void nll_loss2d_forward_no_reduce_kernel(
+ int64_t n_threads,
+ PackedTensorAccessor64<scalar_t, 4> input,
+ PackedTensorAccessor64<int64_t, 3> target,
+ PackedTensorAccessor64<scalar_t, 3> output,
+ scalar_t* weight,
+ int64_t ignore_index
+) {
+ int64_t batch_size = input.size(0);
+ int64_t H = input.size(2);
+ int64_t W = input.size(3);
+
+ CUDA_KERNEL_LOOP(index, n_threads) {
+ const int64_t b = index % batch_size;
+ const int64_t h = (index / batch_size) % H;
+ const int64_t w = (index / (batch_size * H)) % W;
+
+ int64_t cur_target = target[b][h][w];
+ if (cur_target == ignore_index) {
+ output[b][h][w] = static_cast<scalar_t>(0);
+ continue;
+ }
+ scalar_t value = input[b][cur_target][h][w];
+ scalar_t cur_weight = weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1);
+ output[b][h][w] = -value * cur_weight;
+ }
+}
+
+template <typename scalar_t, typename accscalar_t>
+C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
+__global__ void nll_loss2d_forward_kernel(
+ scalar_t* output,
+ scalar_t* total_weight,
+ scalar_t* input,
+ int64_t* target,
+ scalar_t* weight,
+ bool size_average,
+ int batch_size,
+ int n_classes,
+ int map_nelem,
+ int blocks_per_sample,
+ int64_t ignore_index) {
+
+ scalar_t cur_weight;
+ accscalar_t input_sum = 0;
+ accscalar_t acc_weight = 0;
+
+ int sample = blockIdx.x / blocks_per_sample;
+ int toffset = sample * map_nelem;
+ int ioffset = sample * map_nelem * n_classes;
+ int step = blockDim.x * blocks_per_sample;
+ for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
+ i < map_nelem;
+ i += step) {
+ int t = target[toffset + i];
+ if (t != ignore_index) {
+ CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
+ cur_weight = weight != nullptr ? weight[t] : static_cast<scalar_t>(1);
+ input_sum -= input[ioffset + i + map_nelem * t] * cur_weight;
+ acc_weight += cur_weight;
+ }
+ }
+
+ __shared__ accscalar_t acc_weight_smem[CUDA_NUM_THREADS];
+ __shared__ accscalar_t input_sum_smem[CUDA_NUM_THREADS];
+ auto acc_weight_ = cuda_utils::BlockReduceSum(acc_weight, acc_weight_smem);
+ auto input_sum_ = cuda_utils::BlockReduceSum(input_sum, input_sum_smem);
+
+ if (threadIdx.x == 0) {
+ gpuAtomicAdd(total_weight, static_cast<scalar_t>(acc_weight_));
+ gpuAtomicAdd(output, static_cast<scalar_t>(input_sum_));
+ }
+}
+
+template <typename scalar_t>
+C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
+__global__ void nll_loss2d_forward_size_average_kernel(
+ scalar_t* output,
+ scalar_t* total_weight,
+ int n_elements
+) {
+ if (n_elements == 0) {
+ // Mean reduction on empty tensors produces NaN
+ *output = std::numeric_limits<double>::quiet_NaN();
+ }
+ if (*total_weight != 0) {
+ *output /= *total_weight;
+ }
+}
+
+template <typename scalar_t>
+C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
+__global__ void nll_loss2d_backward_no_reduce_kernel(
+ int64_t n_threads,
+ PackedTensorAccessor64<int64_t, 3> target,
+ PackedTensorAccessor64<scalar_t, 3> grad_output,
+ PackedTensorAccessor64<scalar_t, 4> grad_input,
+ scalar_t* weight,
+ int64_t ignore_index
+) {
+ int64_t batch_size = target.size(0);
+ int64_t H = target.size(1);
+ int64_t W = target.size(2);
+
+ CUDA_KERNEL_LOOP(index, n_threads) {
+ const int64_t b = index % batch_size;
+ const int64_t h = (index / batch_size) % H;
+ const int64_t w = (index / (batch_size * H)) % W;
+
+ int64_t cur_target = target[b][h][w];
+ if (cur_target == ignore_index) {
+ continue;
+ }
+ scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1));
+ grad_input[b][cur_target][h][w] = value * grad_output[b][h][w];
+ }
+}
+
+template <typename scalar_t>
+C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
+__global__ void nll_loss2d_backward_kernel(
+ scalar_t* grad_input,
+ scalar_t* grad_output,
+ int64_t* target,
+ scalar_t* weight,
+ scalar_t* total_weight,
+ bool size_average,
+ int batch_size,
+ int n_classes,
+ int map_nelem,
+ int blocks_per_sample,
+ int64_t ignore_index
+) {
+ if (*total_weight <= 0) {
+ return;
+ }
+
+ scalar_t norm = size_average ? (static_cast<scalar_t>(1) / *total_weight) : static_cast<scalar_t>(1);
+
+ int sample = blockIdx.x / blocks_per_sample;
+ int step = blockDim.x * blocks_per_sample;
+ int toffset = sample * map_nelem;
+ int ioffset = sample * map_nelem * n_classes;
+ for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
+ i < map_nelem;
+ i += step) {
+ int t = (int)target[toffset + i];
+ if (t != ignore_index) {
+ CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
+ grad_input[ioffset + i + map_nelem * t] = -(weight != nullptr ? weight[t] : static_cast<scalar_t>(1)) * norm * grad_output[0];
+ }
+ }
+}
+
+void check_inputs_nll_loss2d(
+ const Tensor& input,
+ const Tensor& target,
+ const Tensor& weight) {
+ TORCH_CHECK(
+ target.dim() == 3,
+ "only batches of spatial targets supported (3D tensors)"
+ " but got targets of size: : ",
+ target.sizes());
+ TORCH_CHECK(
+ input.dim() == 4,
+ "only batches of spatial inputs supported (4D tensors), "
+ "but got input of size: ",
+ input.sizes());
+ TORCH_CHECK(
+ !weight.defined() || weight.numel() == input.size(1),
+ "weight tensor should be defined either for all or no classes");
+
+ TORCH_CHECK(
+ input.size(0) == target.size(0) && input.size(2) == target.size(1) &&
+ input.size(3) == target.size(2),
+ "input and target batch or spatial sizes don't match: target ",
+ target.sizes(),
+ ", input ",
+ input.sizes());
+}
+
+void nll_loss2d_forward_out_cuda_template(
+ Tensor& output,
+ Tensor& total_weight,
+ const Tensor& input,
+ const Tensor& target,
+ const c10::optional<Tensor>& weight_opt,
+ int64_t reduction,
+ int64_t ignore_index) {
+ // See Note [Writing Nondeterministic Operations]
+ // Nondeterministic because of atomicAdd usage in 'sum' or 'mean' reductions.
+ if (reduction != at::Reduction::None) {
+ at::globalContext().alertNotDeterministic("nll_loss2d_forward_out_cuda_template");
+ }
+
+ // See [Note: hacky wrapper removal for optional tensor]
+ c10::MaybeOwned<Tensor> weight_maybe_owned =
+ at::borrow_from_optional_tensor(weight_opt);
+ const Tensor& weight = *weight_maybe_owned;
+
+ check_inputs_nll_loss2d(input, target, weight);
+ total_weight.resize_({});
+
+ if (reduction == at::Reduction::None) {
+ int64_t batch_size = input.size(0);
+ int64_t H = input.size(2);
+ int64_t W = input.size(3);
+ int64_t count = batch_size * H * W;
+
+ resize_output(output, {batch_size, H, W});
+ if (count == 0) {
+ // This guards from unnecessary operations and launching CUDA kernel with
+ // 0 blocks.
+ return;
+ }
+ auto weight_ = optional_contiguous(weight);
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ input.scalar_type(),
+ "nll_loss2d_forward_no_reduce_kernel",
+ [&] {
+ nll_loss2d_forward_no_reduce_kernel<scalar_t>
+ <<<GET_BLOCKS(count),
+ CUDA_NUM_THREADS,
+ 0,
+ at::cuda::getCurrentCUDAStream()>>>(
+ count,
+ input.packed_accessor<scalar_t, 4>(),
+ target.packed_accessor<int64_t, 3>(),
+ output.packed_accessor<scalar_t, 3>(),
+ optional_data<scalar_t>(weight_),
+ ignore_index);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+ return;
+ }
+
+ // produce scalar outputs for the reduction case
+ resize_output(output, {});
+
+ auto input_ = input.contiguous();
+ auto weight_ = optional_contiguous(weight);
+ auto target_ = target.contiguous();
+
+ output.fill_(0);
+ total_weight.fill_(0);
+
+ auto batch_size = target.size(0);
+ auto target_numel = target.numel();
+ if (batch_size != 0 && target_numel != 0) {
+ // This guards from unnecessary operations and launching CUDA kernel with 0
+ // blocks. launch kernel
+ int64_t map_nelem = target_numel / batch_size;
+ int blocks_per_sample = GET_BLOCKS(map_nelem) / 128;
+ blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
+ int total_blocks = blocks_per_sample * batch_size;
+
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ input.scalar_type(),
+ "nll_loss2d_forward_kernel",
+ [&] {
+ using accscalar_t = acc_type<scalar_t, true>;
+ nll_loss2d_forward_kernel<scalar_t, accscalar_t>
+ <<<total_blocks,
+ CUDA_NUM_THREADS,
+ 0,
+ at::cuda::getCurrentCUDAStream()>>>(
+ output.data_ptr<scalar_t>(),
+ total_weight.data_ptr<scalar_t>(),
+ input_.data_ptr<scalar_t>(),
+ target_.data_ptr<int64_t>(),
+ optional_data<scalar_t>(weight_),
+ reduction == at::Reduction::Mean,
+ input_.size(0),
+ input_.size(1),
+ input_.size(2) * input_.size(3),
+ blocks_per_sample,
+ ignore_index);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+ }
+ if (reduction == at::Reduction::Mean) {
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ input.scalar_type(),
+ "nll_loss2d_forward_size_average_kernel",
+ [&] {
+ nll_loss2d_forward_size_average_kernel<scalar_t>
+ <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
+ output.data_ptr<scalar_t>(),
+ total_weight.data_ptr<scalar_t>(),
+ input_.numel());
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+ }
+}
+
+void nll_loss2d_backward_out_cuda_template(
+ Tensor& grad_input,
+ const Tensor& grad_output,
+ const Tensor& input,
+ const Tensor& target,
+ const c10::optional<Tensor>& weight_opt,
+ int64_t reduction,
+ int64_t ignore_index,
+ const Tensor& total_weight) {
+ // See [Note: hacky wrapper removal for optional tensor]
+ c10::MaybeOwned<Tensor> weight_maybe_owned =
+ at::borrow_from_optional_tensor(weight_opt);
+ const Tensor& weight = *weight_maybe_owned;
+
+ check_inputs_nll_loss2d(input, target, weight);
+ grad_input.resize_as_(input);
+ grad_input.zero_();
+ TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
+ TORCH_CHECK(
+ total_weight.numel() == 1,
+ "expected total_weight to be a single element tensor, got: ",
+ total_weight.sizes(),
+ " (",
+ total_weight.numel(),
+ " elements)");
+
+
+ if (reduction == at::Reduction::None) {
+ TORCH_CHECK(
+ grad_output.dim() == 3,
+ "grad_output must have same dimension as target (3) but got dimension: ",
+ grad_output.sizes());
+ TORCH_CHECK(
+ grad_output.size(0) == target.size(0) &&
+ grad_output.size(1) == target.size(1) &&
+ grad_output.size(2) == target.size(2),
+ "grad_output sizes don't match target sizes: target ",
+ target.sizes(),
+ ", grad_output ",
+ grad_output.sizes())
+ int64_t batch_size = input.size(0);
+ int64_t H = input.size(2);
+ int64_t W = input.size(3);
+ int64_t count = batch_size * H * W;
+
+ if (count == 0) {
+ // This guards from unnecessary operations and launching CUDA kernel with
+ // 0 blocks.
+ return;
+ }
+ auto weight_ = optional_contiguous(weight);
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ input.scalar_type(),
+ "nll_loss2d_backward_no_reduce_kernel",
+ [&] {
+ nll_loss2d_backward_no_reduce_kernel<scalar_t>
+ <<<GET_BLOCKS(count),
+ CUDA_NUM_THREADS,
+ 0,
+ at::cuda::getCurrentCUDAStream()>>>(
+ count,
+ target.packed_accessor<int64_t, 3>(),
+ grad_output.packed_accessor<scalar_t, 3>(),
+ grad_input.packed_accessor<scalar_t, 4>(),
+ optional_data<scalar_t>(weight_),
+ ignore_index);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+ return;
+ }
+
+ int64_t batch_size = target.size(0);
+ auto target_numel = target.numel();
+ if (batch_size != 0 && target_numel != 0) {
+ // This guards from unnecessary operations and launching CUDA kernel with 1
+ // blocks.
+ auto target_ = target.contiguous();
+ auto weight_ = optional_contiguous(weight);
+
+ int64_t map_nelem = target_numel / batch_size;
+ int blocks_per_sample = GET_BLOCKS(map_nelem) / 128;
+ blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
+ int total_blocks = blocks_per_sample * batch_size;
+
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ input.scalar_type(),
+ "nll_loss2d_backward_kernel",
+ [&] {
+ nll_loss2d_backward_kernel<scalar_t>
+ <<<total_blocks,
+ CUDA_NUM_THREADS,
+ 0,
+ at::cuda::getCurrentCUDAStream()>>>(
+ grad_input.data_ptr<scalar_t>(),
+ grad_output.data_ptr<scalar_t>(),
+ target_.data_ptr<int64_t>(),
+ optional_data<scalar_t>(weight_),
+ total_weight.data_ptr<scalar_t>(),
+ reduction == at::Reduction::Mean,
+ input.size(0),
+ input.size(1),
+ map_nelem,
+ blocks_per_sample,
+ ignore_index);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+ }
+}
+} // namespace
+
+std::tuple<Tensor&, Tensor&> nll_loss2d_forward_out_cuda(
+ const Tensor& self,
+ const Tensor& target,
+ const c10::optional<Tensor>& weight_opt,
+ int64_t reduction,
+ int64_t ignore_index,
+ Tensor& output,
+ Tensor& total_weight) {
+ nll_loss2d_forward_out_cuda_template(
+ output, total_weight, self, target, weight_opt, reduction, ignore_index);
+ return std::tuple<Tensor&, Tensor&>(output, total_weight);
+}
+
+std::tuple<Tensor, Tensor> nll_loss2d_forward_cuda(
+ const Tensor& self,
+ const Tensor& target,
+ const c10::optional<Tensor>& weight_opt,
+ int64_t reduction,
+ int64_t ignore_index) {
+ auto output = at::empty({0}, self.options());
+ auto total_weight = at::empty({0}, self.options());
+ nll_loss2d_forward_out_cuda_template(
+ output, total_weight, self, target, weight_opt, reduction, ignore_index);
+ return std::make_tuple(output, total_weight);
+}
+
+Tensor& nll_loss2d_backward_out_cuda(
+ const Tensor& grad_output,
+ const Tensor& self,
+ const Tensor& target,
+ const c10::optional<Tensor>& weight_opt,
+ int64_t reduction,
+ int64_t ignore_index,
+ const Tensor& total_weight,
+ Tensor& grad_input) {
+ nll_loss2d_backward_out_cuda_template(
+ grad_input,
+ grad_output,
+ self,
+ target,
+ weight_opt,
+ reduction,
+ ignore_index,
+ total_weight);
+ return grad_input;
+}
+
+Tensor nll_loss2d_backward_cuda(
+ const Tensor& grad_output,
+ const Tensor& self,
+ const Tensor& target,
+ const c10::optional<Tensor>& weight_opt,
+ int64_t reduction,
+ int64_t ignore_index,
+ const Tensor& total_weight) {
+ auto grad_input = at::empty_like(self);
+ nll_loss2d_backward_out_cuda_template(
+ grad_input,
+ grad_output,
+ self,
+ target,
+ weight_opt,
+ reduction,
+ ignore_index,
+ total_weight);
+ return grad_input;
+}
+
+} // namespace native
+} // namespace at
python_module: nn
dispatch:
CPU: nll_loss2d_forward_out_cpu
- CUDA: legacy::cuda::_thnn_nll_loss2d_forward_out
+ CUDA: nll_loss2d_forward_out_cuda
- func: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)
python_module: nn
dispatch:
CPU: nll_loss2d_forward_cpu
- CUDA: legacy::cuda::_thnn_nll_loss2d_forward
+ CUDA: nll_loss2d_forward_cuda
- func: nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
dispatch:
CPU: nll_loss2d_backward_out_cpu
- CUDA: legacy::cuda::_thnn_nll_loss2d_backward_out
+ CUDA: nll_loss2d_backward_out_cuda
- func: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor
python_module: nn
dispatch:
CPU: nll_loss2d_backward_cpu
- CUDA: legacy::cuda::_thnn_nll_loss2d_backward
+ CUDA: nll_loss2d_backward_cuda
- func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
-${CMAKE_CURRENT_SOURCE_DIR}/SpatialClassNLLCriterion.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialConvolutionMM.cu
PARENT_SCOPE)
+++ /dev/null
-#include <limits>
-
-#include <THCUNN/THCUNN.h>
-#include <TH/THHalf.h>
-#include <THC/THCNumerics.cuh>
-#include <THC/THCAtomics.cuh>
-#include <THCUNN/common.h>
-#include <THC/THCDeviceTensor.cuh>
-#include <THC/THCDeviceTensorUtils.cuh>
-#include <THC/THCDeviceUtils.cuh>
-#include <THC/THCApply.cuh>
-#include <c10/macros/Macros.h>
-#include <ATen/cuda/detail/KernelUtils.h>
-
-#include <thrust/functional.h>
-
-template <typename Dtype>
-C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
-__global__ void SpatialClassNLLCriterion_updateOutput_no_reduce_kernel(
- int64_t nthreads,
- THCDeviceTensor<Dtype, 4> input,
- THCDeviceTensor<THCIndex_t, 3> target,
- THCDeviceTensor<Dtype, 3> output,
- Dtype* weights,
- int64_t ignore_index) {
- int64_t batch_size = input.getSize(0);
- int64_t H = input.getSize(2);
- int64_t W = input.getSize(3);
-
- CUDA_KERNEL_LOOP(index, nthreads) {
- const int64_t b = index % batch_size;
- const int64_t h = (index / batch_size) % H;
- const int64_t w = (index / (batch_size * H)) % W;
-
- int64_t cur_target = target[b][h][w];
- if (cur_target == ignore_index) {
- output[b][h][w] = ScalarConvert<int, Dtype>::to(0);
- continue;
- }
- Dtype value = input[b][cur_target][h][w];
- Dtype weight =
- weights ? weights[cur_target] : ScalarConvert<int, Dtype>::to(1);
- output[b][h][w] = -value * weight;
- }
-}
-
-template <typename Dtype>
-C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
-__global__ void SpatialClassNLLCriterion_updateGradInput_no_reduce_kernel(
- int64_t nthreads,
- THCDeviceTensor<THCIndex_t, 3> target,
- THCDeviceTensor<Dtype, 3> gradOutput,
- THCDeviceTensor<Dtype, 4> gradInput,
- Dtype* weights,
- int64_t ignore_index) {
- int64_t batch_size = target.getSize(0);
- int64_t H = target.getSize(1);
- int64_t W = target.getSize(2);
-
- CUDA_KERNEL_LOOP(index, nthreads) {
- const int64_t b = index % batch_size;
- const int64_t h = (index / batch_size) % H;
- const int64_t w = (index / (batch_size * H)) % W;
-
- int64_t cur_target = target[b][h][w];
- if (cur_target == ignore_index) {
- continue;
- }
- Dtype value =
- -(weights ? weights[cur_target] : ScalarConvert<int, Dtype>::to(1));
- gradInput[b][cur_target][h][w] = value * gradOutput[b][h][w];
- }
-}
-
-template <typename T, typename AccumT>
-C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
-__global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel(
- T* output,
- T* total_weight,
- T* input,
- THCIndex_t* target,
- T* weights,
- int size_average,
- int batch_size,
- int n_classes,
- int map_nelem,
- int blocks_per_sample,
- int64_t ignore_index) {
- __shared__ AccumT partial_sums[CUDA_NUM_THREADS];
-
- int i, t;
- T cur_weight;
- AccumT input_sum = 0;
- AccumT acc_weight = 0;
-
- int sample = blockIdx.x / blocks_per_sample;
- int toffset = sample * map_nelem;
- int ioffset = sample * map_nelem * n_classes;
- int step = blockDim.x * blocks_per_sample;
- for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
- i < map_nelem;
- i += step) {
- t = target[toffset + i];
- if (t != ignore_index) {
- CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
- cur_weight = weights ? weights[t] : ScalarConvert<int, T>::to(1);
- input_sum -= input[ioffset + i + map_nelem * t] * cur_weight;
- acc_weight += cur_weight;
- }
- }
-
- input_sum = reduceBlock(partial_sums, blockDim.x, input_sum, thrust::plus<AccumT>(), AccumT(0));
- __syncthreads();
- acc_weight = reduceBlock(partial_sums, blockDim.x, acc_weight, thrust::plus<AccumT>(), AccumT(0));
-
- if (threadIdx.x == 0) {
- gpuAtomicAdd(total_weight, ScalarConvert<AccumT, T>::to(acc_weight));
- gpuAtomicAdd(output, ScalarConvert<AccumT, T>::to(input_sum));
- }
-}
-
-template<typename T>
-__global__ void cunn_SpatialClassNLLCriterion_sizeAverage_kernel(
- T *output,
- T *total_weight,
- int nElement)
-{
- if (nElement == 0) {
- // Mean reduction on empty tensors produces NaN
- *output = std::numeric_limits<double>::quiet_NaN();
- }
- if (*total_weight != 0) {
- *output = THCNumerics<T>::div(*output, *total_weight);
- }
-}
-
-template <typename T>
-C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
-__global__ void cunn_SpatialClassNLLCriterion_updateGradInput_kernel(
- T* gradInput,
- T* gradOutput,
- THCIndex_t* target,
- T* weights,
- T* total_weight,
- int size_average,
- int batch_size,
- int n_classes,
- int map_nelem,
- int blocks_per_sample,
- int64_t ignore_index) {
- if (*total_weight <= 0)
- return;
-
- int i, t;
- T norm = size_average ? (ScalarConvert<int, T>::to(1) / *total_weight) : ScalarConvert<int, T>::to(1);
-
- int sample = blockIdx.x / blocks_per_sample;
- int step = blockDim.x * blocks_per_sample;
- int toffset = sample * map_nelem;
- int ioffset = sample * map_nelem * n_classes;
- for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
- i < map_nelem;
- i += step) {
- t = (int)target[toffset + i];
- if (t != ignore_index) {
- CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
- gradInput[ioffset + i + map_nelem * t] = -(weights ? weights[t] : ScalarConvert<int, T>::to(1)) * norm * gradOutput[0];
- }
- }
-}
-
-#include <THCUNN/generic/SpatialClassNLLCriterion.cu>
-#include <THC/THCGenerateFloatTypes.h>
-
-#include <THCUNN/generic/SpatialClassNLLCriterion.cu>
-#include <THC/THCGenerateBFloat16Type.h>
+++ /dev/null
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THCUNN/generic/SpatialClassNLLCriterion.cu"
-#else
-
-void THNN_(SpatialClassNLLCriterion_shapeCheck)(
- THCState *state,
- THCTensor *input,
- THCIndexTensor *target,
- THCTensor *weights)
-{
- TORCH_CHECK(target->dim() == 3, 1,
- "only batches of spatial targets supported (3D tensors)" \
- " but got targets of size: : ", target->sizes());
- TORCH_CHECK(input->dim() == 4, 2,
- "only batches of spatial inputs supported (4D tensors), " \
- "but got input of size: ", input->sizes());
- if (THCTensor_(size)(state, input, 0) != THCIndexTensor_(size)(state, target, 0) ||
- THCTensor_(size)(state, input, 2) != THCIndexTensor_(size)(state, target, 1) ||
- THCTensor_(size)(state, input, 3) != THCIndexTensor_(size)(state, target, 2)) {
- THCDescBuff input_size = THCTensor_(sizeDesc)(state, input);
- THCDescBuff target_size = THCIndexTensor_(sizeDesc)(state, target);
- THError("input and target batch or spatial sizes don't match: target %s, input %s",
- target_size.str, input_size.str);
- }
-
- if (weights && THCTensor_(nElement)(state, weights) != THCTensor_(size)(state, input, 1)) {
- THError("weight tensor should be defined either for all or no classes");
- }
-}
-
-static void THNN_(SpatialClassNLLCriterion_gradOutput_no_reduce_shapeCheck)(
- THCState *state,
- THCTensor *gradOutput,
- THCIndexTensor *target)
-{
- TORCH_CHECK(THCTensor_(nDimensionLegacyNoScalars)(state, gradOutput) == 3, 2,
- "gradOutput must have same dimension as target (3) but got dimension: ", gradOutput->sizes());
- if (THCTensor_(size)(state, gradOutput, 0) != THCIndexTensor_(size)(state, target, 0) ||
- THCTensor_(size)(state, gradOutput, 1) != THCIndexTensor_(size)(state, target, 1) ||
- THCTensor_(size)(state, gradOutput, 2) != THCIndexTensor_(size)(state, target, 2)) {
- THCDescBuff gradOutput_size = THCTensor_(sizeDesc)(state, gradOutput);
- THCDescBuff target_size = THCIndexTensor_(sizeDesc)(state, target);
- THError("gradOutput sizes don't match target sizes: target %s, gradOutput %s",
- target_size.str, gradOutput_size.str);
- }
-}
-
-void THNN_(SpatialClassNLLCriterion_updateOutput)(
- THCState *state,
- THCTensor *input,
- THCIndexTensor *target,
- THCTensor *output,
- int64_t reduction,
- THCTensor *weights,
- THCTensor *total_weight,
- int64_t ignore_index)
-{
- // See Note [Writing Nondeterministic Operations]
- // Nondeterministic because of atomicAdd usage
- at::globalContext().alertNotDeterministic("SpatialClassNLLCriterion_updateOutput");
- THNN_(SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights);
- THCTensor_(resize0d)(state, output);
- THCTensor_(resize0d)(state, total_weight);
-
- if (weights)
- THCUNN_assertSameGPU(state, 5, input, target, weights, output, total_weight);
- else
- THCUNN_assertSameGPU(state, 4, input, target, output, total_weight);
-
- if (reduction == at::Reduction::None) {
- int64_t batch_size = THCTensor_(size)(state, input, 0);
- int64_t H = THCTensor_(size)(state, input, 2);
- int64_t W = THCTensor_(size)(state, input, 3);
- int64_t count = batch_size * H * W;
-
- THCTensor_(resize3d)(state, output, batch_size, H, W);
-
- if (count == 0) {
- // This guards from unnecessary operations and launching CUDA kernel with 0 blocks.
- return;
- }
- if (weights) {
- weights = THCTensor_(newContiguous)(state, weights);
- }
-
- SpatialClassNLLCriterion_updateOutput_no_reduce_kernel<scalar_t>
- <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream()>>>(
- count,
- toDeviceTensor<scalar_t, 4>(state, input),
- toDeviceTensor<THCIndex_t, 3>(state, target),
- toDeviceTensor<scalar_t, 3>(state, output),
- weights ? THCTensor_(data)(state, weights) : NULL,
- ignore_index);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
-
- if (weights) {
- THCTensor_(free)(state, weights);
- }
- return;
- }
-
- input = THCTensor_(newContiguous)(state, input);
- weights = weights ? THCTensor_(newContiguous)(state, weights) : NULL;
- target = THCIndexTensor_(newContiguous)(state, target);
-
- scalar_t *input_data = THCTensor_(data)(state, input);
- scalar_t *weights_data = weights ? THCTensor_(data)(state, weights) : NULL;
- THCIndex_t *target_data = THCIndexTensor_(data)(state, target);
- scalar_t *output_data = THCTensor_(data)(state, output);
- scalar_t *total_weight_data = THCTensor_(data)(state, total_weight);
- THCTensor_(fill)(state, output, ScalarConvert<int, scalar_t>::to(0));
- THCTensor_(fill)(state, total_weight, ScalarConvert<int, scalar_t>::to(0));
-
- THCIndex_t batch_size = THCIndexTensor_(size)(state, target, 0);
- if (batch_size != 0) { // This guards from unnecessary operations and launching CUDA kernel with 0 blocks.
- THCIndex_t map_nelem = THCIndexTensor_(nElement)(state, target) / batch_size;
- int blocks_per_sample = GET_BLOCKS(map_nelem) / 128;
- blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
- int total_blocks = blocks_per_sample * batch_size;
-
- cunn_SpatialClassNLLCriterion_updateOutput_kernel<scalar_t, accreal>
- <<<total_blocks, CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream()>>>(
- output_data,
- total_weight_data,
- input_data,
- target_data,
- weights_data,
- reduction == at::Reduction::Mean,
- THCTensor_(size)(state, input, 0),
- THCTensor_(size)(state, input, 1),
- THCTensor_(size)(state, input, 2) * THCTensor_(size)(state, input, 3),
- blocks_per_sample,
- ignore_index
- );
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- if (reduction == at::Reduction::Mean) {
- cunn_SpatialClassNLLCriterion_sizeAverage_kernel<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(
- output_data, total_weight_data, THCTensor_(nElement)(state, input)
- );
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
-
- if (weights)
- THCTensor_(free)(state, weights);
- THCIndexTensor_(free)(state, target);
- THCTensor_(free)(state, input);
-}
-
-void THNN_(SpatialClassNLLCriterion_updateGradInput)(
- THCState *state,
- THCTensor *input,
- THCIndexTensor *target,
- THCTensor *gradOutput,
- THCTensor *gradInput,
- int64_t reduction,
- THCTensor *weights,
- THCTensor *total_weight,
- int64_t ignore_index)
-{
- THNN_(SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights);
- THCTensor_(resizeAs)(state, gradInput, input);
- THCTensor_(zero)(state, gradInput);
- THArgCheck(THCTensor_(isContiguous)(state, gradInput), 4,
- "gradInput must be contiguous");
-
- if (weights)
- THCUNN_assertSameGPU(state, 5, weights, input, target, gradInput, total_weight);
- else
- THCUNN_assertSameGPU(state, 4, input, target, gradInput, total_weight);
-
- if (reduction == at::Reduction::None) {
- THNN_(SpatialClassNLLCriterion_gradOutput_no_reduce_shapeCheck)(
- state,
- gradOutput,
- target);
-
- int64_t batch_size = THCTensor_(size)(state, input, 0);
- int64_t H = THCTensor_(size)(state, input, 2);
- int64_t W = THCTensor_(size)(state, input, 3);
- int64_t count = batch_size * H * W;
-
- if (count == 0) {
- // This guards from unnecessary operations and launching CUDA kernel with 0 blocks.
- return;
- }
- if (weights) {
- weights = THCTensor_(newContiguous)(state, weights);
- }
-
- SpatialClassNLLCriterion_updateGradInput_no_reduce_kernel<scalar_t>
- <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream()>>>(
- count,
- toDeviceTensor<THCIndex_t, 3>(state, target),
- toDeviceTensor<scalar_t, 3>(state, gradOutput),
- toDeviceTensor<scalar_t, 4>(state, gradInput),
- weights ? THCTensor_(data)(state, weights) : NULL,
- ignore_index);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
-
- if (weights) {
- THCTensor_(free)(state, weights);
- }
- return;
- }
-
- input = THCTensor_(newContiguous)(state, input);
- weights = weights ? THCTensor_(newContiguous)(state, weights) : NULL;
- target = THCIndexTensor_(newContiguous)(state, target);
-
- scalar_t *gradOutput_data = THCTensor_(data)(state, gradOutput);
- scalar_t *weights_data = weights ? THCTensor_(data)(state, weights) : NULL;
- scalar_t *gradInput_data = THCTensor_(data)(state, gradInput);
- THCIndex_t *target_data = THCIndexTensor_(data)(state, target);
- scalar_t *total_weight_data = THCTensor_(data)(state, total_weight);
-
- THCIndex_t batch_size = THCIndexTensor_(size)(state, target, 0);
- if (batch_size != 0) { // This guards from unnecessary operations and launching CUDA kernel with 0 blocks.
- THCIndex_t map_nelem = THCIndexTensor_(nElement)(state, target) / batch_size;
- int blocks_per_sample = GET_BLOCKS(map_nelem) / 128;
- blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
- int total_blocks = blocks_per_sample * batch_size;
-
- cunn_SpatialClassNLLCriterion_updateGradInput_kernel
- <<<total_blocks, CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream()>>>(
- gradInput_data,
- gradOutput_data,
- target_data,
- weights_data,
- total_weight_data,
- reduction == at::Reduction::Mean,
- THCTensor_(size)(state, input, 0),
- THCTensor_(size)(state, input, 1),
- THCTensor_(size)(state, input, 2) *THCTensor_(size)(state, input, 3),
- blocks_per_sample,
- ignore_index
- );
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
-
- if (weights)
- THCTensor_(free)(state, weights);
- THCIndexTensor_(free)(state, target);
- THCTensor_(free)(state, input);
-}
-
-#endif
THCTensor* weights, // [OPTIONAL]
accreal margin);
-TORCH_CUDA_CU_API void THNN_(SpatialClassNLLCriterion_updateOutput)(
- THCState* state,
- THCTensor* input,
- THCIndexTensor* target,
- THCTensor* output,
- int64_t reduction,
- THCTensor* weights, // [OPTIONAL]
- THCTensor* total_weight,
- int64_t ignore_index);
-
-TORCH_CUDA_CU_API void THNN_(SpatialClassNLLCriterion_updateGradInput)(
- THCState* state,
- THCTensor* input,
- THCIndexTensor* target,
- THCTensor* gradOutput,
- THCTensor* gradInput,
- int64_t reduction,
- THCTensor* weights, // [OPTIONAL]
- THCTensor* total_weight,
- int64_t ignore_index);
-
TORCH_CUDA_CU_API void THNN_(SpatialConvolutionMM_updateOutput)(
THCState* state,
THCTensor* input,
input = torch.randn(2, 3, 5, 5, device=device)
target = torch.rand(2, 5, 5, device=device).mul(3).floor().long()
- @expectedAlertNondeterministic('SpatialClassNLLCriterion_updateOutput', 'cuda')
+ @expectedAlertNondeterministic('nll_loss2d_forward_out_cuda_template', 'cuda')
def forward_func(slf, device):
module(input, target)