Replace Vec256<T>::size with constexpr method (#15406)
authorEdward Yang <ezyang@fb.com>
Thu, 20 Dec 2018 04:31:09 +0000 (20:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 04:33:45 +0000 (20:33 -0800)
Summary:
Stack:
&nbsp;&nbsp;&nbsp;&nbsp;:black_circle:&nbsp; **#15406 Replace Vec256<T>::size with constexpr method**&nbsp;&nbsp;[:yellow_heart:](https://our.intern.facebook.com/intern/diff/D13519902/)

See Note [constexpr static function to avoid odr-usage compiler bug]
for detailed justification.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15406

Differential Revision: D13523774

Pulled By: ezyang

fbshipit-source-id: c0ab44298bb2ef3d68a66d026fc6bc156a909a6b

12 files changed:
aten/src/ATen/cpu/vec256/functional.h
aten/src/ATen/cpu/vec256/vec256.h
aten/src/ATen/cpu/vec256/vec256_base.h
aten/src/ATen/cpu/vec256/vec256_double.h
aten/src/ATen/cpu/vec256/vec256_float.h
aten/src/ATen/cpu/vec256/vec256_int.h
aten/src/ATen/native/cpu/DistanceOpsKernel.cpp
aten/src/ATen/native/cpu/GridSamplerKernel.cpp
aten/src/ATen/native/cpu/Loops.h
aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
aten/src/ATen/native/cpu/SoftMaxKernel.cpp
aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

index fbc526f..82d1faa 100644 (file)
@@ -10,10 +10,10 @@ inline scalar_t vec_reduce_all(
     vec256::Vec256<scalar_t> acc_vec,
     int64_t size) {
   using Vec = vec256::Vec256<scalar_t>;
-  scalar_t acc_arr[Vec::size];
+  scalar_t acc_arr[Vec::size()];
   acc_vec.store(acc_arr);
   for (int64_t i = 1; i < size; i++) {
-    scalar_t acc_arr_next[Vec::size];
+    scalar_t acc_arr_next[Vec::size()];
     acc_arr_next[0] = acc_arr[i];
     Vec acc_vec_next = Vec::loadu(acc_arr_next);
     acc_vec = vec_fun(acc_vec, acc_vec_next);
@@ -25,11 +25,11 @@ inline scalar_t vec_reduce_all(
 template <typename scalar_t, typename Op>
 inline scalar_t reduce_all(const Op& vec_fun, scalar_t* data, int64_t size) {
   using Vec = vec256::Vec256<scalar_t>;
-  if (size < Vec::size)
+  if (size < Vec::size())
     return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
-  int64_t d = Vec::size;
+  int64_t d = Vec::size();
   Vec acc_vec = Vec::loadu(data);
-  for (; d < size - (size % Vec::size); d += Vec::size) {
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
     Vec data_vec = Vec::loadu(data + d);
     acc_vec = vec_fun(acc_vec, data_vec);
   }
@@ -37,7 +37,7 @@ inline scalar_t reduce_all(const Op& vec_fun, scalar_t* data, int64_t size) {
     Vec data_vec = Vec::loadu(data + d, size - d);
     acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
   }
-  return vec_reduce_all(vec_fun, acc_vec, Vec::size);
+  return vec_reduce_all(vec_fun, acc_vec, Vec::size());
 }
 
 template <typename scalar_t, typename MapOp, typename ReduceOp>
@@ -47,11 +47,11 @@ inline scalar_t map_reduce_all(
     scalar_t* data,
     int64_t size) {
   using Vec = vec256::Vec256<scalar_t>;
-  if (size < Vec::size)
+  if (size < Vec::size())
     return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
-  int64_t d = Vec::size;
+  int64_t d = Vec::size();
   Vec acc_vec = map_fun(Vec::loadu(data));
-  for (; d < size - (size % Vec::size); d += Vec::size) {
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
     Vec data_vec = Vec::loadu(data + d);
     data_vec = map_fun(data_vec);
     acc_vec = red_fun(acc_vec, data_vec);
@@ -61,7 +61,7 @@ inline scalar_t map_reduce_all(
     data_vec = map_fun(data_vec);
     acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
   }
-  return vec_reduce_all(red_fun, acc_vec, Vec::size);
+  return vec_reduce_all(red_fun, acc_vec, Vec::size());
 }
 
 template <typename scalar_t, typename MapOp, typename ReduceOp>
@@ -72,15 +72,15 @@ inline scalar_t map2_reduce_all(
     const scalar_t* data2,
     int64_t size) {
   using Vec = vec256::Vec256<scalar_t>;
-  if (size < Vec::size) {
+  if (size < Vec::size()) {
     Vec data_vec = Vec::loadu(data, size);
     Vec data2_vec = Vec::loadu(data2, size);
     data_vec = map_fun(data_vec, data2_vec);
     return vec_reduce_all(red_fun, data_vec, size);
   }
-  int64_t d = Vec::size;
+  int64_t d = Vec::size();
   Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
-  for (; d < size - (size % Vec::size); d += Vec::size) {
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
     Vec data_vec = Vec::loadu(data + d);
     Vec data2_vec = Vec::loadu(data2 + d);
     data_vec = map_fun(data_vec, data2_vec);
@@ -92,7 +92,7 @@ inline scalar_t map2_reduce_all(
     data_vec = map_fun(data_vec, data2_vec);
     acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
   }
-  return vec_reduce_all(red_fun, acc_vec, Vec::size);
+  return vec_reduce_all(red_fun, acc_vec, Vec::size());
 }
 
 template <typename scalar_t, typename Op>
@@ -103,7 +103,7 @@ inline void map(
     int64_t size) {
   using Vec = vec256::Vec256<scalar_t>;
   int64_t d = 0;
-  for (; d < size - (size % Vec::size); d += Vec::size) {
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
     Vec output_vec = vec_fun(Vec::loadu(input_data + d));
     output_vec.store(output_data + d);
   }
@@ -122,7 +122,7 @@ inline void map2(
     int64_t size) {
   using Vec = vec256::Vec256<scalar_t>;
   int64_t d = 0;
-  for (; d < size - (size % Vec::size); d += Vec::size) {
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
     Vec data_vec = Vec::loadu(input_data + d);
     Vec data_vec2 = Vec::loadu(input_data2 + d);
     Vec output_vec = vec_fun(data_vec, data_vec2);
index 1cbe915..3b7d478 100644 (file)
 
 namespace at {
 namespace vec256 {
+
+// Note [Acceptable use of anonymous namespace in header]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// Yes you saw right, this is an anonymous namespace in a header.  This header,
+// and all of its subheaders, REQUIRE their code to be entirely inlined into
+// the compilation unit that uses them.  It's important that these functions have
+// internal linkage so that kernels for different architectures don't get
+// combined during linking. It's sufficient to label functions "static", but
+// class methods must be an unnamed namespace to have internal linkage (since
+// static means something different in the context of classes).
 namespace {
 
 template <typename T>
 std::ostream& operator<<(std::ostream& stream, const Vec256<T>& vec) {
-  T buf[Vec256<T>::size];
+  T buf[Vec256<T>::size()];
   vec.store(buf);
   stream << "vec[";
-  for (int i = 0; i != Vec256<T>::size; i++) {
+  for (int i = 0; i != Vec256<T>::size(); i++) {
     if (i != 0) {
       stream << ", ";
     }
index e97070e..5217b90 100644 (file)
@@ -20,6 +20,7 @@
 
 namespace at {
 namespace vec256 {
+// See Note [Acceptable use of anonymous namespace in header]
 namespace {
 
 template<size_t n> struct int_of_size;
@@ -45,15 +46,49 @@ struct Vec256 {
 private:
   T values[32 / sizeof(T)] = {0};
 public:
-  static constexpr int size = 32 / sizeof(T);
+  // Note [constexpr static function to avoid odr-usage compiler bug]
+  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+  // Why, you might ask, is size defined to be a static constexpr function,
+  // rather than a more ordinary 'static constexpr int size;' variable?
+  // The problem lies within ODR rules for static constexpr members versus
+  // static constexpr functions.  First, recall that this class (along with all
+  // of its derivations) live in an anonymous namespace: they are intended to be
+  // *completely* inlined at their use-sites, because we need to compile it
+  // multiple times for different instruction sets.
+  //
+  // Because of this constraint, we CANNOT provide a single definition for
+  // any static members in this class; since we want to compile the class
+  // multiple times, there wouldn't actually be any good place to put the
+  // definition.  Now here is the problem: if we ODR-use a static constexpr
+  // member, we are *obligated* to provide a definition.  Without the
+  // definition, you get a compile error like:
+  //
+  //    relocation R_X86_64_PC32 against undefined symbol
+  //    `_ZN2at6vec25612_GLOBAL__N_16Vec256IdE4sizeE' can not be used when making
+  //    a shared object; recompile with -fPIC
+  //
+  // If this were C++17, we could replace a static constexpr variable with
+  // an inline variable which doesn't require one definition. But we are not
+  // C++17.  So the next best thing is to replace the member with a static
+  // constexpr (and therefore inline) function, which does not require ODR
+  // either.
+  //
+  // Also, technically according to the C++ standard, we don't have to define
+  // a constexpr variable if we never odr-use it.  But it seems that some
+  // versions GCC/Clang have buggy determinations on whether or not an
+  // identifier is odr-used or not, and in any case it's hard to tel if
+  // a variabe is odr-used or not.  So best to just cut the probem at the root.
+  static constexpr int size() {
+    return 32 / sizeof(T);
+  }
   Vec256() {}
   Vec256(T val) {
-    for (int i = 0; i != size; i++) {
+    for (int i = 0; i != size(); i++) {
       values[i] = val;
     }
   }
   template<typename... Args,
-           typename = c10::guts::enable_if_t<(sizeof...(Args) == size)>>
+           typename = c10::guts::enable_if_t<(sizeof...(Args) == size())>>
   Vec256(Args... vals) {
     values = { vals... };
   }
@@ -61,7 +96,7 @@ public:
   static Vec256<T> blend(const Vec256<T>& a, const Vec256<T>& b) {
     int64_t mask = mask_;
     Vec256 vec;
-    for (int64_t i = 0; i < size; i++) {
+    for (int64_t i = 0; i < size(); i++) {
       if (mask & 0x01) {
         vec[i] = b[i];
       } else {
@@ -74,9 +109,9 @@ public:
   static Vec256<T> blendv(const Vec256<T>& a, const Vec256<T>& b,
                           const Vec256<T>& mask) {
     Vec256 vec;
-    int_same_size_t<T> buffer[size];
+    int_same_size_t<T> buffer[size()];
     mask.store(buffer);
-    for (int64_t i = 0; i < size; i++) {
+    for (int64_t i = 0; i < size(); i++) {
       if (buffer[i] & 0x01)
        {
         vec[i] = b[i];
@@ -88,14 +123,14 @@ public:
   }
   static Vec256<T> arange(T base = static_cast<T>(0), T step = static_cast<T>(1)) {
     Vec256 vec;
-    for (int64_t i = 0; i < size; i++) {
+    for (int64_t i = 0; i < size(); i++) {
       vec.values[i] = base + i * step;
     }
     return vec;
   }
-  static Vec256<T> set(const Vec256<T>& a, const Vec256<T>& b, int64_t count = size) {
+  static Vec256<T> set(const Vec256<T>& a, const Vec256<T>& b, int64_t count = size()) {
     Vec256 vec;
-    for (int64_t i = 0; i < size; i++) {
+    for (int64_t i = 0; i < size(); i++) {
       if (i < count) {
         vec[i] = b[i];
       } else {
@@ -114,7 +149,7 @@ public:
     std::memcpy(vec.values, ptr, count * sizeof(T));
     return vec;
   }
-  void store(void* ptr, int count = size) const {
+  void store(void* ptr, int count = size()) const {
     std::memcpy(ptr, values, count * sizeof(T));
   }
   const T& operator[](int idx) const {
@@ -125,14 +160,14 @@ public:
   }
   Vec256<T> map(T (*f)(T)) const {
     Vec256<T> ret;
-    for (int64_t i = 0; i != size; i++) {
+    for (int64_t i = 0; i != size(); i++) {
       ret[i] = f(values[i]);
     }
     return ret;
   }
   Vec256<T> abs() const {
     Vec256<T> ret;
-    for (int64_t i = 0; i < size; i++) {
+    for (int64_t i = 0; i < size(); i++) {
       ret[i] = values[i] < 0 ? -values[i] : values[i];
     }
     return ret;
@@ -214,7 +249,7 @@ public:
   }
   Vec256<T> pow(const Vec256<T> &exp) const {
     Vec256<T> ret;
-    for (int64_t i = 0; i < size; i++) {
+    for (int64_t i = 0; i < size(); i++) {
       ret[i] = std::pow(values[i], exp[i]);
     }
     return ret;
@@ -222,7 +257,7 @@ public:
 #define DEFINE_COMP(binary_pred)                                              \
   Vec256<T> operator binary_pred(const Vec256<T> &other) const {              \
     Vec256<T> vec;                                                            \
-    for (int64_t i = 0; i != size; i++) {                                     \
+    for (int64_t i = 0; i != size(); i++) {                                     \
       if (values[i] binary_pred other.values[i]) {                            \
         std::memset(static_cast<void*>(vec.values + i), 0xFF, sizeof(T));     \
       } else {                                                                \
@@ -242,7 +277,7 @@ public:
 
 template <class T> Vec256<T> inline operator+(const Vec256<T> &a, const Vec256<T> &b) {
   Vec256<T> c = Vec256<T>();
-  for (int i = 0; i != Vec256<T>::size; i++) {
+  for (int i = 0; i != Vec256<T>::size(); i++) {
     c[i] = a[i] + b[i];
   }
   return c;
@@ -250,7 +285,7 @@ template <class T> Vec256<T> inline operator+(const Vec256<T> &a, const Vec256<T
 
 template <class T> Vec256<T> inline operator-(const Vec256<T> &a, const Vec256<T> &b) {
   Vec256<T> c = Vec256<T>();
-  for (int i = 0; i != Vec256<T>::size; i++) {
+  for (int i = 0; i != Vec256<T>::size(); i++) {
     c[i] = a[i] - b[i];
   }
   return c;
@@ -258,7 +293,7 @@ template <class T> Vec256<T> inline operator-(const Vec256<T> &a, const Vec256<T
 
 template <class T> Vec256<T> inline operator*(const Vec256<T> &a, const Vec256<T> &b) {
   Vec256<T> c = Vec256<T>();
-  for (int i = 0; i != Vec256<T>::size; i++) {
+  for (int i = 0; i != Vec256<T>::size(); i++) {
     c[i] = a[i] * b[i];
   }
   return c;
@@ -266,7 +301,7 @@ template <class T> Vec256<T> inline operator*(const Vec256<T> &a, const Vec256<T
 
 template <class T> Vec256<T> inline operator/(const Vec256<T> &a, const Vec256<T> &b) __ubsan_ignore_float_divide_by_zero__ {
   Vec256<T> c = Vec256<T>();
-  for (int i = 0; i != Vec256<T>::size; i++) {
+  for (int i = 0; i != Vec256<T>::size(); i++) {
     c[i] = a[i] / b[i];
   }
   return c;
@@ -276,7 +311,7 @@ template <class T> Vec256<T> inline operator/(const Vec256<T> &a, const Vec256<T
 // either input is a NaN.
 template <class T> Vec256<T> inline maximum(const Vec256<T> &a, const Vec256<T> &b) {
   Vec256<T> c = Vec256<T>();
-  for (int i = 0; i != Vec256<T>::size; i++) {
+  for (int i = 0; i != Vec256<T>::size(); i++) {
     c[i] = (a[i] > b[i]) ? a[i] : b[i];
     if (std::is_floating_point<T>::value && std::isnan(a[i])) {
       // If either input is NaN, propagate a NaN.
@@ -301,7 +336,7 @@ inline T maximum(const T& a, const T& b) {
 // either input is a NaN.
 template <class T> Vec256<T> inline minimum(const Vec256<T> &a, const Vec256<T> &b) {
   Vec256<T> c = Vec256<T>();
-  for (int i = 0; i != Vec256<T>::size; i++) {
+  for (int i = 0; i != Vec256<T>::size(); i++) {
     c[i] = (a[i] < b[i]) ? a[i] : b[i];
     if (std::is_floating_point<T>::value && std::isnan(a[i])) {
       // If either input is NaN, propagate a NaN.
@@ -327,8 +362,8 @@ inline T minimum(const T& a, const T& b) {
 template <class T>                                                          \
 Vec256<T> inline operator op(const Vec256<T> &a, const Vec256<T> &b) {      \
   using iT = int_same_size_t<T>;                                            \
-  iT buffer[Vec256<T>::size];                                               \
-  for (int64_t i = 0; i != Vec256<T>::size; i++) {                          \
+  iT buffer[Vec256<T>::size()];                                               \
+  for (int64_t i = 0; i != Vec256<T>::size(); i++) {                          \
     auto a_val = a[i];                                                      \
     auto b_val = b[i];                                                      \
     iT *i_a_ptr = reinterpret_cast<iT*>(&a_val);                            \
@@ -350,7 +385,7 @@ inline T fmadd(const T& a, const T& b, const T& c) {
 template <int64_t scale = 1, typename T = void>
 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<T>>
 inline gather(T const* base_addr, const Vec256<int_same_size_t<T>>& vindex) {
-  static constexpr int size = Vec256<T>::size;
+  static constexpr int size = Vec256<T>::size();
   int_same_size_t<T> index_arr[size];
   vindex.store(static_cast<void*>(index_arr));
   T buffer[size];
@@ -364,7 +399,7 @@ template <int64_t scale = 1, typename T = void>
 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<T>>
 inline mask_gather(const Vec256<T>& src, T const* base_addr,
                    const Vec256<int_same_size_t<T>>& vindex, Vec256<T>& mask) {
-  static constexpr int size = Vec256<T>::size;
+  static constexpr int size = Vec256<T>::size();
   T src_arr[size];
   int_same_size_t<T> mask_arr[size];  // use int type so we can logical and
   int_same_size_t<T> index_arr[size];
@@ -392,7 +427,7 @@ namespace {
   template<typename dst_t, typename src_t>
   struct CastImpl {
     static inline Vec256<dst_t> apply(const Vec256<src_t>& src) {
-      src_t src_arr[Vec256<src_t>::size];
+      src_t src_arr[Vec256<src_t>::size()];
       src.store(static_cast<void*>(src_arr));
       return Vec256<dst_t>::loadu(static_cast<const void*>(src_arr));
     }
@@ -412,7 +447,7 @@ Vec256<dst_t> cast(const Vec256<src_t>& src) {
 
 template <typename T>
 inline Vec256<int_same_size_t<T>> convert_to_int_of_same_size(const Vec256<T>& src) {
-  static constexpr int size = Vec256<T>::size;
+  static constexpr int size = Vec256<T>::size();
   T src_arr[size];
   src.store(static_cast<void*>(src_arr));
   int_same_size_t<T> buffer[size];
@@ -427,9 +462,9 @@ inline Vec256<int_same_size_t<T>> convert_to_int_of_same_size(const Vec256<T>& s
 //       returns:            Vec256<float>   = {a0, a1, a2, a3, a4, a5, a6, a7}
 //                           Vec256<float>   = {b0, b1, b2, b3, b4, b5, b6, b7}
 template <typename T>
-inline c10::guts::enable_if_t<Vec256<T>::size % 2 == 0, std::pair<Vec256<T>, Vec256<T>>>
+inline c10::guts::enable_if_t<Vec256<T>::size() % 2 == 0, std::pair<Vec256<T>, Vec256<T>>>
 deinterleave2(const Vec256<T>& a, const Vec256<T>& b) {
-  static constexpr int size = Vec256<T>::size;
+  static constexpr int size = Vec256<T>::size();
   static constexpr int half_size = size / 2;
   T a_arr[size];
   T b_arr[size];
@@ -453,9 +488,9 @@ deinterleave2(const Vec256<T>& a, const Vec256<T>& b) {
 //       returns:            Vec256<float>   = {a0, b0, a1, b1, a2, b2, a3, b3}
 //                           Vec256<float>   = {a4, b4, a5, b5, a6, b6, a7, b7}
 template <typename T>
-inline c10::guts::enable_if_t<Vec256<T>::size % 2 == 0, std::pair<Vec256<T>, Vec256<T>>>
+inline c10::guts::enable_if_t<Vec256<T>::size() % 2 == 0, std::pair<Vec256<T>, Vec256<T>>>
 interleave2(const Vec256<T>& a, const Vec256<T>& b) {
-  static constexpr int size = Vec256<T>::size;
+  static constexpr int size = Vec256<T>::size();
   static constexpr int half_size = size / 2;
   T a_arr[size];
   T b_arr[size];
index bd50cda..c5fea7d 100644 (file)
@@ -8,6 +8,7 @@
 
 namespace at {
 namespace vec256 {
+// See Note [Acceptable use of anonymous namespace in header]
 namespace {
 
 #if defined(__AVX__) && !defined(_MSC_VER)
@@ -16,7 +17,9 @@ template <> class Vec256<double> {
 private:
   __m256d values;
 public:
-  static constexpr int size = 4;
+  static constexpr int size() {
+    return 4;
+  }
   Vec256() {}
   Vec256(__m256d v) : values(v) {}
   Vec256(double val) {
@@ -40,7 +43,7 @@ public:
     return Vec256<double>(base, base + step, base + 2 * step, base + 3 * step);
   }
   static Vec256<double> set(const Vec256<double>& a, const Vec256<double>& b,
-                            int64_t count = size) {
+                            int64_t count = size()) {
     switch (count) {
       case 0:
         return a;
@@ -53,22 +56,22 @@ public:
     }
     return b;
   }
-  static Vec256<double> loadu(const void* ptr, int64_t count = size) {
-    if (count == size)
+  static Vec256<double> loadu(const void* ptr, int64_t count = size()) {
+    if (count == size())
       return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
 
-    __at_align32__ double tmp_values[size];
+    __at_align32__ double tmp_values[size()];
     std::memcpy(
         tmp_values,
         reinterpret_cast<const double*>(ptr),
         count * sizeof(double));
     return _mm256_load_pd(tmp_values);
   }
-  void store(void* ptr, int count = size) const {
-    if (count == size) {
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
       _mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
     } else if (count > 0) {
-      double tmp_values[size];
+      double tmp_values[size()];
       _mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
       std::memcpy(ptr, tmp_values, count * sizeof(double));
     }
@@ -252,7 +255,7 @@ template <>
 void convert(const double* src, double* dst, int64_t n) {
   int64_t i;
 #pragma unroll
-  for (i = 0; i <= (n - Vec256<double>::size); i += Vec256<double>::size) {
+  for (i = 0; i <= (n - Vec256<double>::size()); i += Vec256<double>::size()) {
     _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
   }
 #pragma unroll
index c769c32..dfaa126 100644 (file)
@@ -8,6 +8,7 @@
 
 namespace at {
 namespace vec256 {
+// See Note [Acceptable use of anonymous namespace in header]
 namespace {
 
 #if defined(__AVX__) && !defined(_MSC_VER)
@@ -16,7 +17,9 @@ template <> class Vec256<float> {
 private:
   __m256 values;
 public:
-  static constexpr int size = 8;
+  static constexpr int size() {
+    return 8;
+  }
   Vec256() {}
   Vec256(__m256 v) : values(v) {}
   Vec256(float val) {
@@ -43,7 +46,7 @@ public:
       base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
   }
   static Vec256<float> set(const Vec256<float>& a, const Vec256<float>& b,
-                           int64_t count = size) {
+                           int64_t count = size()) {
     switch (count) {
       case 0:
         return a;
@@ -64,19 +67,19 @@ public:
     }
     return b;
   }
-  static Vec256<float> loadu(const void* ptr, int64_t count = size) {
-    if (count == size)
+  static Vec256<float> loadu(const void* ptr, int64_t count = size()) {
+    if (count == size())
       return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
-    __at_align32__ float tmp_values[size];
+    __at_align32__ float tmp_values[size()];
     std::memcpy(
         tmp_values, reinterpret_cast<const float*>(ptr), count * sizeof(float));
     return _mm256_loadu_ps(tmp_values);
   }
-  void store(void* ptr, int64_t count = size) const {
-    if (count == size) {
+  void store(void* ptr, int64_t count = size()) const {
+    if (count == size()) {
       _mm256_storeu_ps(reinterpret_cast<float*>(ptr), values);
     } else if (count > 0) {
-      float tmp_values[size];
+      float tmp_values[size()];
       _mm256_storeu_ps(reinterpret_cast<float*>(tmp_values), values);
       std::memcpy(ptr, tmp_values, count * sizeof(float));
     }
@@ -260,7 +263,7 @@ template <>
 void convert(const float* src, float* dst, int64_t n) {
   int64_t i;
 #pragma unroll
-  for (i = 0; i <= (n - Vec256<float>::size); i += Vec256<float>::size) {
+  for (i = 0; i <= (n - Vec256<float>::size()); i += Vec256<float>::size()) {
     _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i));
   }
 #pragma unroll
index f793d7a..9d2581e 100644 (file)
@@ -22,7 +22,9 @@ public:
 
 template <>
 struct Vec256<int64_t> : public Vec256i {
-  static constexpr int size = 4;
+  static constexpr int size() {
+    return 4;
+  }
   using Vec256i::Vec256i;
   Vec256() {}
   Vec256(int64_t v) { values = _mm256_set1_epi64x(v); }
@@ -31,7 +33,7 @@ struct Vec256<int64_t> : public Vec256i {
   }
   template <int64_t mask>
   static Vec256<int64_t> blend(Vec256<int64_t> a, Vec256<int64_t> b) {
-    __at_align32__ int64_t tmp_values[size];
+    __at_align32__ int64_t tmp_values[size()];
     a.store(tmp_values);
     if (mask & 0x01)
       tmp_values[0] = _mm256_extract_epi64(b.values, 0);
@@ -51,7 +53,7 @@ struct Vec256<int64_t> : public Vec256i {
     return Vec256<int64_t>(base, base + step, base + 2 * step, base + 3 * step);
   }
   static Vec256<int64_t>
-  set(Vec256<int64_t> a, Vec256<int64_t> b, int64_t count = size) {
+  set(Vec256<int64_t> a, Vec256<int64_t> b, int64_t count = size()) {
     switch (count) {
       case 0:
         return a;
@@ -68,15 +70,15 @@ struct Vec256<int64_t> : public Vec256i {
     return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
   }
   static Vec256<int64_t> loadu(const void* ptr, int64_t count) {
-    __at_align32__ int64_t tmp_values[size];
+    __at_align32__ int64_t tmp_values[size()];
     std::memcpy(tmp_values, ptr, count * sizeof(int64_t));
     return loadu(tmp_values);
   }
-  void store(void* ptr, int count = size) const {
-    if (count == size) {
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
       _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
     } else if (count > 0) {
-      __at_align32__ int64_t tmp_values[size];
+      __at_align32__ int64_t tmp_values[size()];
       _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
       std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
     }
@@ -117,7 +119,9 @@ struct Vec256<int64_t> : public Vec256i {
 
 template <>
 struct Vec256<int32_t> : public Vec256i {
-  static constexpr int size = 8;
+  static constexpr int size() {
+    return 8;
+  }
   using Vec256i::Vec256i;
   Vec256() {}
   Vec256(int32_t v) { values = _mm256_set1_epi32(v); }
@@ -139,7 +143,7 @@ struct Vec256<int32_t> : public Vec256i {
       base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
   }
   static Vec256<int32_t>
-  set(Vec256<int32_t> a, Vec256<int32_t> b, int32_t count = size) {
+  set(Vec256<int32_t> a, Vec256<int32_t> b, int32_t count = size()) {
     switch (count) {
       case 0:
         return a;
@@ -164,15 +168,15 @@ struct Vec256<int32_t> : public Vec256i {
     return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
   }
   static Vec256<int32_t> loadu(const void* ptr, int32_t count) {
-    __at_align32__ int32_t tmp_values[size];
+    __at_align32__ int32_t tmp_values[size()];
     std::memcpy(tmp_values, ptr, count * sizeof(int32_t));
     return loadu(tmp_values);
   }
-  void store(void* ptr, int count = size) const {
-    if (count == size) {
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
       _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
     } else if (count > 0) {
-      __at_align32__ int32_t tmp_values[size];
+      __at_align32__ int32_t tmp_values[size()];
       _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
       std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
     }
@@ -215,7 +219,7 @@ void convert(const int32_t *src, float *dst, int64_t n) {
 #ifndef _MSC_VER  
 # pragma unroll  
 #endif
-  for (i = 0; i <= (n - Vec256<int32_t>::size); i += Vec256<int32_t>::size) {
+  for (i = 0; i <= (n - Vec256<int32_t>::size()); i += Vec256<int32_t>::size()) {
     auto input_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
     auto output_vec = _mm256_cvtepi32_ps(input_vec);
     _mm256_storeu_ps(reinterpret_cast<float*>(dst + i), output_vec);
@@ -235,7 +239,7 @@ void convert(const int32_t *src, double *dst, int64_t n) {
 #ifndef _MSC_VER  
 # pragma unroll  
 #endif
-  for (i = 0; i <= (n - Vec256<double>::size); i += Vec256<double>::size) {
+  for (i = 0; i <= (n - Vec256<double>::size()); i += Vec256<double>::size()) {
     auto input_128_vec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
     auto output_vec = _mm256_cvtepi32_pd(input_128_vec);
     _mm256_storeu_pd(reinterpret_cast<double*>(dst + i), output_vec);
@@ -250,7 +254,9 @@ void convert(const int32_t *src, double *dst, int64_t n) {
 
 template <>
 struct Vec256<int16_t> : public Vec256i {
-  static constexpr int size = 16;
+  static constexpr int size() {
+    return 16;
+  }
   using Vec256i::Vec256i;
   Vec256() {}
   Vec256(int16_t v) { values = _mm256_set1_epi16(v); }
@@ -263,7 +269,7 @@ struct Vec256<int16_t> : public Vec256i {
   }
   template <int64_t mask>
   static Vec256<int16_t> blend(Vec256<int16_t> a, Vec256<int16_t> b) {
-    __at_align32__ int16_t tmp_values[size];
+    __at_align32__ int16_t tmp_values[size()];
     a.store(tmp_values);
     if (mask & 0x01)
       tmp_values[0] = _mm256_extract_epi16(b.values, 0);
@@ -311,7 +317,7 @@ struct Vec256<int16_t> : public Vec256i {
       base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
   }
   static Vec256<int16_t>
-  set(Vec256<int16_t> a, Vec256<int16_t> b, int16_t count = size) {
+  set(Vec256<int16_t> a, Vec256<int16_t> b, int16_t count = size()) {
     switch (count) {
       case 0:
         return a;
@@ -352,15 +358,15 @@ struct Vec256<int16_t> : public Vec256i {
     return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
   }
   static Vec256<int16_t> loadu(const void* ptr, int16_t count) {
-    __at_align32__ int16_t tmp_values[size];
+    __at_align32__ int16_t tmp_values[size()];
     std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
     return loadu(tmp_values);
   }
-  void store(void* ptr, int count = size) const {
-    if (count == size) {
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
       _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
     } else if (count > 0) {
-      __at_align32__ int16_t tmp_values[size];
+      __at_align32__ int16_t tmp_values[size()];
       _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
       std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
     }
@@ -462,11 +468,11 @@ Vec256<int16_t> inline operator*(const Vec256<int16_t>& a, const Vec256<int16_t>
 
 template <typename T>
 Vec256<T> inline intdiv_256(const Vec256<T>& a, const Vec256<T>& b) {
-  T values_a[Vec256<T>::size];
-  T values_b[Vec256<T>::size];
+  T values_a[Vec256<T>::size()];
+  T values_b[Vec256<T>::size()];
   a.store(values_a);
   b.store(values_b);
-  for (int i = 0; i != Vec256<T>::size; i++) {
+  for (int i = 0; i != Vec256<T>::size(); i++) {
     values_a[i] /= values_b[i];
   }
   return Vec256<T>::loadu(values_a);
index a6239be..620537f 100644 (file)
@@ -149,7 +149,7 @@ struct PDist {
   }
 
   template <typename F>
-  inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size) {
+  inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
     for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
 
       const Vec self_vec_i = Vec::loadu(self_i, count);
@@ -187,15 +187,15 @@ struct PDist {
     // The only way to parallelize and avoid locking requires parallelizing
     // over the columns of the input, i.e. we compute the gradient for the
     // first section of each vector independentaly of the second section, etc.
-    at::parallel_for(0, m / Vec::size, internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t l, int64_t end) {
-      const scalar_t * self_l = self_start + l * Vec::size;
-      scalar_t * res_l = res_start + l * Vec::size;
+    at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t l, int64_t end) {
+      const scalar_t * self_l = self_start + l * Vec::size();
+      scalar_t * res_l = res_start + l * Vec::size();
 
-      for (const scalar_t * const res_end = res_start + end * Vec::size; res_l != res_end; self_l += Vec::size, res_l += Vec::size) {
+      for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; self_l += Vec::size(), res_l += Vec::size()) {
         backward_down_column<F>(self_l, res_l, grad_start, dist_start, pvec, n, m, gs);
       }
     });
-    const int64_t remainder = m % Vec::size;
+    const int64_t remainder = m % Vec::size();
     if (remainder) {
       backward_down_column<F>(self_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, pvec, n, m, gs, remainder);
     }
index ef1d4b6..7faff9b 100644 (file)
@@ -484,25 +484,25 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bilinear, padding>
     // So we store the necessary vectors to temporary arrays and use the helper
     // mask_scatter_add defined above.
 
-    integer_t i_gInp_nw_offset_arr[iVec::size];
-    integer_t i_gInp_ne_offset_arr[iVec::size];
-    integer_t i_gInp_sw_offset_arr[iVec::size];
-    integer_t i_gInp_se_offset_arr[iVec::size];
+    integer_t i_gInp_nw_offset_arr[iVec::size()];
+    integer_t i_gInp_ne_offset_arr[iVec::size()];
+    integer_t i_gInp_sw_offset_arr[iVec::size()];
+    integer_t i_gInp_se_offset_arr[iVec::size()];
     i_gInp_nw_offset.store(i_gInp_nw_offset_arr);
     i_gInp_ne_offset.store(i_gInp_ne_offset_arr);
     i_gInp_sw_offset.store(i_gInp_sw_offset_arr);
     i_gInp_se_offset.store(i_gInp_se_offset_arr);
 
-    integer_t i_nw_mask_arr[iVec::size];
-    integer_t i_ne_mask_arr[iVec::size];
-    integer_t i_sw_mask_arr[iVec::size];
-    integer_t i_se_mask_arr[iVec::size];
+    integer_t i_nw_mask_arr[iVec::size()];
+    integer_t i_ne_mask_arr[iVec::size()];
+    integer_t i_sw_mask_arr[iVec::size()];
+    integer_t i_se_mask_arr[iVec::size()];
     nw_mask.store(i_nw_mask_arr);
     ne_mask.store(i_ne_mask_arr);
     sw_mask.store(i_sw_mask_arr);
     se_mask.store(i_se_mask_arr);
 
-    scalar_t gInp_corner_arr[Vec::size];
+    scalar_t gInp_corner_arr[Vec::size()];
 
     auto gx = Vec(0), gy = Vec(0);
     #ifndef _MSC_VER  
@@ -539,7 +539,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bilinear, padding>
     gx = gx * gx_mult;
     gy = gy * gy_mult;
 
-    constexpr int64_t step = Vec::size;
+    constexpr int64_t step = Vec::size();
     auto interleaved_gGrid = interleave2(gx, gy);
     auto gGrid_ptr = gGrid_slice.data() + offset * 2;
     std::get<0>(interleaved_gGrid).store(gGrid_ptr,
@@ -630,9 +630,9 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Nearest, padding>
 
     auto i_gInp_offset = i_y_nearest * iVec(inp_W) + i_x_nearest;  // gInp is contiguous
 
-    integer_t mask_arr[iVec::size];
+    integer_t mask_arr[iVec::size()];
     i_mask.store(mask_arr);
-    integer_t gInp_offset_arr[iVec::size];
+    integer_t gInp_offset_arr[iVec::size()];
     i_gInp_offset.store(gInp_offset_arr);
 
     #ifndef _MSC_VER  
@@ -666,7 +666,7 @@ static inline void grid_sample_2d_grid_slice_iterator(
 
   using Vec = Vec256<scalar_t>;
   using iVec = Vec256<int_same_size_t<scalar_t>>;
-  constexpr int64_t step = Vec::size;
+  constexpr int64_t step = Vec::size();
 
   // Loop over each output pixel in grid.
   // We consider the following three cases (after slicing out the batch
index 5417604..62b7f56 100644 (file)
@@ -80,15 +80,15 @@ template <typename func_t, typename vec_func_t>
 static inline void vectorized_binary_loop(char** data, int64_t n, func_t op, vec_func_t vop) {
   VEC_LOOP_HEADER(func_t, data)
   int64_t i = 0;
-  for (; i <= n - 2 * Vec::size; i += 2 * Vec::size) {
+  for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
     auto a1 = Vec::loadu(in1_ptr + i * sizeof(scalar_t));
-    auto a2 = Vec::loadu(in1_ptr + (i + Vec::size) * sizeof(scalar_t));
+    auto a2 = Vec::loadu(in1_ptr + (i + Vec::size()) * sizeof(scalar_t));
     auto b1 = Vec::loadu(in2_ptr + i * sizeof(scalar_t));
-    auto b2 = Vec::loadu(in2_ptr + (i + Vec::size) * sizeof(scalar_t));
+    auto b2 = Vec::loadu(in2_ptr + (i + Vec::size()) * sizeof(scalar_t));
     auto out1 = vop(a1, b1);
     auto out2 = vop(a2, b2);
     out1.store(out_ptr + i * sizeof(scalar_t));
-    out2.store(out_ptr + (i + Vec::size) * sizeof(scalar_t));
+    out2.store(out_ptr + (i + Vec::size()) * sizeof(scalar_t));
   }
   int64_t strides[] = { sizeof(scalar_t), sizeof(scalar_t), sizeof(scalar_t) };
   binary_loop(data, strides, i, n, op);
@@ -100,13 +100,13 @@ static inline void vectorized_binary_loop_s1(char** data, int64_t n, func_t op,
   VEC_LOOP_HEADER(func_t, data)
   int64_t i = 0;
   auto a = Vec(*(scalar_t*)in1_ptr);
-  for (; i <= n - 2 * Vec::size; i += 2 * Vec::size) {
+  for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
     auto b1 = Vec::loadu(in2_ptr + i * sizeof(scalar_t));
-    auto b2 = Vec::loadu(in2_ptr + (i + Vec::size) * sizeof(scalar_t));
+    auto b2 = Vec::loadu(in2_ptr + (i + Vec::size()) * sizeof(scalar_t));
     auto out1 = vop(a, b1);
     auto out2 = vop(a, b2);
     out1.store(out_ptr + i * sizeof(scalar_t));
-    out2.store(out_ptr + (i + Vec::size) * sizeof(scalar_t));
+    out2.store(out_ptr + (i + Vec::size()) * sizeof(scalar_t));
   }
   int64_t strides[] = { sizeof(scalar_t), 0, sizeof(scalar_t) };
   binary_loop(data, strides, i, n, op);
@@ -118,13 +118,13 @@ static inline void vectorized_binary_loop_s2(char** data, int64_t n, func_t op,
   VEC_LOOP_HEADER(func_t, data)
   int64_t i = 0;
   auto b = Vec(*(scalar_t*)in2_ptr);
-  for (; i <= n - 2 * Vec::size; i += 2 * Vec::size) {
+  for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
     auto a1 = Vec::loadu(in1_ptr + i * sizeof(scalar_t));
-    auto a2 = Vec::loadu(in1_ptr + (i + Vec::size) * sizeof(scalar_t));
+    auto a2 = Vec::loadu(in1_ptr + (i + Vec::size()) * sizeof(scalar_t));
     auto out1 = vop(a1, b);
     auto out2 = vop(a2, b);
     out1.store(out_ptr + i * sizeof(scalar_t));
-    out2.store(out_ptr + (i + Vec::size) * sizeof(scalar_t));
+    out2.store(out_ptr + (i + Vec::size()) * sizeof(scalar_t));
   }
   int64_t strides[] = { sizeof(scalar_t), sizeof(scalar_t), 0 };
   binary_loop(data, strides, i, n, op);
@@ -137,27 +137,27 @@ static inline void reduction128(char** data, int64_t n, int64_t stride, func_t o
   char* in_ptr = data[1];
   Vec acc[4];
   for  (int j = 0; j < 4; j++) {
-    acc[j] = Vec::loadu(in_ptr + j * Vec::size * sizeof(scalar_t));
+    acc[j] = Vec::loadu(in_ptr + j * Vec::size() * sizeof(scalar_t));
   }
   for (int64_t i = 1; i < n; i++) {
     const char* ptr = in_ptr + stride * i;
-    acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size * sizeof(scalar_t))));
-    acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size * sizeof(scalar_t))));
-    acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size * sizeof(scalar_t))));
-    acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size * sizeof(scalar_t))));
+    acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size() * sizeof(scalar_t))));
+    acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size() * sizeof(scalar_t))));
+    acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size() * sizeof(scalar_t))));
+    acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size() * sizeof(scalar_t))));
   }
   if (reduce) {
-    scalar_t buffer[Vec::size];
+    scalar_t buffer[Vec::size()];
     acc[0] = vop(vop(acc[0], acc[1]), vop(acc[2], acc[3]));
     acc[0].store(buffer);
-    for (int j = 1; j < Vec::size; j++) {
+    for (int j = 1; j < Vec::size(); j++) {
       buffer[0] = op(buffer[0], buffer[j]);
     }
     auto dst = (scalar_t*)out_ptr;
     *dst = op(*dst, buffer[0]);
   } else {
     for (int j = 0; j < 4; j++) {
-      auto dst = out_ptr + j * Vec::size * sizeof(scalar_t);
+      auto dst = out_ptr + j * Vec::size() * sizeof(scalar_t);
       acc[j] = vop(acc[j], Vec::loadu(dst));
       acc[j].store(dst);
     }
@@ -177,14 +177,14 @@ static inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int
 template <typename func_t, typename vec_func_t>
 static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
   VEC_HEADER(func_t)
-  int64_t vector_stride = 4 * Vec::size * sizeof(scalar_t);
-  int64_t count = n / (4 * Vec::size);
+  int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
+  int64_t count = n / (4 * Vec::size());
   if (count > 0) {
     reduction128(data, count, vector_stride, op, vop, /*reduce=*/true);
   }
   char* ptrs[3] = { data[0], data[0], data[1] };
   int64_t strides[] = { 0, 0, sizeof(scalar_t) };
-  binary_loop(ptrs, strides, count * 4 * Vec::size, n, op);
+  binary_loop(ptrs, strides, count * 4 * Vec::size(), n, op);
 }
 
 // computes the reduction out = op(out, in)
@@ -192,15 +192,15 @@ template <typename func_t, typename vec_func_t>
 static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
   VEC_HEADER(func_t)
 
-  // reduce down each column of 4 * Vec::size elements (128 bytes)
+  // reduce down each column of 4 * Vec::size() elements (128 bytes)
   int64_t outer_stride[2] = { 128, 128 };
-  UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size), [&] {
+  UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
     reduction128(data, size0, inner_stride, op, vop, /*reduce=*/false);
   });
 
   // reduce down the remaining columns
   int64_t step[] = { sizeof(scalar_t), sizeof(scalar_t) };
-  int64_t remaining = size1 % (4 * Vec::size);
+  int64_t remaining = size1 % (4 * Vec::size());
   UNARY_OUTER_LOOP(data, step, remaining, [&] {
     char* ptrs[3] = { data[0], data[0], data[1] };
     int64_t strides[] = { 0, 0, inner_stride };
index 983e894..0881912 100644 (file)
@@ -216,7 +216,7 @@ struct NormReduction {
     if (pval == 1){
       for (int row = 0; row < rows; row ++) {
         for (int j = 0; j != 4; j++) {
-          auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size]);
+          auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size()]);
           acc[j] = acc[j] + val.abs();
         }
       }
@@ -224,7 +224,7 @@ struct NormReduction {
     else if (pval == 2) {
       for (int row = 0; row < rows; row ++) {
         for (int j = 0; j != 4; j++) {
-          auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size]);
+          auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size()]);
           acc[j] = acc[j] + val * val;
         }
       }
@@ -232,14 +232,14 @@ struct NormReduction {
     else if (pval == 3) {
       for (int row = 0; row < rows; row ++) {
         for (int j = 0; j != 4; j++) {
-          auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size]);
+          auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size()]);
           acc[j] = acc[j] + (val * val * val).abs();
         }
       }
     }
     scalar_t buf[WIDTH] = {0};
     for (int j = 0; j != 4; j++) {
-      acc[j].store(&buf[j * Vec::size]);
+      acc[j].store(&buf[j * Vec::size()]);
     }
     for (int i = 0; i < WIDTH; i++) {
       result += buf[i];
index d0ae993..83f0232 100644 (file)
@@ -29,7 +29,7 @@ inline void _vec_log_softmax_lastdim(
     int64_t outer_size,
     int64_t dim_size) {
   using Vec = vec256::Vec256<scalar_t>;
-  static constexpr int64_t CHUNK_SIZE = (128 / sizeof(scalar_t)) * Vec::size;
+  static constexpr int64_t CHUNK_SIZE = (128 / sizeof(scalar_t)) * Vec::size();
   int64_t grain_size = internal::GRAIN_SIZE / (16 * dim_size * CHUNK_SIZE);
   if (grain_size < CHUNK_SIZE)
     grain_size = CHUNK_SIZE;
index 668d393..31263fd 100644 (file)
@@ -37,9 +37,9 @@ template <>
 int64_t _sigmoid(float* x, float* y, int64_t size) {
   using Vec = Vec256<float>;
   int64_t i = 0;
-  for (; i < size - (size % (2 * Vec::size)); i += 2 * Vec::size) {
+  for (; i < size - (size % (2 * Vec::size())); i += 2 * Vec::size()) {
     Vec ret = Vec::loadu(y + i);
-    Vec ret2 = Vec::loadu(y + i + Vec::size);
+    Vec ret2 = Vec::loadu(y + i + Vec::size());
     ret = ret.neg();
     ret2 = ret2.neg();
 #if defined(__AVX2__) && !defined(_MSC_VER)
@@ -54,7 +54,7 @@ int64_t _sigmoid(float* x, float* y, int64_t size) {
     ret = ret.reciprocal();
     ret2 = ret2.reciprocal();
     ret.store(x + i);
-    ret2.store(x + i + Vec::size);
+    ret2.store(x + i + Vec::size());
   }
   return i;
 }
@@ -63,9 +63,9 @@ template <>
 int64_t _sigmoid(double* x, double* y, int64_t size) {
   using Vec = Vec256<double>;
   int64_t i = 0;
-  for (; i < size - (size % (2 * Vec::size)); i += 2 * Vec::size) {
+  for (; i < size - (size % (2 * Vec::size())); i += 2 * Vec::size()) {
     Vec ret = Vec::loadu(y + i);
-    Vec ret2 = Vec::loadu(y + i + Vec::size);
+    Vec ret2 = Vec::loadu(y + i + Vec::size());
     ret = ret.neg();
     ret2 = ret2.neg();
     ret = ret.exp();
@@ -75,7 +75,7 @@ int64_t _sigmoid(double* x, double* y, int64_t size) {
     ret = ret.reciprocal();
     ret2 = ret2.reciprocal();
     ret.store(x + i);
-    ret2.store(x + i + Vec::size);
+    ret2.store(x + i + Vec::size());
   }
   return i;
 }
@@ -95,9 +95,9 @@ static void sigmoid_kernel(Tensor& result, const Tensor& self) {
           if (stridex == 1 && stridey == 1) {
             i = _sigmoid(x, y, size);
           }
-          for (; i < size; i += Vec::size) {
-            scalar_t buffer[Vec::size];
-            int64_t width = Vec::size;
+          for (; i < size; i += Vec::size()) {
+            scalar_t buffer[Vec::size()];
+            int64_t width = Vec::size();
             width = std::min(width, size - i);
             for (int64_t j = 0; j < width; j++) {
               buffer[j] = y[stridey * (i + j)];