Migrate legacy lstsq from THC to ATen (CUDA) (#63504)
authorPeter Bell <peterbell10@live.co.uk>
Tue, 24 Aug 2021 19:43:27 +0000 (12:43 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 19:47:16 +0000 (12:47 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63504

Closes gh-24592

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D30441304

Pulled By: ngimel

fbshipit-source-id: ec176596f54bc084af48a73d1dbb0dcb82fec593

BUILD.bazel
aten/src/ATen/LegacyTHFunctionsCUDA.h
aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/ATen/native/native_functions.yaml
aten/src/THC/CMakeLists.txt
aten/src/THC/THCTensorMath.h
aten/src/THC/THCTensorMathMagma.cpp
aten/src/THC/THCTensorMathMagma.h [deleted file]
aten/src/THC/generic/THCTensorMathMagma.cpp [deleted file]
aten/src/THC/generic/THCTensorMathMagma.h [deleted file]

index 5acbe40..afdd469 100644 (file)
@@ -393,7 +393,6 @@ filegroup(
         "aten/src/THC/THCTensor.cu.cc",
         "aten/src/THC/THCTensorCopy.cu.cc",
         "aten/src/THC/THCTensorMath.cu.cc",
-        "aten/src/THC/THCTensorMathMagma.cu.cc",
         "aten/src/THC/THCTensorMathPairwise.cu.cc",
         "aten/src/THC/THCTensorMathScan.cu.cc",
         "aten/src/THC/THCTensorScatterGather.cu.cc",
index 1a20e0b..41cbdd6 100644 (file)
@@ -18,12 +18,7 @@ namespace native {
 namespace legacy {
 namespace cuda {
 
-std::tuple<Tensor &,Tensor &> _th_gels_out(const Tensor & self, const Tensor & A, Tensor & res1, Tensor & res2);
-std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A);
-Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper);
-Tensor _th_potri(const Tensor & self, bool upper);
 Tensor & _th_copy_ignoring_overlaps_(Tensor & self, const Tensor & src);
-Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, const Scalar& lower, const Scalar& upper, bool training);
 
 } // namespace th
 } // namespace legacy
index 0ad6dc8..c4e9dfe 100644 (file)
@@ -39,83 +39,6 @@ namespace {
   }
 }
 
-std::tuple<Tensor &,Tensor &> _th_gels_out(const Tensor & self, const Tensor & A, Tensor & res1, Tensor & res2) {
-    TORCH_WARN_ONCE(
-      "torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.\n",
-      "torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in "
-      "the returned tuple (although it returns other information about the problem).\n",
-      "To get the qr decomposition consider using torch.linalg.qr.\n",
-      "The returned solution in torch.lstsq stored the residuals of the solution in the ",
-      "last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the ",
-      "residuals in the field 'residuals' of the returned named tuple.\n",
-      "The unpacking of the solution, as in\n",
-      "X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n",
-      "should be replaced with\n",
-      "X = torch.linalg.lstsq(A, B).solution"
-    );
-    // DeviceGuard omitted
-    auto dispatch_scalar_type = infer_scalar_type(self);
-
-    switch (dispatch_scalar_type) {
-        case ScalarType::Double: {
-            auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type);
-            auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type);
-            auto A_ = checked_dense_tensor_unwrap(A, "A", 2, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type);
-            THCudaDoubleTensor_gels(globalContext().getTHCState(), res1_, res2_, self_, A_);
-            break;
-        }
-        case ScalarType::Float: {
-            auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type);
-            auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type);
-            auto A_ = checked_dense_tensor_unwrap(A, "A", 2, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type);
-            THCudaTensor_gels(globalContext().getTHCState(), res1_, res2_, self_, A_);
-            break;
-        }
-        default:
-            AT_ERROR("_th_gels_out not supported on CUDAType for ", dispatch_scalar_type);
-    }
-    return std::tuple<Tensor &, Tensor &>(res1, res2);
-}
-std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A) {
-    TORCH_WARN_ONCE(
-      "torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.\n",
-      "torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in "
-      "the returned tuple (although it returns other information about the problem).\n",
-      "To get the qr decomposition consider using torch.linalg.qr.\n",
-      "The returned solution in torch.lstsq stored the residuals of the solution in the ",
-      "last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the ",
-      "residuals in the field 'residuals' of the returned named tuple.\n",
-      "The unpacking of the solution, as in\n",
-      "X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n",
-      "should be replaced with\n",
-      "X = torch.linalg.lstsq(A, B).solution"
-    );
-    // DeviceGuard omitted
-    auto dispatch_scalar_type = infer_scalar_type(self);
-    auto res1_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
-    auto res1 = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(res1_));
-    auto res2_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
-    auto res2 = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(res2_));
-    switch (dispatch_scalar_type) {
-        case ScalarType::Double: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_gels", false, DeviceType::CUDA, dispatch_scalar_type);
-            auto A_ = checked_dense_tensor_unwrap(A, "A", 2, "_th_gels", false, DeviceType::CUDA, dispatch_scalar_type);
-            THCudaDoubleTensor_gels(globalContext().getTHCState(), res1_, res2_, self_, A_);
-            break;
-        }
-        case ScalarType::Float: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_gels", false, DeviceType::CUDA, dispatch_scalar_type);
-            auto A_ = checked_dense_tensor_unwrap(A, "A", 2, "_th_gels", false, DeviceType::CUDA, dispatch_scalar_type);
-            THCudaTensor_gels(globalContext().getTHCState(), res1_, res2_, self_, A_);
-            break;
-        }
-        default:
-            AT_ERROR("_th_gels not supported on CUDAType for ", dispatch_scalar_type);
-    }
-    return std::tuple<Tensor, Tensor>(res1, res2);
-}
 Tensor & _th_copy_ignoring_overlaps_(Tensor & self, const Tensor & src) {
     // DeviceGuard omitted
     auto dispatch_scalar_type = infer_scalar_type(self);
index 0dae7a2..4e806f0 100644 (file)
@@ -3114,6 +3114,84 @@ void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& /*rank*/, Tensor& /*singul
 
 REGISTER_DISPATCH(lstsq_stub, &lstsq_kernel);
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ legacy_lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+std::tuple<Tensor, Tensor> legacy_lstsq_cuda(const Tensor &B, const Tensor &A) {
+  TORCH_WARN_ONCE(
+      "torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.\n",
+      "torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in "
+      "the returned tuple (although it returns other information about the problem).\n",
+      "To get the qr decomposition consider using torch.linalg.qr.\n",
+      "The returned solution in torch.lstsq stored the residuals of the solution in the ",
+      "last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the ",
+      "residuals in the field 'residuals' of the returned named tuple.\n",
+      "The unpacking of the solution, as in\n",
+      "X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n",
+      "should be replaced with\n",
+      "X = torch.linalg.lstsq(A, B).solution"
+    );
+
+#ifndef USE_MAGMA
+  TORCH_CHECK(false, "solve: MAGMA library not found in "
+              "compilation. Please rebuild with MAGMA.");
+#else
+  const auto dtype = A.scalar_type();
+  TORCH_CHECK(B.scalar_type() == dtype, "exepected A and B dtypes to match but found ",
+              dtype, " and ", B.scalar_type());
+  TORCH_CHECK(A.numel() > 0 && A.dim() == 2, "A should be (non-empty) 2 dimensional");
+  TORCH_CHECK(B.numel() > 0 && B.dim() == 2, "B should be (non-empty) 2 dimensional");
+  auto a_sizes = A.sizes();
+  auto b_sizes = B.sizes();
+  TORCH_CHECK(a_sizes[0] == b_sizes[0], "Expected A and b to have same size "
+      "at dim 0, but A has ", a_sizes[0], " rows and B has ", b_sizes[0], " rows");
+  TORCH_CHECK(a_sizes[0] >= a_sizes[1], "Expected A with shape (m x n) to have "
+      "m >= n. The case for m < n is not implemented yet.");
+
+  Tensor A_working = cloneBatchedColumnMajor(A);
+  Tensor B_working = cloneBatchedColumnMajor(B);
+
+  int64_t m = a_sizes[0];
+  int64_t n = a_sizes[1];
+  int64_t nrhs = b_sizes[1];
+
+  int info;
+  AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "legacy_lstsq_cuda", [&] {
+    scalar_t *a_data = A_working.data_ptr<scalar_t>();
+    scalar_t *b_data = B_working.data_ptr<scalar_t>();
+    scalar_t wkopt;
+    magmaGels(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info);
+
+    const auto hwork_size = static_cast<magma_int_t>(wkopt);
+    scalar_t *hwork = nullptr;
+    ALLOCATE_ARRAY(hwork, scalar_t, hwork_size);
+
+    magmaGels(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, hwork_size, &info);
+  });
+
+  TORCH_CHECK(info == 0, "MAGMA gels : Argument %d : illegal value", -info);
+  return std::tuple<Tensor, Tensor>(B_working, A_working);
+#endif  // USE_MAGMA
+}
+
+std::tuple<Tensor&, Tensor&> legacy_lstsq_out_cuda(
+    const Tensor& B, const Tensor& A, Tensor& B_out, Tensor& A_out) {
+  const auto dtype = A.scalar_type();
+  TORCH_CHECK(B.scalar_type() == dtype, "exepected A and B dtypes to match but found ",
+              A.scalar_type(), " and ", B.scalar_type());
+  TORCH_CHECK(A_out.scalar_type() == dtype, "A_out to have scalar type ", dtype,
+              " but found", A_out.scalar_type());
+  TORCH_CHECK(B_out.scalar_type() == dtype, "A_out to have scalar type ", dtype,
+              " but found", B_out.scalar_type());
+  Tensor A_tmp, B_tmp;
+  std::tie(B_tmp, A_tmp) = native::legacy_lstsq_cuda(B, A);
+  resize_output(A_out, A_tmp.sizes());
+  A_out.copy_(A_tmp);
+  resize_output(B_out, B_tmp.sizes());
+  B_out.copy_(B_tmp);
+  return std::tuple<Tensor&, Tensor&>(B_out, A_out);
+}
+
+
 }}  // namespace at::native
 
 #undef ALLOCATE_ARRAY
index 9bce764..4f7d7e6 100644 (file)
 - func: lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
   dispatch:
     CPU: legacy_lstsq_out
-    CUDA: legacy::cuda::_th_gels_out
+    CUDA: legacy_lstsq_out_cuda
 
 - func: lstsq(Tensor self, Tensor A) -> (Tensor solution, Tensor QR)
   variants: method, function
   dispatch:
     CPU: legacy_lstsq
-    CUDA: legacy::cuda::_th_gels
+    CUDA: legacy_lstsq_cuda
 
 - func: triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient)
   dispatch:
index 7865060..f34b040 100644 (file)
@@ -66,7 +66,6 @@ install(FILES
           THCNumerics.cuh
           THCTensorInfo.cuh
           THCTensorTypeUtils.cuh
-          THCTensorMathMagma.h
           THCThrustAllocator.cuh
           # See Note [TH abstraction violation]
           THCTensor.hpp
@@ -88,8 +87,6 @@ install(FILES
           generic/THCTensorCopy.h
           generic/THCTensorMath.h
           generic/THCTensorMath.cu
-          generic/THCTensorMathMagma.h
-          generic/THCTensorMathMagma.cpp
           generic/THCTensorMathPairwise.h
           generic/THCTensorMathPairwise.cu
           DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/THC/generic")
index 422a423..b70d4d1 100644 (file)
@@ -13,9 +13,6 @@
 #include <THC/generic/THCTensorMath.h>
 #include <THC/THCGenerateBFloat16Type.h>
 
-#include <THC/generic/THCTensorMathMagma.h>
-#include <THC/THCGenerateAllTypes.h>
-
 #include <THC/generic/THCTensorMathPairwise.h>
 #include <THC/THCGenerateAllTypes.h>
 
index ca0cc8a..4360753 100644 (file)
@@ -1,23 +1,10 @@
 #include <THC/THCGeneral.h>
-#include <THC/THCTensorMath.h>
-#include <THC/THCTensorCopy.h>
-#include <THC/THCTensorMathMagma.h>
-#include <THC/THCTensor.hpp>
-#include <THC/THCStorage.hpp>
-#include <algorithm>
-#include <ATen/native/cuda/MiscUtils.h>
 #include <ATen/cuda/detail/CUDAHooks.h>
 
 #ifdef USE_MAGMA
 #include <magma_v2.h>
 #endif
 
-#ifndef DIVUP
-#define DIVUP(x, y) (((x) + (y) - 1) / (y))
-#endif
-
-#define NoMagma(name) "No CUDA implementation of '" #name "'. Install MAGMA and rebuild cutorch (http://icl.cs.utk.edu/magma/)"
-
 namespace {
 void _THCMagma_init() {
 #ifdef USE_MAGMA
@@ -31,6 +18,3 @@ struct Initializer {
   };
 } initializer;
 } // anonymous namespace
-
-#include <THC/generic/THCTensorMathMagma.cpp>
-#include <THC/THCGenerateAllTypes.h>
diff --git a/aten/src/THC/THCTensorMathMagma.h b/aten/src/THC/THCTensorMathMagma.h
deleted file mode 100644 (file)
index 1fb5821..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-#ifndef THC_TENSOR_MATH_MAGMA_CUH
-#define THC_TENSOR_MATH_MAGMA_CUH
-
-#ifdef USE_MAGMA
-#include <magma_v2.h>
-#endif
-
-#ifdef USE_MAGMA
-template <typename T>
-static inline T* th_magma_malloc_pinned(size_t n)
-{
-  void* ptr;
-  if (MAGMA_SUCCESS != magma_malloc_pinned(&ptr, n * sizeof(T)))
-    THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", n/268435456);
-  return reinterpret_cast<T*>(ptr);
-}
-
-#endif
-
-#endif // THC_TENSOR_MATH_MAGMA_CUH
diff --git a/aten/src/THC/generic/THCTensorMathMagma.cpp b/aten/src/THC/generic/THCTensorMathMagma.cpp
deleted file mode 100644 (file)
index 0d94fc3..0000000
+++ /dev/null
@@ -1,83 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THC/generic/THCTensorMathMagma.cpp"
-#else
-
-#include <c10/cuda/CUDAException.h>
-
-#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
-
-static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, THCTensor *src)
-{
-  THAssert(src->dim() == 2);
-  if (self == src && self->stride(0) == 1 && self->stride(1) == self->size(0))
-  {
-    THCTensor_(retain)(state, self);
-    return self;
-  }
-
-  if (self == src)
-    self = THCTensor_(new)(state);
-  else
-    THCTensor_(retain)(state, self);
-
-  int64_t size[2] = { src->size(0), src->size(1) };
-  int64_t stride[2] = { 1, src->size(0) };
-
-  THCTensor_(resizeNd)(state, self, 2, size, stride);
-  THCTensor_(copy)(state, self, src);
-  return self;
-}
-
-void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_)
-{
-#ifdef USE_MAGMA
-  THArgCheck(!a_->is_empty() && a_->dim() == 2, 1, "A should be (non-empty) 2 dimensional");
-  THArgCheck(!b_->is_empty() && b_->dim() == 2, 1, "b should be (non-empty) 2 dimensional");
-  TORCH_CHECK(a_->size(0) == b_->size(0), "Expected A and b to have same size "
-      "at dim 0, but A has ", a_->size(0), " rows and B has ", b_->size(0), " rows");
-  THArgCheck(a_->size(0) >= a_->size(1), 2, "Expected A with shape (m x n) to have "
-      "m >= n. The case for m < n is not implemented yet.");
-
-  THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_);
-  THCTensor *b = THCTensor_(newColumnMajor)(state, rb_, b_);
-  scalar_t *a_data = THCTensor_(data)(state, a);
-  scalar_t *b_data = THCTensor_(data)(state, b);
-
-  int64_t m = a->size(0);
-  int64_t n = a->size(1);
-  int64_t nrhs = b->size(1);
-  scalar_t wkopt;
-
-  int info;
-  {
-    at::native::MagmaStreamSyncGuard guard;
-#if defined(THC_REAL_IS_FLOAT)
-    magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info);
-#else
-    magma_dgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info);
-#endif
-
-    scalar_t *hwork = th_magma_malloc_pinned<scalar_t>((size_t)wkopt);
-
-#if defined(THC_REAL_IS_FLOAT)
-    magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info);
-#else
-    magma_dgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info);
-#endif
-
-    magma_free_pinned(hwork);
-  }
-
-  if (info != 0)
-    THError("MAGMA gels : Argument %d : illegal value", -info);
-
-  THCTensor_(freeCopyTo)(state, a, ra_);
-  THCTensor_(freeCopyTo)(state, b, rb_);
-#else
-  THError(NoMagma(gels));
-#endif
-}
-
-#endif
-
-#endif
diff --git a/aten/src/THC/generic/THCTensorMathMagma.h b/aten/src/THC/generic/THCTensorMathMagma.h
deleted file mode 100644 (file)
index 585d02c..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THC/generic/THCTensorMathMagma.h"
-#else
-
-#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
-
-// MAGMA (i.e. CUDA implementation of LAPACK functions)
-TORCH_CUDA_CU_API void THCTensor_(gels)(
-    THCState* state,
-    THCTensor* rb_,
-    THCTensor* ra_,
-    THCTensor* b_,
-    THCTensor* a_);
-
-#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
-
-#endif