Fix math::Set for large tensor (#17539)
authorXiaomeng Yang <yangxm@fb.com>
Wed, 27 Feb 2019 20:18:52 +0000 (12:18 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 27 Feb 2019 20:34:58 +0000 (12:34 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17539

Fix math::Set for large tensor

i-am-not-moving-c2-to-c10

Reviewed By: dzhulgakov, houseroad

Differential Revision: D14240756

fbshipit-source-id: 0ade26790be41fb26d2cc193bfa3082c7bd4e69d

caffe2/contrib/aten/aten_op.cc
caffe2/utils/math/elementwise.cc
caffe2/utils/math/elementwise.cu
caffe2/utils/math/elementwise.h

index 2abc25a..aa8b7fe 100644 (file)
@@ -4,7 +4,7 @@
 namespace caffe2 {
 
 REGISTER_CPU_OPERATOR(ATen, ATenOp<CPUContext>);
-template<>
+template <>
 at::Backend ATenOp<CPUContext>::backend() const {
   return at::Backend::CPU;
 }
@@ -12,14 +12,16 @@ at::Backend ATenOp<CPUContext>::backend() const {
 OPERATOR_SCHEMA(ATen);
 
 namespace math {
+
 template <>
 void Set<at::Half, CPUContext>(
-    const int /*N*/,
+    const std::int64_t /* N */,
     const at::Half h,
     at::Half* v,
     CPUContext* c) {
-  Set(0, h.x, (uint16_t*) v, c);
-}
+  Set(0, h.x, (uint16_t*)v, c);
 }
 
-}
+} // namespace math
+
+} // namespace caffe2
index 08d723c..96e4644 100644 (file)
@@ -356,18 +356,18 @@ DELEGATE_SCALE(float, double, cblas_dscal)
 // Eigen or via custom code.
 ////////////////////////////////////////////////////////////////////////////////
 
-#define CAFFE2_SPECIALIZED_SET(T)                                    \
-  template <>                                                        \
-  C10_EXPORT void Set<T, CPUContext>(                                \
-      const int N, const T alpha, T* Y, CPUContext* /* context */) { \
-    if (N == 0) {                                                    \
-      return;                                                        \
-    }                                                                \
-    if (alpha == T(0)) {                                             \
-      std::memset(Y, 0, N * sizeof(T));                              \
-    } else {                                                         \
-      EigenVectorArrayMap<T>(Y, N).setConstant(alpha);               \
-    }                                                                \
+#define CAFFE2_SPECIALIZED_SET(T)                                             \
+  template <>                                                                 \
+  C10_EXPORT void Set<T, CPUContext>(                                         \
+      const std::int64_t N, const T alpha, T* Y, CPUContext* /* context */) { \
+    if (N == 0) {                                                             \
+      return;                                                                 \
+    }                                                                         \
+    if (alpha == T(0)) {                                                      \
+      std::memset(Y, 0, N * sizeof(T));                                       \
+    } else {                                                                  \
+      EigenVectorArrayMap<T>(Y, N).setConstant(alpha);                        \
+    }                                                                         \
   }
 CAFFE2_SPECIALIZED_SET(float)
 CAFFE2_SPECIALIZED_SET(double)
index 006fbd0..7509fb2 100644 (file)
@@ -34,7 +34,7 @@ __global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) {
 #define CAFFE2_SPECIALIZED_CUDA_SET(T)                                    \
   template <>                                                             \
   CAFFE2_CUDA_EXPORT void Set<T, CUDAContext>(                            \
-      const int N, const T alpha, T* Y, CUDAContext* context) {           \
+      const std::int64_t N, const T alpha, T* Y, CUDAContext* context) {  \
     if (N == 0) {                                                         \
       return;                                                             \
     }                                                                     \
index 32ef6be..904853c 100644 (file)
@@ -57,7 +57,7 @@ template <typename T, class Context>
 CAFFE2_API void Erf(int N, const T* X, T* Y, Context* context);
 
 template <typename T, class Context>
-CAFFE2_API void Set(int N, T alpha, T* X, Context* context);
+CAFFE2_API void Set(std::int64_t N, T alpha, T* X, Context* context);
 
 template <typename TAlpha, typename TData, class Context>
 CAFFE2_API void