Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17539
Fix math::Set for large tensor
i-am-not-moving-c2-to-c10
Reviewed By: dzhulgakov, houseroad
Differential Revision:
D14240756
fbshipit-source-id:
0ade26790be41fb26d2cc193bfa3082c7bd4e69d
namespace caffe2 {
REGISTER_CPU_OPERATOR(ATen, ATenOp<CPUContext>);
-template<>
+template <>
at::Backend ATenOp<CPUContext>::backend() const {
return at::Backend::CPU;
}
OPERATOR_SCHEMA(ATen);
namespace math {
+
template <>
void Set<at::Half, CPUContext>(
- const int /*N*/,
+ const std::int64_t /* N */,
const at::Half h,
at::Half* v,
CPUContext* c) {
- Set(0, h.x, (uint16_t*) v, c);
-}
+ Set(0, h.x, (uint16_t*)v, c);
}
-}
+} // namespace math
+
+} // namespace caffe2
// Eigen or via custom code.
////////////////////////////////////////////////////////////////////////////////
-#define CAFFE2_SPECIALIZED_SET(T) \
- template <> \
- C10_EXPORT void Set<T, CPUContext>( \
- const int N, const T alpha, T* Y, CPUContext* /* context */) { \
- if (N == 0) { \
- return; \
- } \
- if (alpha == T(0)) { \
- std::memset(Y, 0, N * sizeof(T)); \
- } else { \
- EigenVectorArrayMap<T>(Y, N).setConstant(alpha); \
- } \
+#define CAFFE2_SPECIALIZED_SET(T) \
+ template <> \
+ C10_EXPORT void Set<T, CPUContext>( \
+ const std::int64_t N, const T alpha, T* Y, CPUContext* /* context */) { \
+ if (N == 0) { \
+ return; \
+ } \
+ if (alpha == T(0)) { \
+ std::memset(Y, 0, N * sizeof(T)); \
+ } else { \
+ EigenVectorArrayMap<T>(Y, N).setConstant(alpha); \
+ } \
}
CAFFE2_SPECIALIZED_SET(float)
CAFFE2_SPECIALIZED_SET(double)
#define CAFFE2_SPECIALIZED_CUDA_SET(T) \
template <> \
CAFFE2_CUDA_EXPORT void Set<T, CUDAContext>( \
- const int N, const T alpha, T* Y, CUDAContext* context) { \
+ const std::int64_t N, const T alpha, T* Y, CUDAContext* context) { \
if (N == 0) { \
return; \
} \
CAFFE2_API void Erf(int N, const T* X, T* Y, Context* context);
template <typename T, class Context>
-CAFFE2_API void Set(int N, T alpha, T* X, Context* context);
+CAFFE2_API void Set(std::int64_t N, T alpha, T* X, Context* context);
template <typename TAlpha, typename TData, class Context>
CAFFE2_API void