From 07b00fc3249d1c6f842eae80ead63e5e4bce62aa Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 12 Aug 2021 18:05:29 -0700 Subject: [PATCH] ENH Migrate nll_loss2d from THC to ATen (#62826) Summary: Fixes https://github.com/pytorch/pytorch/issues/24608 Fixes https://github.com/pytorch/pytorch/issues/24607 With the following benchmark, the backward pass runs a little slower. This is strange since the implementation should be exactly the same.
Benchmark script ```python from itertools import product import torch import torch.nn as nn import torch.nn.functional as F import time torch.manual_seed(0) MS_PER_SECOND = 1000 def _time(): torch.cuda.synchronize() return time.perf_counter() * MS_PER_SECOND device = "cuda" C = 3 n_runs = 30 reductions = ["none", "sum", "mean"] Ns = [128, 256, 512] Hs = [128, 256, 512] for reduction, N, H in product(reductions, Ns, Hs): total_fwd_time = 0 total_back_time = 0 if reduction == "none": grad_out = torch.randn(N, H, H, device=device) else: grad_out = torch.randn(1)[0] for _ in range(n_runs): input = torch.randn(N, C, H, H, device=device, requires_grad=True) target = torch.rand(N, H, H, device=device).mul(3).floor().long() # forward start = _time() result = F.nll_loss(input, target, reduction=reduction) total_fwd_time += _time() - start result = F.nll_loss(input, target, reduction=reduction) for _ in range(n_runs): # backward start = _time() result.backward(grad_out, retain_graph=True) total_back_time += _time() - start fwd_avg = total_fwd_time / n_runs bwd_avg = total_back_time / n_runs print( f"input size({N}, {C}, {H}, {H}), reduction: {reduction}, fwd: {fwd_avg:.2f} (ms), back: {bwd_avg:.2f} (ms)" ) ```
master results ``` input size(128, 3, 128, 128), reduction: none, fwd: 0.34 (ms), back: 0.57 (ms) input size(128, 3, 256, 256), reduction: none, fwd: 2.56 (ms), back: 3.85 (ms) input size(128, 3, 512, 512), reduction: none, fwd: 14.54 (ms), back: 16.62 (ms) input size(256, 3, 128, 128), reduction: none, fwd: 1.26 (ms), back: 1.78 (ms) input size(256, 3, 256, 256), reduction: none, fwd: 7.07 (ms), back: 8.22 (ms) input size(256, 3, 512, 512), reduction: none, fwd: 29.38 (ms), back: 33.29 (ms) input size(512, 3, 128, 128), reduction: none, fwd: 3.41 (ms), back: 4.05 (ms) input size(512, 3, 256, 256), reduction: none, fwd: 14.32 (ms), back: 16.46 (ms) input size(512, 3, 512, 512), reduction: none, fwd: 59.20 (ms), back: 66.68 (ms) input size(128, 3, 128, 128), reduction: sum, fwd: 0.08 (ms), back: 0.21 (ms) input size(128, 3, 256, 256), reduction: sum, fwd: 0.21 (ms), back: 0.73 (ms) input size(128, 3, 512, 512), reduction: sum, fwd: 0.82 (ms), back: 2.86 (ms) input size(256, 3, 128, 128), reduction: sum, fwd: 0.12 (ms), back: 0.39 (ms) input size(256, 3, 256, 256), reduction: sum, fwd: 0.42 (ms), back: 1.45 (ms) input size(256, 3, 512, 512), reduction: sum, fwd: 1.53 (ms), back: 5.66 (ms) input size(512, 3, 128, 128), reduction: sum, fwd: 0.21 (ms), back: 0.74 (ms) input size(512, 3, 256, 256), reduction: sum, fwd: 0.78 (ms), back: 2.86 (ms) input size(512, 3, 512, 512), reduction: sum, fwd: 2.98 (ms), back: 11.23 (ms) input size(128, 3, 128, 128), reduction: mean, fwd: 0.07 (ms), back: 0.21 (ms) input size(128, 3, 256, 256), reduction: mean, fwd: 0.21 (ms), back: 0.73 (ms) input size(128, 3, 512, 512), reduction: mean, fwd: 0.82 (ms), back: 2.86 (ms) input size(256, 3, 128, 128), reduction: mean, fwd: 0.13 (ms), back: 0.39 (ms) input size(256, 3, 256, 256), reduction: mean, fwd: 0.42 (ms), back: 1.45 (ms) input size(256, 3, 512, 512), reduction: mean, fwd: 1.54 (ms), back: 5.65 (ms) input size(512, 3, 128, 128), reduction: mean, fwd: 0.22 (ms), back: 0.74 (ms) input size(512, 3, 256, 256), reduction: mean, fwd: 0.78 (ms), back: 2.87 (ms) input size(512, 3, 512, 512), reduction: mean, fwd: 2.98 (ms), back: 11.23 (ms) ```
PR results ``` input size(128, 3, 128, 128), reduction: none, fwd: 0.33 (ms), back: 0.59 (ms) input size(128, 3, 256, 256), reduction: none, fwd: 2.51 (ms), back: 3.92 (ms) input size(128, 3, 512, 512), reduction: none, fwd: 14.52 (ms), back: 17.05 (ms) input size(256, 3, 128, 128), reduction: none, fwd: 1.23 (ms), back: 1.85 (ms) input size(256, 3, 256, 256), reduction: none, fwd: 7.07 (ms), back: 8.45 (ms) input size(256, 3, 512, 512), reduction: none, fwd: 29.39 (ms), back: 34.21 (ms) input size(512, 3, 128, 128), reduction: none, fwd: 3.40 (ms), back: 4.18 (ms) input size(512, 3, 256, 256), reduction: none, fwd: 14.33 (ms), back: 16.90 (ms) input size(512, 3, 512, 512), reduction: none, fwd: 59.04 (ms), back: 68.36 (ms) input size(128, 3, 128, 128), reduction: sum, fwd: 0.07 (ms), back: 0.25 (ms) input size(128, 3, 256, 256), reduction: sum, fwd: 0.21 (ms), back: 0.86 (ms) input size(128, 3, 512, 512), reduction: sum, fwd: 0.82 (ms), back: 3.33 (ms) input size(256, 3, 128, 128), reduction: sum, fwd: 0.12 (ms), back: 0.46 (ms) input size(256, 3, 256, 256), reduction: sum, fwd: 0.42 (ms), back: 1.70 (ms) input size(256, 3, 512, 512), reduction: sum, fwd: 1.53 (ms), back: 6.58 (ms) input size(512, 3, 128, 128), reduction: sum, fwd: 0.21 (ms), back: 0.87 (ms) input size(512, 3, 256, 256), reduction: sum, fwd: 0.78 (ms), back: 3.34 (ms) input size(512, 3, 512, 512), reduction: sum, fwd: 2.98 (ms), back: 13.07 (ms) input size(128, 3, 128, 128), reduction: mean, fwd: 0.07 (ms), back: 0.26 (ms) input size(128, 3, 256, 256), reduction: mean, fwd: 0.21 (ms), back: 0.86 (ms) input size(128, 3, 512, 512), reduction: mean, fwd: 0.82 (ms), back: 3.34 (ms) input size(256, 3, 128, 128), reduction: mean, fwd: 0.12 (ms), back: 0.46 (ms) input size(256, 3, 256, 256), reduction: mean, fwd: 0.42 (ms), back: 1.72 (ms) input size(256, 3, 512, 512), reduction: mean, fwd: 1.53 (ms), back: 6.60 (ms) input size(512, 3, 128, 128), reduction: mean, fwd: 0.21 (ms), back: 0.87 (ms) input size(512, 3, 256, 256), reduction: mean, fwd: 0.78 (ms), back: 3.33 (ms) input size(512, 3, 512, 512), reduction: mean, fwd: 2.98 (ms), back: 13.07 (ms) ```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62826 Reviewed By: bdhirsh Differential Revision: D30282279 Pulled By: ngimel fbshipit-source-id: 4aa0ff3f8af0632957417931d332ec486a12b52d --- BUILD.bazel | 2 +- aten/src/ATen/LegacyTHFunctionsCUDA.h | 4 - aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp | 200 -------- aten/src/ATen/native/cuda/NLLLoss2d.cu | 523 +++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 8 +- aten/src/THCUNN/CMakeLists.txt | 1 - aten/src/THCUNN/SpatialClassNLLCriterion.cu | 176 ------- .../src/THCUNN/generic/SpatialClassNLLCriterion.cu | 247 ---------- aten/src/THCUNN/generic/THCUNN.h | 21 - test/test_torch.py | 2 +- 10 files changed, 529 insertions(+), 655 deletions(-) create mode 100644 aten/src/ATen/native/cuda/NLLLoss2d.cu delete mode 100644 aten/src/THCUNN/SpatialClassNLLCriterion.cu delete mode 100644 aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu diff --git a/BUILD.bazel b/BUILD.bazel index 9122ec4..ca8874d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -420,7 +420,6 @@ filegroup( "aten/src/THCUNN/SoftMarginCriterion.cu.cc", "aten/src/THCUNN/SoftPlus.cu.cc", "aten/src/THCUNN/SoftShrink.cu.cc", - "aten/src/THCUNN/SpatialClassNLLCriterion.cu.cc", "aten/src/THCUNN/SpatialConvolutionMM.cu.cc", "aten/src/THCUNN/Tanh.cu.cc", ], @@ -472,6 +471,7 @@ filegroup( "aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu.cc", "aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu.cc", "aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu.cc", + "aten/src/ATen/native/cuda/NLLLoss2d.cu.cc", "aten/src/ATen/native/cuda/Normalization.cu.cc", "aten/src/ATen/native/cuda/PointwiseOpsKernel.cu.cc", "aten/src/ATen/native/cuda/PowKernel.cu.cc", diff --git a/aten/src/ATen/LegacyTHFunctionsCUDA.h b/aten/src/ATen/LegacyTHFunctionsCUDA.h index 6a5c9ef..5670f31 100644 --- a/aten/src/ATen/LegacyTHFunctionsCUDA.h +++ b/aten/src/ATen/LegacyTHFunctionsCUDA.h @@ -23,10 +23,6 @@ std::tuple _th_gels(const Tensor & self, const Tensor & A); Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper); Tensor _th_potri(const Tensor & self, bool upper); Tensor & _th_copy_ignoring_overlaps_(Tensor & self, const Tensor & src); -std::tuple _thnn_nll_loss2d_forward_out(const Tensor & self, const Tensor & target, const c10::optional& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output, Tensor & total_weight); -std::tuple _thnn_nll_loss2d_forward(const Tensor & self, const Tensor & target, const optional & weight, int64_t reduction, int64_t ignore_index); -Tensor & _thnn_nll_loss2d_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional& weight_opt, int64_t reduction, int64_t ignore_index, const Tensor & total_weight, Tensor & grad_input); -Tensor _thnn_nll_loss2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const optional & weight, int64_t reduction, int64_t ignore_index, const Tensor & total_weight); Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, const Scalar& lower, const Scalar& upper, bool training); std::tuple _thnn_conv2d_forward_out(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const c10::optional& bias_opt, IntArrayRef stride, IntArrayRef padding, Tensor & output, Tensor & columns, Tensor & ones); std::tuple _thnn_conv2d_forward(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const optional & bias, IntArrayRef stride, IntArrayRef padding); diff --git a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp index 9ad1d44..30c61a3 100644 --- a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp +++ b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp @@ -175,206 +175,6 @@ Tensor & _th_copy_ignoring_overlaps_(Tensor & self, const Tensor & src) { } return self; } -std::tuple _thnn_nll_loss2d_forward_out(const Tensor & self, const Tensor & target, const c10::optional& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output, Tensor & total_weight) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - const OptionalDeviceGuard device_guard(device_of(self)); - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto output_ = checked_dense_tensor_unwrap(output, "output", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaDoubleSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto output_ = checked_dense_tensor_unwrap(output, "output", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto output_ = checked_dense_tensor_unwrap(output, "output", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaHalfSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::BFloat16: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto output_ = checked_dense_tensor_unwrap(output, "output", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 5, "_thnn_nll_loss2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaBFloat16SpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - default: - AT_ERROR("_thnn_nll_loss2d_forward_out not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(output, total_weight); -} -std::tuple _thnn_nll_loss2d_forward(const Tensor & self, const Tensor & target, const c10::optional& weight_opt, int64_t reduction, int64_t ignore_index) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - const OptionalDeviceGuard device_guard(device_of(self)); - auto dispatch_scalar_type = infer_scalar_type(self); - auto output_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto output = Tensor(c10::intrusive_ptr::reclaim(output_)); - auto total_weight_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto total_weight = Tensor(c10::intrusive_ptr::reclaim(total_weight_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward", true, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaDoubleSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward", true, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward", true, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaHalfSpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::BFloat16: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 2, "_thnn_nll_loss2d_forward", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_nll_loss2d_forward", true, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaBFloat16SpatialClassNLLCriterion_updateOutput(globalContext().getTHCState(), self_, target_, output_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - default: - AT_ERROR("_thnn_nll_loss2d_forward not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(output, total_weight); -} -Tensor & _thnn_nll_loss2d_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional& weight_opt, int64_t reduction, int64_t ignore_index, const Tensor & total_weight, Tensor & grad_input) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - const OptionalDeviceGuard device_guard(device_of(self)); - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaDoubleSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::Float: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::Half: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaHalfSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::BFloat16: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaBFloat16SpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - default: - AT_ERROR("_thnn_nll_loss2d_backward_out not supported on CUDAType for ", dispatch_scalar_type); - } - return grad_input; -} -Tensor _thnn_nll_loss2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional& weight_opt, int64_t reduction, int64_t ignore_index, const Tensor & total_weight) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - const OptionalDeviceGuard device_guard(device_of(self)); - auto dispatch_scalar_type = infer_scalar_type(self); - auto grad_input_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto grad_input = Tensor(c10::intrusive_ptr::reclaim(grad_input_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward", true, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaDoubleSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::Float: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward", true, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::Half: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward", true, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaHalfSpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - case ScalarType::BFloat16: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, ScalarType::Long); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss2d_backward", true, DeviceType::CUDA, dispatch_scalar_type); - auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaBFloat16SpatialClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index); - break; - } - default: - AT_ERROR("_thnn_nll_loss2d_backward not supported on CUDAType for ", dispatch_scalar_type); - } - return grad_input; -} std::tuple _thnn_conv2d_forward_out(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const c10::optional& bias_opt, IntArrayRef stride, IntArrayRef padding, Tensor & output, Tensor & columns, Tensor & ones) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); diff --git a/aten/src/ATen/native/cuda/NLLLoss2d.cu b/aten/src/ATen/native/cuda/NLLLoss2d.cu new file mode 100644 index 0000000..d5ea3d6 --- /dev/null +++ b/aten/src/ATen/native/cuda/NLLLoss2d.cu @@ -0,0 +1,523 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +namespace { + +// Returns a contiguous tensor if the source tensor +// is defined. Otherwise returns the undefined +// source tensor unmodified. +inline Tensor optional_contiguous(const Tensor& source) { + return source.defined() ? source.contiguous() : source; +} + +// Returns the address of the first element of a tensor +// or nullptr if the tensor is undefined. +template +inline scalar_t* optional_data(const Tensor& source) { + return source.defined() ? source.data_ptr() : nullptr; +} + +using at::cuda::detail::CUDA_NUM_THREADS; +using at::cuda::detail::GET_BLOCKS; + +template +C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) +__global__ void nll_loss2d_forward_no_reduce_kernel( + int64_t n_threads, + PackedTensorAccessor64 input, + PackedTensorAccessor64 target, + PackedTensorAccessor64 output, + scalar_t* weight, + int64_t ignore_index +) { + int64_t batch_size = input.size(0); + int64_t H = input.size(2); + int64_t W = input.size(3); + + CUDA_KERNEL_LOOP(index, n_threads) { + const int64_t b = index % batch_size; + const int64_t h = (index / batch_size) % H; + const int64_t w = (index / (batch_size * H)) % W; + + int64_t cur_target = target[b][h][w]; + if (cur_target == ignore_index) { + output[b][h][w] = static_cast(0); + continue; + } + scalar_t value = input[b][cur_target][h][w]; + scalar_t cur_weight = weight != nullptr ? weight[cur_target] : static_cast(1); + output[b][h][w] = -value * cur_weight; + } +} + +template +C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) +__global__ void nll_loss2d_forward_kernel( + scalar_t* output, + scalar_t* total_weight, + scalar_t* input, + int64_t* target, + scalar_t* weight, + bool size_average, + int batch_size, + int n_classes, + int map_nelem, + int blocks_per_sample, + int64_t ignore_index) { + + scalar_t cur_weight; + accscalar_t input_sum = 0; + accscalar_t acc_weight = 0; + + int sample = blockIdx.x / blocks_per_sample; + int toffset = sample * map_nelem; + int ioffset = sample * map_nelem * n_classes; + int step = blockDim.x * blocks_per_sample; + for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x; + i < map_nelem; + i += step) { + int t = target[toffset + i]; + if (t != ignore_index) { + CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes); + cur_weight = weight != nullptr ? weight[t] : static_cast(1); + input_sum -= input[ioffset + i + map_nelem * t] * cur_weight; + acc_weight += cur_weight; + } + } + + __shared__ accscalar_t acc_weight_smem[CUDA_NUM_THREADS]; + __shared__ accscalar_t input_sum_smem[CUDA_NUM_THREADS]; + auto acc_weight_ = cuda_utils::BlockReduceSum(acc_weight, acc_weight_smem); + auto input_sum_ = cuda_utils::BlockReduceSum(input_sum, input_sum_smem); + + if (threadIdx.x == 0) { + gpuAtomicAdd(total_weight, static_cast(acc_weight_)); + gpuAtomicAdd(output, static_cast(input_sum_)); + } +} + +template +C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) +__global__ void nll_loss2d_forward_size_average_kernel( + scalar_t* output, + scalar_t* total_weight, + int n_elements +) { + if (n_elements == 0) { + // Mean reduction on empty tensors produces NaN + *output = std::numeric_limits::quiet_NaN(); + } + if (*total_weight != 0) { + *output /= *total_weight; + } +} + +template +C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) +__global__ void nll_loss2d_backward_no_reduce_kernel( + int64_t n_threads, + PackedTensorAccessor64 target, + PackedTensorAccessor64 grad_output, + PackedTensorAccessor64 grad_input, + scalar_t* weight, + int64_t ignore_index +) { + int64_t batch_size = target.size(0); + int64_t H = target.size(1); + int64_t W = target.size(2); + + CUDA_KERNEL_LOOP(index, n_threads) { + const int64_t b = index % batch_size; + const int64_t h = (index / batch_size) % H; + const int64_t w = (index / (batch_size * H)) % W; + + int64_t cur_target = target[b][h][w]; + if (cur_target == ignore_index) { + continue; + } + scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast(1)); + grad_input[b][cur_target][h][w] = value * grad_output[b][h][w]; + } +} + +template +C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) +__global__ void nll_loss2d_backward_kernel( + scalar_t* grad_input, + scalar_t* grad_output, + int64_t* target, + scalar_t* weight, + scalar_t* total_weight, + bool size_average, + int batch_size, + int n_classes, + int map_nelem, + int blocks_per_sample, + int64_t ignore_index +) { + if (*total_weight <= 0) { + return; + } + + scalar_t norm = size_average ? (static_cast(1) / *total_weight) : static_cast(1); + + int sample = blockIdx.x / blocks_per_sample; + int step = blockDim.x * blocks_per_sample; + int toffset = sample * map_nelem; + int ioffset = sample * map_nelem * n_classes; + for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x; + i < map_nelem; + i += step) { + int t = (int)target[toffset + i]; + if (t != ignore_index) { + CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes); + grad_input[ioffset + i + map_nelem * t] = -(weight != nullptr ? weight[t] : static_cast(1)) * norm * grad_output[0]; + } + } +} + +void check_inputs_nll_loss2d( + const Tensor& input, + const Tensor& target, + const Tensor& weight) { + TORCH_CHECK( + target.dim() == 3, + "only batches of spatial targets supported (3D tensors)" + " but got targets of size: : ", + target.sizes()); + TORCH_CHECK( + input.dim() == 4, + "only batches of spatial inputs supported (4D tensors), " + "but got input of size: ", + input.sizes()); + TORCH_CHECK( + !weight.defined() || weight.numel() == input.size(1), + "weight tensor should be defined either for all or no classes"); + + TORCH_CHECK( + input.size(0) == target.size(0) && input.size(2) == target.size(1) && + input.size(3) == target.size(2), + "input and target batch or spatial sizes don't match: target ", + target.sizes(), + ", input ", + input.sizes()); +} + +void nll_loss2d_forward_out_cuda_template( + Tensor& output, + Tensor& total_weight, + const Tensor& input, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + int64_t ignore_index) { + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage in 'sum' or 'mean' reductions. + if (reduction != at::Reduction::None) { + at::globalContext().alertNotDeterministic("nll_loss2d_forward_out_cuda_template"); + } + + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + check_inputs_nll_loss2d(input, target, weight); + total_weight.resize_({}); + + if (reduction == at::Reduction::None) { + int64_t batch_size = input.size(0); + int64_t H = input.size(2); + int64_t W = input.size(3); + int64_t count = batch_size * H * W; + + resize_output(output, {batch_size, H, W}); + if (count == 0) { + // This guards from unnecessary operations and launching CUDA kernel with + // 0 blocks. + return; + } + auto weight_ = optional_contiguous(weight); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss2d_forward_no_reduce_kernel", + [&] { + nll_loss2d_forward_no_reduce_kernel + <<>>( + count, + input.packed_accessor(), + target.packed_accessor(), + output.packed_accessor(), + optional_data(weight_), + ignore_index); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return; + } + + // produce scalar outputs for the reduction case + resize_output(output, {}); + + auto input_ = input.contiguous(); + auto weight_ = optional_contiguous(weight); + auto target_ = target.contiguous(); + + output.fill_(0); + total_weight.fill_(0); + + auto batch_size = target.size(0); + auto target_numel = target.numel(); + if (batch_size != 0 && target_numel != 0) { + // This guards from unnecessary operations and launching CUDA kernel with 0 + // blocks. launch kernel + int64_t map_nelem = target_numel / batch_size; + int blocks_per_sample = GET_BLOCKS(map_nelem) / 128; + blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample; + int total_blocks = blocks_per_sample * batch_size; + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss2d_forward_kernel", + [&] { + using accscalar_t = acc_type; + nll_loss2d_forward_kernel + <<>>( + output.data_ptr(), + total_weight.data_ptr(), + input_.data_ptr(), + target_.data_ptr(), + optional_data(weight_), + reduction == at::Reduction::Mean, + input_.size(0), + input_.size(1), + input_.size(2) * input_.size(3), + blocks_per_sample, + ignore_index); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + if (reduction == at::Reduction::Mean) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss2d_forward_size_average_kernel", + [&] { + nll_loss2d_forward_size_average_kernel + <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + output.data_ptr(), + total_weight.data_ptr(), + input_.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} + +void nll_loss2d_backward_out_cuda_template( + Tensor& grad_input, + const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + int64_t ignore_index, + const Tensor& total_weight) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + check_inputs_nll_loss2d(input, target, weight); + grad_input.resize_as_(input); + grad_input.zero_(); + TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); + TORCH_CHECK( + total_weight.numel() == 1, + "expected total_weight to be a single element tensor, got: ", + total_weight.sizes(), + " (", + total_weight.numel(), + " elements)"); + + + if (reduction == at::Reduction::None) { + TORCH_CHECK( + grad_output.dim() == 3, + "grad_output must have same dimension as target (3) but got dimension: ", + grad_output.sizes()); + TORCH_CHECK( + grad_output.size(0) == target.size(0) && + grad_output.size(1) == target.size(1) && + grad_output.size(2) == target.size(2), + "grad_output sizes don't match target sizes: target ", + target.sizes(), + ", grad_output ", + grad_output.sizes()) + int64_t batch_size = input.size(0); + int64_t H = input.size(2); + int64_t W = input.size(3); + int64_t count = batch_size * H * W; + + if (count == 0) { + // This guards from unnecessary operations and launching CUDA kernel with + // 0 blocks. + return; + } + auto weight_ = optional_contiguous(weight); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss2d_backward_no_reduce_kernel", + [&] { + nll_loss2d_backward_no_reduce_kernel + <<>>( + count, + target.packed_accessor(), + grad_output.packed_accessor(), + grad_input.packed_accessor(), + optional_data(weight_), + ignore_index); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return; + } + + int64_t batch_size = target.size(0); + auto target_numel = target.numel(); + if (batch_size != 0 && target_numel != 0) { + // This guards from unnecessary operations and launching CUDA kernel with 1 + // blocks. + auto target_ = target.contiguous(); + auto weight_ = optional_contiguous(weight); + + int64_t map_nelem = target_numel / batch_size; + int blocks_per_sample = GET_BLOCKS(map_nelem) / 128; + blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample; + int total_blocks = blocks_per_sample * batch_size; + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss2d_backward_kernel", + [&] { + nll_loss2d_backward_kernel + <<>>( + grad_input.data_ptr(), + grad_output.data_ptr(), + target_.data_ptr(), + optional_data(weight_), + total_weight.data_ptr(), + reduction == at::Reduction::Mean, + input.size(0), + input.size(1), + map_nelem, + blocks_per_sample, + ignore_index); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} +} // namespace + +std::tuple nll_loss2d_forward_out_cuda( + const Tensor& self, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + int64_t ignore_index, + Tensor& output, + Tensor& total_weight) { + nll_loss2d_forward_out_cuda_template( + output, total_weight, self, target, weight_opt, reduction, ignore_index); + return std::tuple(output, total_weight); +} + +std::tuple nll_loss2d_forward_cuda( + const Tensor& self, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + int64_t ignore_index) { + auto output = at::empty({0}, self.options()); + auto total_weight = at::empty({0}, self.options()); + nll_loss2d_forward_out_cuda_template( + output, total_weight, self, target, weight_opt, reduction, ignore_index); + return std::make_tuple(output, total_weight); +} + +Tensor& nll_loss2d_backward_out_cuda( + const Tensor& grad_output, + const Tensor& self, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + int64_t ignore_index, + const Tensor& total_weight, + Tensor& grad_input) { + nll_loss2d_backward_out_cuda_template( + grad_input, + grad_output, + self, + target, + weight_opt, + reduction, + ignore_index, + total_weight); + return grad_input; +} + +Tensor nll_loss2d_backward_cuda( + const Tensor& grad_output, + const Tensor& self, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + int64_t ignore_index, + const Tensor& total_weight) { + auto grad_input = at::empty_like(self); + nll_loss2d_backward_out_cuda_template( + grad_input, + grad_output, + self, + target, + weight_opt, + reduction, + ignore_index, + total_weight); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 80cdd20..b0d7131 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8339,25 +8339,25 @@ python_module: nn dispatch: CPU: nll_loss2d_forward_out_cpu - CUDA: legacy::cuda::_thnn_nll_loss2d_forward_out + CUDA: nll_loss2d_forward_out_cuda - func: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight) python_module: nn dispatch: CPU: nll_loss2d_forward_cpu - CUDA: legacy::cuda::_thnn_nll_loss2d_forward + CUDA: nll_loss2d_forward_cuda - func: nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: nll_loss2d_backward_out_cpu - CUDA: legacy::cuda::_thnn_nll_loss2d_backward_out + CUDA: nll_loss2d_backward_out_cuda - func: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor python_module: nn dispatch: CPU: nll_loss2d_backward_cpu - CUDA: legacy::cuda::_thnn_nll_loss2d_backward + CUDA: nll_loss2d_backward_cuda - func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt index 3361bb1..5519727 100644 --- a/aten/src/THCUNN/CMakeLists.txt +++ b/aten/src/THCUNN/CMakeLists.txt @@ -1,5 +1,4 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} -${CMAKE_CURRENT_SOURCE_DIR}/SpatialClassNLLCriterion.cu ${CMAKE_CURRENT_SOURCE_DIR}/SpatialConvolutionMM.cu PARENT_SCOPE) diff --git a/aten/src/THCUNN/SpatialClassNLLCriterion.cu b/aten/src/THCUNN/SpatialClassNLLCriterion.cu deleted file mode 100644 index 21e4d78..0000000 --- a/aten/src/THCUNN/SpatialClassNLLCriterion.cu +++ /dev/null @@ -1,176 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -template -C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) -__global__ void SpatialClassNLLCriterion_updateOutput_no_reduce_kernel( - int64_t nthreads, - THCDeviceTensor input, - THCDeviceTensor target, - THCDeviceTensor output, - Dtype* weights, - int64_t ignore_index) { - int64_t batch_size = input.getSize(0); - int64_t H = input.getSize(2); - int64_t W = input.getSize(3); - - CUDA_KERNEL_LOOP(index, nthreads) { - const int64_t b = index % batch_size; - const int64_t h = (index / batch_size) % H; - const int64_t w = (index / (batch_size * H)) % W; - - int64_t cur_target = target[b][h][w]; - if (cur_target == ignore_index) { - output[b][h][w] = ScalarConvert::to(0); - continue; - } - Dtype value = input[b][cur_target][h][w]; - Dtype weight = - weights ? weights[cur_target] : ScalarConvert::to(1); - output[b][h][w] = -value * weight; - } -} - -template -C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) -__global__ void SpatialClassNLLCriterion_updateGradInput_no_reduce_kernel( - int64_t nthreads, - THCDeviceTensor target, - THCDeviceTensor gradOutput, - THCDeviceTensor gradInput, - Dtype* weights, - int64_t ignore_index) { - int64_t batch_size = target.getSize(0); - int64_t H = target.getSize(1); - int64_t W = target.getSize(2); - - CUDA_KERNEL_LOOP(index, nthreads) { - const int64_t b = index % batch_size; - const int64_t h = (index / batch_size) % H; - const int64_t w = (index / (batch_size * H)) % W; - - int64_t cur_target = target[b][h][w]; - if (cur_target == ignore_index) { - continue; - } - Dtype value = - -(weights ? weights[cur_target] : ScalarConvert::to(1)); - gradInput[b][cur_target][h][w] = value * gradOutput[b][h][w]; - } -} - -template -C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) -__global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel( - T* output, - T* total_weight, - T* input, - THCIndex_t* target, - T* weights, - int size_average, - int batch_size, - int n_classes, - int map_nelem, - int blocks_per_sample, - int64_t ignore_index) { - __shared__ AccumT partial_sums[CUDA_NUM_THREADS]; - - int i, t; - T cur_weight; - AccumT input_sum = 0; - AccumT acc_weight = 0; - - int sample = blockIdx.x / blocks_per_sample; - int toffset = sample * map_nelem; - int ioffset = sample * map_nelem * n_classes; - int step = blockDim.x * blocks_per_sample; - for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x; - i < map_nelem; - i += step) { - t = target[toffset + i]; - if (t != ignore_index) { - CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes); - cur_weight = weights ? weights[t] : ScalarConvert::to(1); - input_sum -= input[ioffset + i + map_nelem * t] * cur_weight; - acc_weight += cur_weight; - } - } - - input_sum = reduceBlock(partial_sums, blockDim.x, input_sum, thrust::plus(), AccumT(0)); - __syncthreads(); - acc_weight = reduceBlock(partial_sums, blockDim.x, acc_weight, thrust::plus(), AccumT(0)); - - if (threadIdx.x == 0) { - gpuAtomicAdd(total_weight, ScalarConvert::to(acc_weight)); - gpuAtomicAdd(output, ScalarConvert::to(input_sum)); - } -} - -template -__global__ void cunn_SpatialClassNLLCriterion_sizeAverage_kernel( - T *output, - T *total_weight, - int nElement) -{ - if (nElement == 0) { - // Mean reduction on empty tensors produces NaN - *output = std::numeric_limits::quiet_NaN(); - } - if (*total_weight != 0) { - *output = THCNumerics::div(*output, *total_weight); - } -} - -template -C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) -__global__ void cunn_SpatialClassNLLCriterion_updateGradInput_kernel( - T* gradInput, - T* gradOutput, - THCIndex_t* target, - T* weights, - T* total_weight, - int size_average, - int batch_size, - int n_classes, - int map_nelem, - int blocks_per_sample, - int64_t ignore_index) { - if (*total_weight <= 0) - return; - - int i, t; - T norm = size_average ? (ScalarConvert::to(1) / *total_weight) : ScalarConvert::to(1); - - int sample = blockIdx.x / blocks_per_sample; - int step = blockDim.x * blocks_per_sample; - int toffset = sample * map_nelem; - int ioffset = sample * map_nelem * n_classes; - for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x; - i < map_nelem; - i += step) { - t = (int)target[toffset + i]; - if (t != ignore_index) { - CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes); - gradInput[ioffset + i + map_nelem * t] = -(weights ? weights[t] : ScalarConvert::to(1)) * norm * gradOutput[0]; - } - } -} - -#include -#include - -#include -#include diff --git a/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu b/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu deleted file mode 100644 index 29eba1b..0000000 --- a/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu +++ /dev/null @@ -1,247 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THCUNN/generic/SpatialClassNLLCriterion.cu" -#else - -void THNN_(SpatialClassNLLCriterion_shapeCheck)( - THCState *state, - THCTensor *input, - THCIndexTensor *target, - THCTensor *weights) -{ - TORCH_CHECK(target->dim() == 3, 1, - "only batches of spatial targets supported (3D tensors)" \ - " but got targets of size: : ", target->sizes()); - TORCH_CHECK(input->dim() == 4, 2, - "only batches of spatial inputs supported (4D tensors), " \ - "but got input of size: ", input->sizes()); - if (THCTensor_(size)(state, input, 0) != THCIndexTensor_(size)(state, target, 0) || - THCTensor_(size)(state, input, 2) != THCIndexTensor_(size)(state, target, 1) || - THCTensor_(size)(state, input, 3) != THCIndexTensor_(size)(state, target, 2)) { - THCDescBuff input_size = THCTensor_(sizeDesc)(state, input); - THCDescBuff target_size = THCIndexTensor_(sizeDesc)(state, target); - THError("input and target batch or spatial sizes don't match: target %s, input %s", - target_size.str, input_size.str); - } - - if (weights && THCTensor_(nElement)(state, weights) != THCTensor_(size)(state, input, 1)) { - THError("weight tensor should be defined either for all or no classes"); - } -} - -static void THNN_(SpatialClassNLLCriterion_gradOutput_no_reduce_shapeCheck)( - THCState *state, - THCTensor *gradOutput, - THCIndexTensor *target) -{ - TORCH_CHECK(THCTensor_(nDimensionLegacyNoScalars)(state, gradOutput) == 3, 2, - "gradOutput must have same dimension as target (3) but got dimension: ", gradOutput->sizes()); - if (THCTensor_(size)(state, gradOutput, 0) != THCIndexTensor_(size)(state, target, 0) || - THCTensor_(size)(state, gradOutput, 1) != THCIndexTensor_(size)(state, target, 1) || - THCTensor_(size)(state, gradOutput, 2) != THCIndexTensor_(size)(state, target, 2)) { - THCDescBuff gradOutput_size = THCTensor_(sizeDesc)(state, gradOutput); - THCDescBuff target_size = THCIndexTensor_(sizeDesc)(state, target); - THError("gradOutput sizes don't match target sizes: target %s, gradOutput %s", - target_size.str, gradOutput_size.str); - } -} - -void THNN_(SpatialClassNLLCriterion_updateOutput)( - THCState *state, - THCTensor *input, - THCIndexTensor *target, - THCTensor *output, - int64_t reduction, - THCTensor *weights, - THCTensor *total_weight, - int64_t ignore_index) -{ - // See Note [Writing Nondeterministic Operations] - // Nondeterministic because of atomicAdd usage - at::globalContext().alertNotDeterministic("SpatialClassNLLCriterion_updateOutput"); - THNN_(SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights); - THCTensor_(resize0d)(state, output); - THCTensor_(resize0d)(state, total_weight); - - if (weights) - THCUNN_assertSameGPU(state, 5, input, target, weights, output, total_weight); - else - THCUNN_assertSameGPU(state, 4, input, target, output, total_weight); - - if (reduction == at::Reduction::None) { - int64_t batch_size = THCTensor_(size)(state, input, 0); - int64_t H = THCTensor_(size)(state, input, 2); - int64_t W = THCTensor_(size)(state, input, 3); - int64_t count = batch_size * H * W; - - THCTensor_(resize3d)(state, output, batch_size, H, W); - - if (count == 0) { - // This guards from unnecessary operations and launching CUDA kernel with 0 blocks. - return; - } - if (weights) { - weights = THCTensor_(newContiguous)(state, weights); - } - - SpatialClassNLLCriterion_updateOutput_no_reduce_kernel - <<>>( - count, - toDeviceTensor(state, input), - toDeviceTensor(state, target), - toDeviceTensor(state, output), - weights ? THCTensor_(data)(state, weights) : NULL, - ignore_index); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - if (weights) { - THCTensor_(free)(state, weights); - } - return; - } - - input = THCTensor_(newContiguous)(state, input); - weights = weights ? THCTensor_(newContiguous)(state, weights) : NULL; - target = THCIndexTensor_(newContiguous)(state, target); - - scalar_t *input_data = THCTensor_(data)(state, input); - scalar_t *weights_data = weights ? THCTensor_(data)(state, weights) : NULL; - THCIndex_t *target_data = THCIndexTensor_(data)(state, target); - scalar_t *output_data = THCTensor_(data)(state, output); - scalar_t *total_weight_data = THCTensor_(data)(state, total_weight); - THCTensor_(fill)(state, output, ScalarConvert::to(0)); - THCTensor_(fill)(state, total_weight, ScalarConvert::to(0)); - - THCIndex_t batch_size = THCIndexTensor_(size)(state, target, 0); - if (batch_size != 0) { // This guards from unnecessary operations and launching CUDA kernel with 0 blocks. - THCIndex_t map_nelem = THCIndexTensor_(nElement)(state, target) / batch_size; - int blocks_per_sample = GET_BLOCKS(map_nelem) / 128; - blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample; - int total_blocks = blocks_per_sample * batch_size; - - cunn_SpatialClassNLLCriterion_updateOutput_kernel - <<>>( - output_data, - total_weight_data, - input_data, - target_data, - weights_data, - reduction == at::Reduction::Mean, - THCTensor_(size)(state, input, 0), - THCTensor_(size)(state, input, 1), - THCTensor_(size)(state, input, 2) * THCTensor_(size)(state, input, 3), - blocks_per_sample, - ignore_index - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - if (reduction == at::Reduction::Mean) { - cunn_SpatialClassNLLCriterion_sizeAverage_kernel<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>( - output_data, total_weight_data, THCTensor_(nElement)(state, input) - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - - if (weights) - THCTensor_(free)(state, weights); - THCIndexTensor_(free)(state, target); - THCTensor_(free)(state, input); -} - -void THNN_(SpatialClassNLLCriterion_updateGradInput)( - THCState *state, - THCTensor *input, - THCIndexTensor *target, - THCTensor *gradOutput, - THCTensor *gradInput, - int64_t reduction, - THCTensor *weights, - THCTensor *total_weight, - int64_t ignore_index) -{ - THNN_(SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights); - THCTensor_(resizeAs)(state, gradInput, input); - THCTensor_(zero)(state, gradInput); - THArgCheck(THCTensor_(isContiguous)(state, gradInput), 4, - "gradInput must be contiguous"); - - if (weights) - THCUNN_assertSameGPU(state, 5, weights, input, target, gradInput, total_weight); - else - THCUNN_assertSameGPU(state, 4, input, target, gradInput, total_weight); - - if (reduction == at::Reduction::None) { - THNN_(SpatialClassNLLCriterion_gradOutput_no_reduce_shapeCheck)( - state, - gradOutput, - target); - - int64_t batch_size = THCTensor_(size)(state, input, 0); - int64_t H = THCTensor_(size)(state, input, 2); - int64_t W = THCTensor_(size)(state, input, 3); - int64_t count = batch_size * H * W; - - if (count == 0) { - // This guards from unnecessary operations and launching CUDA kernel with 0 blocks. - return; - } - if (weights) { - weights = THCTensor_(newContiguous)(state, weights); - } - - SpatialClassNLLCriterion_updateGradInput_no_reduce_kernel - <<>>( - count, - toDeviceTensor(state, target), - toDeviceTensor(state, gradOutput), - toDeviceTensor(state, gradInput), - weights ? THCTensor_(data)(state, weights) : NULL, - ignore_index); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - if (weights) { - THCTensor_(free)(state, weights); - } - return; - } - - input = THCTensor_(newContiguous)(state, input); - weights = weights ? THCTensor_(newContiguous)(state, weights) : NULL; - target = THCIndexTensor_(newContiguous)(state, target); - - scalar_t *gradOutput_data = THCTensor_(data)(state, gradOutput); - scalar_t *weights_data = weights ? THCTensor_(data)(state, weights) : NULL; - scalar_t *gradInput_data = THCTensor_(data)(state, gradInput); - THCIndex_t *target_data = THCIndexTensor_(data)(state, target); - scalar_t *total_weight_data = THCTensor_(data)(state, total_weight); - - THCIndex_t batch_size = THCIndexTensor_(size)(state, target, 0); - if (batch_size != 0) { // This guards from unnecessary operations and launching CUDA kernel with 0 blocks. - THCIndex_t map_nelem = THCIndexTensor_(nElement)(state, target) / batch_size; - int blocks_per_sample = GET_BLOCKS(map_nelem) / 128; - blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample; - int total_blocks = blocks_per_sample * batch_size; - - cunn_SpatialClassNLLCriterion_updateGradInput_kernel - <<>>( - gradInput_data, - gradOutput_data, - target_data, - weights_data, - total_weight_data, - reduction == at::Reduction::Mean, - THCTensor_(size)(state, input, 0), - THCTensor_(size)(state, input, 1), - THCTensor_(size)(state, input, 2) *THCTensor_(size)(state, input, 3), - blocks_per_sample, - ignore_index - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - - if (weights) - THCTensor_(free)(state, weights); - THCIndexTensor_(free)(state, target); - THCTensor_(free)(state, input); -} - -#endif diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h index 5e017c6..87a6105 100644 --- a/aten/src/THCUNN/generic/THCUNN.h +++ b/aten/src/THCUNN/generic/THCUNN.h @@ -26,27 +26,6 @@ TORCH_CUDA_CU_API void THNN_(MultiMarginCriterion_updateGradInput)( THCTensor* weights, // [OPTIONAL] accreal margin); -TORCH_CUDA_CU_API void THNN_(SpatialClassNLLCriterion_updateOutput)( - THCState* state, - THCTensor* input, - THCIndexTensor* target, - THCTensor* output, - int64_t reduction, - THCTensor* weights, // [OPTIONAL] - THCTensor* total_weight, - int64_t ignore_index); - -TORCH_CUDA_CU_API void THNN_(SpatialClassNLLCriterion_updateGradInput)( - THCState* state, - THCTensor* input, - THCIndexTensor* target, - THCTensor* gradOutput, - THCTensor* gradInput, - int64_t reduction, - THCTensor* weights, // [OPTIONAL] - THCTensor* total_weight, - int64_t ignore_index); - TORCH_CUDA_CU_API void THNN_(SpatialConvolutionMM_updateOutput)( THCState* state, THCTensor* input, diff --git a/test/test_torch.py b/test/test_torch.py index 81cf0b8..6766d50 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -3947,7 +3947,7 @@ else: input = torch.randn(2, 3, 5, 5, device=device) target = torch.rand(2, 5, 5, device=device).mul(3).floor().long() - @expectedAlertNondeterministic('SpatialClassNLLCriterion_updateOutput', 'cuda') + @expectedAlertNondeterministic('nll_loss2d_forward_out_cuda_template', 'cuda') def forward_func(slf, device): module(input, target) -- 2.7.4