vec256::Vec256<scalar_t> acc_vec,
int64_t size) {
using Vec = vec256::Vec256<scalar_t>;
- scalar_t acc_arr[Vec::size];
+ scalar_t acc_arr[Vec::size()];
acc_vec.store(acc_arr);
for (int64_t i = 1; i < size; i++) {
- scalar_t acc_arr_next[Vec::size];
+ scalar_t acc_arr_next[Vec::size()];
acc_arr_next[0] = acc_arr[i];
Vec acc_vec_next = Vec::loadu(acc_arr_next);
acc_vec = vec_fun(acc_vec, acc_vec_next);
template <typename scalar_t, typename Op>
inline scalar_t reduce_all(const Op& vec_fun, scalar_t* data, int64_t size) {
using Vec = vec256::Vec256<scalar_t>;
- if (size < Vec::size)
+ if (size < Vec::size())
return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
- int64_t d = Vec::size;
+ int64_t d = Vec::size();
Vec acc_vec = Vec::loadu(data);
- for (; d < size - (size % Vec::size); d += Vec::size) {
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(data + d);
acc_vec = vec_fun(acc_vec, data_vec);
}
Vec data_vec = Vec::loadu(data + d, size - d);
acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
}
- return vec_reduce_all(vec_fun, acc_vec, Vec::size);
+ return vec_reduce_all(vec_fun, acc_vec, Vec::size());
}
template <typename scalar_t, typename MapOp, typename ReduceOp>
scalar_t* data,
int64_t size) {
using Vec = vec256::Vec256<scalar_t>;
- if (size < Vec::size)
+ if (size < Vec::size())
return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
- int64_t d = Vec::size;
+ int64_t d = Vec::size();
Vec acc_vec = map_fun(Vec::loadu(data));
- for (; d < size - (size % Vec::size); d += Vec::size) {
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(data + d);
data_vec = map_fun(data_vec);
acc_vec = red_fun(acc_vec, data_vec);
data_vec = map_fun(data_vec);
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
}
- return vec_reduce_all(red_fun, acc_vec, Vec::size);
+ return vec_reduce_all(red_fun, acc_vec, Vec::size());
}
template <typename scalar_t, typename MapOp, typename ReduceOp>
const scalar_t* data2,
int64_t size) {
using Vec = vec256::Vec256<scalar_t>;
- if (size < Vec::size) {
+ if (size < Vec::size()) {
Vec data_vec = Vec::loadu(data, size);
Vec data2_vec = Vec::loadu(data2, size);
data_vec = map_fun(data_vec, data2_vec);
return vec_reduce_all(red_fun, data_vec, size);
}
- int64_t d = Vec::size;
+ int64_t d = Vec::size();
Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
- for (; d < size - (size % Vec::size); d += Vec::size) {
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(data + d);
Vec data2_vec = Vec::loadu(data2 + d);
data_vec = map_fun(data_vec, data2_vec);
data_vec = map_fun(data_vec, data2_vec);
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
}
- return vec_reduce_all(red_fun, acc_vec, Vec::size);
+ return vec_reduce_all(red_fun, acc_vec, Vec::size());
}
template <typename scalar_t, typename Op>
int64_t size) {
using Vec = vec256::Vec256<scalar_t>;
int64_t d = 0;
- for (; d < size - (size % Vec::size); d += Vec::size) {
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec output_vec = vec_fun(Vec::loadu(input_data + d));
output_vec.store(output_data + d);
}
int64_t size) {
using Vec = vec256::Vec256<scalar_t>;
int64_t d = 0;
- for (; d < size - (size % Vec::size); d += Vec::size) {
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(input_data + d);
Vec data_vec2 = Vec::loadu(input_data2 + d);
Vec output_vec = vec_fun(data_vec, data_vec2);
namespace at {
namespace vec256 {
+
+// Note [Acceptable use of anonymous namespace in header]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// Yes you saw right, this is an anonymous namespace in a header. This header,
+// and all of its subheaders, REQUIRE their code to be entirely inlined into
+// the compilation unit that uses them. It's important that these functions have
+// internal linkage so that kernels for different architectures don't get
+// combined during linking. It's sufficient to label functions "static", but
+// class methods must be an unnamed namespace to have internal linkage (since
+// static means something different in the context of classes).
namespace {
template <typename T>
std::ostream& operator<<(std::ostream& stream, const Vec256<T>& vec) {
- T buf[Vec256<T>::size];
+ T buf[Vec256<T>::size()];
vec.store(buf);
stream << "vec[";
- for (int i = 0; i != Vec256<T>::size; i++) {
+ for (int i = 0; i != Vec256<T>::size(); i++) {
if (i != 0) {
stream << ", ";
}
namespace at {
namespace vec256 {
+// See Note [Acceptable use of anonymous namespace in header]
namespace {
template<size_t n> struct int_of_size;
private:
T values[32 / sizeof(T)] = {0};
public:
- static constexpr int size = 32 / sizeof(T);
+ // Note [constexpr static function to avoid odr-usage compiler bug]
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ // Why, you might ask, is size defined to be a static constexpr function,
+ // rather than a more ordinary 'static constexpr int size;' variable?
+ // The problem lies within ODR rules for static constexpr members versus
+ // static constexpr functions. First, recall that this class (along with all
+ // of its derivations) live in an anonymous namespace: they are intended to be
+ // *completely* inlined at their use-sites, because we need to compile it
+ // multiple times for different instruction sets.
+ //
+ // Because of this constraint, we CANNOT provide a single definition for
+ // any static members in this class; since we want to compile the class
+ // multiple times, there wouldn't actually be any good place to put the
+ // definition. Now here is the problem: if we ODR-use a static constexpr
+ // member, we are *obligated* to provide a definition. Without the
+ // definition, you get a compile error like:
+ //
+ // relocation R_X86_64_PC32 against undefined symbol
+ // `_ZN2at6vec25612_GLOBAL__N_16Vec256IdE4sizeE' can not be used when making
+ // a shared object; recompile with -fPIC
+ //
+ // If this were C++17, we could replace a static constexpr variable with
+ // an inline variable which doesn't require one definition. But we are not
+ // C++17. So the next best thing is to replace the member with a static
+ // constexpr (and therefore inline) function, which does not require ODR
+ // either.
+ //
+ // Also, technically according to the C++ standard, we don't have to define
+ // a constexpr variable if we never odr-use it. But it seems that some
+ // versions GCC/Clang have buggy determinations on whether or not an
+ // identifier is odr-used or not, and in any case it's hard to tel if
+ // a variabe is odr-used or not. So best to just cut the probem at the root.
+ static constexpr int size() {
+ return 32 / sizeof(T);
+ }
Vec256() {}
Vec256(T val) {
- for (int i = 0; i != size; i++) {
+ for (int i = 0; i != size(); i++) {
values[i] = val;
}
}
template<typename... Args,
- typename = c10::guts::enable_if_t<(sizeof...(Args) == size)>>
+ typename = c10::guts::enable_if_t<(sizeof...(Args) == size())>>
Vec256(Args... vals) {
values = { vals... };
}
static Vec256<T> blend(const Vec256<T>& a, const Vec256<T>& b) {
int64_t mask = mask_;
Vec256 vec;
- for (int64_t i = 0; i < size; i++) {
+ for (int64_t i = 0; i < size(); i++) {
if (mask & 0x01) {
vec[i] = b[i];
} else {
static Vec256<T> blendv(const Vec256<T>& a, const Vec256<T>& b,
const Vec256<T>& mask) {
Vec256 vec;
- int_same_size_t<T> buffer[size];
+ int_same_size_t<T> buffer[size()];
mask.store(buffer);
- for (int64_t i = 0; i < size; i++) {
+ for (int64_t i = 0; i < size(); i++) {
if (buffer[i] & 0x01)
{
vec[i] = b[i];
}
static Vec256<T> arange(T base = static_cast<T>(0), T step = static_cast<T>(1)) {
Vec256 vec;
- for (int64_t i = 0; i < size; i++) {
+ for (int64_t i = 0; i < size(); i++) {
vec.values[i] = base + i * step;
}
return vec;
}
- static Vec256<T> set(const Vec256<T>& a, const Vec256<T>& b, int64_t count = size) {
+ static Vec256<T> set(const Vec256<T>& a, const Vec256<T>& b, int64_t count = size()) {
Vec256 vec;
- for (int64_t i = 0; i < size; i++) {
+ for (int64_t i = 0; i < size(); i++) {
if (i < count) {
vec[i] = b[i];
} else {
std::memcpy(vec.values, ptr, count * sizeof(T));
return vec;
}
- void store(void* ptr, int count = size) const {
+ void store(void* ptr, int count = size()) const {
std::memcpy(ptr, values, count * sizeof(T));
}
const T& operator[](int idx) const {
}
Vec256<T> map(T (*f)(T)) const {
Vec256<T> ret;
- for (int64_t i = 0; i != size; i++) {
+ for (int64_t i = 0; i != size(); i++) {
ret[i] = f(values[i]);
}
return ret;
}
Vec256<T> abs() const {
Vec256<T> ret;
- for (int64_t i = 0; i < size; i++) {
+ for (int64_t i = 0; i < size(); i++) {
ret[i] = values[i] < 0 ? -values[i] : values[i];
}
return ret;
}
Vec256<T> pow(const Vec256<T> &exp) const {
Vec256<T> ret;
- for (int64_t i = 0; i < size; i++) {
+ for (int64_t i = 0; i < size(); i++) {
ret[i] = std::pow(values[i], exp[i]);
}
return ret;
#define DEFINE_COMP(binary_pred) \
Vec256<T> operator binary_pred(const Vec256<T> &other) const { \
Vec256<T> vec; \
- for (int64_t i = 0; i != size; i++) { \
+ for (int64_t i = 0; i != size(); i++) { \
if (values[i] binary_pred other.values[i]) { \
std::memset(static_cast<void*>(vec.values + i), 0xFF, sizeof(T)); \
} else { \
template <class T> Vec256<T> inline operator+(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
- for (int i = 0; i != Vec256<T>::size; i++) {
+ for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = a[i] + b[i];
}
return c;
template <class T> Vec256<T> inline operator-(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
- for (int i = 0; i != Vec256<T>::size; i++) {
+ for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = a[i] - b[i];
}
return c;
template <class T> Vec256<T> inline operator*(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
- for (int i = 0; i != Vec256<T>::size; i++) {
+ for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = a[i] * b[i];
}
return c;
template <class T> Vec256<T> inline operator/(const Vec256<T> &a, const Vec256<T> &b) __ubsan_ignore_float_divide_by_zero__ {
Vec256<T> c = Vec256<T>();
- for (int i = 0; i != Vec256<T>::size; i++) {
+ for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = a[i] / b[i];
}
return c;
// either input is a NaN.
template <class T> Vec256<T> inline maximum(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
- for (int i = 0; i != Vec256<T>::size; i++) {
+ for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = (a[i] > b[i]) ? a[i] : b[i];
if (std::is_floating_point<T>::value && std::isnan(a[i])) {
// If either input is NaN, propagate a NaN.
// either input is a NaN.
template <class T> Vec256<T> inline minimum(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
- for (int i = 0; i != Vec256<T>::size; i++) {
+ for (int i = 0; i != Vec256<T>::size(); i++) {
c[i] = (a[i] < b[i]) ? a[i] : b[i];
if (std::is_floating_point<T>::value && std::isnan(a[i])) {
// If either input is NaN, propagate a NaN.
template <class T> \
Vec256<T> inline operator op(const Vec256<T> &a, const Vec256<T> &b) { \
using iT = int_same_size_t<T>; \
- iT buffer[Vec256<T>::size]; \
- for (int64_t i = 0; i != Vec256<T>::size; i++) { \
+ iT buffer[Vec256<T>::size()]; \
+ for (int64_t i = 0; i != Vec256<T>::size(); i++) { \
auto a_val = a[i]; \
auto b_val = b[i]; \
iT *i_a_ptr = reinterpret_cast<iT*>(&a_val); \
template <int64_t scale = 1, typename T = void>
c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<T>>
inline gather(T const* base_addr, const Vec256<int_same_size_t<T>>& vindex) {
- static constexpr int size = Vec256<T>::size;
+ static constexpr int size = Vec256<T>::size();
int_same_size_t<T> index_arr[size];
vindex.store(static_cast<void*>(index_arr));
T buffer[size];
c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<T>>
inline mask_gather(const Vec256<T>& src, T const* base_addr,
const Vec256<int_same_size_t<T>>& vindex, Vec256<T>& mask) {
- static constexpr int size = Vec256<T>::size;
+ static constexpr int size = Vec256<T>::size();
T src_arr[size];
int_same_size_t<T> mask_arr[size]; // use int type so we can logical and
int_same_size_t<T> index_arr[size];
template<typename dst_t, typename src_t>
struct CastImpl {
static inline Vec256<dst_t> apply(const Vec256<src_t>& src) {
- src_t src_arr[Vec256<src_t>::size];
+ src_t src_arr[Vec256<src_t>::size()];
src.store(static_cast<void*>(src_arr));
return Vec256<dst_t>::loadu(static_cast<const void*>(src_arr));
}
template <typename T>
inline Vec256<int_same_size_t<T>> convert_to_int_of_same_size(const Vec256<T>& src) {
- static constexpr int size = Vec256<T>::size;
+ static constexpr int size = Vec256<T>::size();
T src_arr[size];
src.store(static_cast<void*>(src_arr));
int_same_size_t<T> buffer[size];
// returns: Vec256<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
// Vec256<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
template <typename T>
-inline c10::guts::enable_if_t<Vec256<T>::size % 2 == 0, std::pair<Vec256<T>, Vec256<T>>>
+inline c10::guts::enable_if_t<Vec256<T>::size() % 2 == 0, std::pair<Vec256<T>, Vec256<T>>>
deinterleave2(const Vec256<T>& a, const Vec256<T>& b) {
- static constexpr int size = Vec256<T>::size;
+ static constexpr int size = Vec256<T>::size();
static constexpr int half_size = size / 2;
T a_arr[size];
T b_arr[size];
// returns: Vec256<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
// Vec256<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
template <typename T>
-inline c10::guts::enable_if_t<Vec256<T>::size % 2 == 0, std::pair<Vec256<T>, Vec256<T>>>
+inline c10::guts::enable_if_t<Vec256<T>::size() % 2 == 0, std::pair<Vec256<T>, Vec256<T>>>
interleave2(const Vec256<T>& a, const Vec256<T>& b) {
- static constexpr int size = Vec256<T>::size;
+ static constexpr int size = Vec256<T>::size();
static constexpr int half_size = size / 2;
T a_arr[size];
T b_arr[size];
namespace at {
namespace vec256 {
+// See Note [Acceptable use of anonymous namespace in header]
namespace {
#if defined(__AVX__) && !defined(_MSC_VER)
private:
__m256d values;
public:
- static constexpr int size = 4;
+ static constexpr int size() {
+ return 4;
+ }
Vec256() {}
Vec256(__m256d v) : values(v) {}
Vec256(double val) {
return Vec256<double>(base, base + step, base + 2 * step, base + 3 * step);
}
static Vec256<double> set(const Vec256<double>& a, const Vec256<double>& b,
- int64_t count = size) {
+ int64_t count = size()) {
switch (count) {
case 0:
return a;
}
return b;
}
- static Vec256<double> loadu(const void* ptr, int64_t count = size) {
- if (count == size)
+ static Vec256<double> loadu(const void* ptr, int64_t count = size()) {
+ if (count == size())
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
- __at_align32__ double tmp_values[size];
+ __at_align32__ double tmp_values[size()];
std::memcpy(
tmp_values,
reinterpret_cast<const double*>(ptr),
count * sizeof(double));
return _mm256_load_pd(tmp_values);
}
- void store(void* ptr, int count = size) const {
- if (count == size) {
+ void store(void* ptr, int count = size()) const {
+ if (count == size()) {
_mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
} else if (count > 0) {
- double tmp_values[size];
+ double tmp_values[size()];
_mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(double));
}
void convert(const double* src, double* dst, int64_t n) {
int64_t i;
#pragma unroll
- for (i = 0; i <= (n - Vec256<double>::size); i += Vec256<double>::size) {
+ for (i = 0; i <= (n - Vec256<double>::size()); i += Vec256<double>::size()) {
_mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
}
#pragma unroll
namespace at {
namespace vec256 {
+// See Note [Acceptable use of anonymous namespace in header]
namespace {
#if defined(__AVX__) && !defined(_MSC_VER)
private:
__m256 values;
public:
- static constexpr int size = 8;
+ static constexpr int size() {
+ return 8;
+ }
Vec256() {}
Vec256(__m256 v) : values(v) {}
Vec256(float val) {
base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
}
static Vec256<float> set(const Vec256<float>& a, const Vec256<float>& b,
- int64_t count = size) {
+ int64_t count = size()) {
switch (count) {
case 0:
return a;
}
return b;
}
- static Vec256<float> loadu(const void* ptr, int64_t count = size) {
- if (count == size)
+ static Vec256<float> loadu(const void* ptr, int64_t count = size()) {
+ if (count == size())
return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
- __at_align32__ float tmp_values[size];
+ __at_align32__ float tmp_values[size()];
std::memcpy(
tmp_values, reinterpret_cast<const float*>(ptr), count * sizeof(float));
return _mm256_loadu_ps(tmp_values);
}
- void store(void* ptr, int64_t count = size) const {
- if (count == size) {
+ void store(void* ptr, int64_t count = size()) const {
+ if (count == size()) {
_mm256_storeu_ps(reinterpret_cast<float*>(ptr), values);
} else if (count > 0) {
- float tmp_values[size];
+ float tmp_values[size()];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(float));
}
void convert(const float* src, float* dst, int64_t n) {
int64_t i;
#pragma unroll
- for (i = 0; i <= (n - Vec256<float>::size); i += Vec256<float>::size) {
+ for (i = 0; i <= (n - Vec256<float>::size()); i += Vec256<float>::size()) {
_mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i));
}
#pragma unroll
template <>
struct Vec256<int64_t> : public Vec256i {
- static constexpr int size = 4;
+ static constexpr int size() {
+ return 4;
+ }
using Vec256i::Vec256i;
Vec256() {}
Vec256(int64_t v) { values = _mm256_set1_epi64x(v); }
}
template <int64_t mask>
static Vec256<int64_t> blend(Vec256<int64_t> a, Vec256<int64_t> b) {
- __at_align32__ int64_t tmp_values[size];
+ __at_align32__ int64_t tmp_values[size()];
a.store(tmp_values);
if (mask & 0x01)
tmp_values[0] = _mm256_extract_epi64(b.values, 0);
return Vec256<int64_t>(base, base + step, base + 2 * step, base + 3 * step);
}
static Vec256<int64_t>
- set(Vec256<int64_t> a, Vec256<int64_t> b, int64_t count = size) {
+ set(Vec256<int64_t> a, Vec256<int64_t> b, int64_t count = size()) {
switch (count) {
case 0:
return a;
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
}
static Vec256<int64_t> loadu(const void* ptr, int64_t count) {
- __at_align32__ int64_t tmp_values[size];
+ __at_align32__ int64_t tmp_values[size()];
std::memcpy(tmp_values, ptr, count * sizeof(int64_t));
return loadu(tmp_values);
}
- void store(void* ptr, int count = size) const {
- if (count == size) {
+ void store(void* ptr, int count = size()) const {
+ if (count == size()) {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
} else if (count > 0) {
- __at_align32__ int64_t tmp_values[size];
+ __at_align32__ int64_t tmp_values[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
}
template <>
struct Vec256<int32_t> : public Vec256i {
- static constexpr int size = 8;
+ static constexpr int size() {
+ return 8;
+ }
using Vec256i::Vec256i;
Vec256() {}
Vec256(int32_t v) { values = _mm256_set1_epi32(v); }
base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
}
static Vec256<int32_t>
- set(Vec256<int32_t> a, Vec256<int32_t> b, int32_t count = size) {
+ set(Vec256<int32_t> a, Vec256<int32_t> b, int32_t count = size()) {
switch (count) {
case 0:
return a;
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
}
static Vec256<int32_t> loadu(const void* ptr, int32_t count) {
- __at_align32__ int32_t tmp_values[size];
+ __at_align32__ int32_t tmp_values[size()];
std::memcpy(tmp_values, ptr, count * sizeof(int32_t));
return loadu(tmp_values);
}
- void store(void* ptr, int count = size) const {
- if (count == size) {
+ void store(void* ptr, int count = size()) const {
+ if (count == size()) {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
} else if (count > 0) {
- __at_align32__ int32_t tmp_values[size];
+ __at_align32__ int32_t tmp_values[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
}
#ifndef _MSC_VER
# pragma unroll
#endif
- for (i = 0; i <= (n - Vec256<int32_t>::size); i += Vec256<int32_t>::size) {
+ for (i = 0; i <= (n - Vec256<int32_t>::size()); i += Vec256<int32_t>::size()) {
auto input_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
auto output_vec = _mm256_cvtepi32_ps(input_vec);
_mm256_storeu_ps(reinterpret_cast<float*>(dst + i), output_vec);
#ifndef _MSC_VER
# pragma unroll
#endif
- for (i = 0; i <= (n - Vec256<double>::size); i += Vec256<double>::size) {
+ for (i = 0; i <= (n - Vec256<double>::size()); i += Vec256<double>::size()) {
auto input_128_vec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
auto output_vec = _mm256_cvtepi32_pd(input_128_vec);
_mm256_storeu_pd(reinterpret_cast<double*>(dst + i), output_vec);
template <>
struct Vec256<int16_t> : public Vec256i {
- static constexpr int size = 16;
+ static constexpr int size() {
+ return 16;
+ }
using Vec256i::Vec256i;
Vec256() {}
Vec256(int16_t v) { values = _mm256_set1_epi16(v); }
}
template <int64_t mask>
static Vec256<int16_t> blend(Vec256<int16_t> a, Vec256<int16_t> b) {
- __at_align32__ int16_t tmp_values[size];
+ __at_align32__ int16_t tmp_values[size()];
a.store(tmp_values);
if (mask & 0x01)
tmp_values[0] = _mm256_extract_epi16(b.values, 0);
base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
}
static Vec256<int16_t>
- set(Vec256<int16_t> a, Vec256<int16_t> b, int16_t count = size) {
+ set(Vec256<int16_t> a, Vec256<int16_t> b, int16_t count = size()) {
switch (count) {
case 0:
return a;
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
}
static Vec256<int16_t> loadu(const void* ptr, int16_t count) {
- __at_align32__ int16_t tmp_values[size];
+ __at_align32__ int16_t tmp_values[size()];
std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
return loadu(tmp_values);
}
- void store(void* ptr, int count = size) const {
- if (count == size) {
+ void store(void* ptr, int count = size()) const {
+ if (count == size()) {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
} else if (count > 0) {
- __at_align32__ int16_t tmp_values[size];
+ __at_align32__ int16_t tmp_values[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
}
template <typename T>
Vec256<T> inline intdiv_256(const Vec256<T>& a, const Vec256<T>& b) {
- T values_a[Vec256<T>::size];
- T values_b[Vec256<T>::size];
+ T values_a[Vec256<T>::size()];
+ T values_b[Vec256<T>::size()];
a.store(values_a);
b.store(values_b);
- for (int i = 0; i != Vec256<T>::size; i++) {
+ for (int i = 0; i != Vec256<T>::size(); i++) {
values_a[i] /= values_b[i];
}
return Vec256<T>::loadu(values_a);
}
template <typename F>
- inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size) {
+ inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
const Vec self_vec_i = Vec::loadu(self_i, count);
// The only way to parallelize and avoid locking requires parallelizing
// over the columns of the input, i.e. we compute the gradient for the
// first section of each vector independentaly of the second section, etc.
- at::parallel_for(0, m / Vec::size, internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t l, int64_t end) {
- const scalar_t * self_l = self_start + l * Vec::size;
- scalar_t * res_l = res_start + l * Vec::size;
+ at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t l, int64_t end) {
+ const scalar_t * self_l = self_start + l * Vec::size();
+ scalar_t * res_l = res_start + l * Vec::size();
- for (const scalar_t * const res_end = res_start + end * Vec::size; res_l != res_end; self_l += Vec::size, res_l += Vec::size) {
+ for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; self_l += Vec::size(), res_l += Vec::size()) {
backward_down_column<F>(self_l, res_l, grad_start, dist_start, pvec, n, m, gs);
}
});
- const int64_t remainder = m % Vec::size;
+ const int64_t remainder = m % Vec::size();
if (remainder) {
backward_down_column<F>(self_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, pvec, n, m, gs, remainder);
}
// So we store the necessary vectors to temporary arrays and use the helper
// mask_scatter_add defined above.
- integer_t i_gInp_nw_offset_arr[iVec::size];
- integer_t i_gInp_ne_offset_arr[iVec::size];
- integer_t i_gInp_sw_offset_arr[iVec::size];
- integer_t i_gInp_se_offset_arr[iVec::size];
+ integer_t i_gInp_nw_offset_arr[iVec::size()];
+ integer_t i_gInp_ne_offset_arr[iVec::size()];
+ integer_t i_gInp_sw_offset_arr[iVec::size()];
+ integer_t i_gInp_se_offset_arr[iVec::size()];
i_gInp_nw_offset.store(i_gInp_nw_offset_arr);
i_gInp_ne_offset.store(i_gInp_ne_offset_arr);
i_gInp_sw_offset.store(i_gInp_sw_offset_arr);
i_gInp_se_offset.store(i_gInp_se_offset_arr);
- integer_t i_nw_mask_arr[iVec::size];
- integer_t i_ne_mask_arr[iVec::size];
- integer_t i_sw_mask_arr[iVec::size];
- integer_t i_se_mask_arr[iVec::size];
+ integer_t i_nw_mask_arr[iVec::size()];
+ integer_t i_ne_mask_arr[iVec::size()];
+ integer_t i_sw_mask_arr[iVec::size()];
+ integer_t i_se_mask_arr[iVec::size()];
nw_mask.store(i_nw_mask_arr);
ne_mask.store(i_ne_mask_arr);
sw_mask.store(i_sw_mask_arr);
se_mask.store(i_se_mask_arr);
- scalar_t gInp_corner_arr[Vec::size];
+ scalar_t gInp_corner_arr[Vec::size()];
auto gx = Vec(0), gy = Vec(0);
#ifndef _MSC_VER
gx = gx * gx_mult;
gy = gy * gy_mult;
- constexpr int64_t step = Vec::size;
+ constexpr int64_t step = Vec::size();
auto interleaved_gGrid = interleave2(gx, gy);
auto gGrid_ptr = gGrid_slice.data() + offset * 2;
std::get<0>(interleaved_gGrid).store(gGrid_ptr,
auto i_gInp_offset = i_y_nearest * iVec(inp_W) + i_x_nearest; // gInp is contiguous
- integer_t mask_arr[iVec::size];
+ integer_t mask_arr[iVec::size()];
i_mask.store(mask_arr);
- integer_t gInp_offset_arr[iVec::size];
+ integer_t gInp_offset_arr[iVec::size()];
i_gInp_offset.store(gInp_offset_arr);
#ifndef _MSC_VER
using Vec = Vec256<scalar_t>;
using iVec = Vec256<int_same_size_t<scalar_t>>;
- constexpr int64_t step = Vec::size;
+ constexpr int64_t step = Vec::size();
// Loop over each output pixel in grid.
// We consider the following three cases (after slicing out the batch
static inline void vectorized_binary_loop(char** data, int64_t n, func_t op, vec_func_t vop) {
VEC_LOOP_HEADER(func_t, data)
int64_t i = 0;
- for (; i <= n - 2 * Vec::size; i += 2 * Vec::size) {
+ for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
auto a1 = Vec::loadu(in1_ptr + i * sizeof(scalar_t));
- auto a2 = Vec::loadu(in1_ptr + (i + Vec::size) * sizeof(scalar_t));
+ auto a2 = Vec::loadu(in1_ptr + (i + Vec::size()) * sizeof(scalar_t));
auto b1 = Vec::loadu(in2_ptr + i * sizeof(scalar_t));
- auto b2 = Vec::loadu(in2_ptr + (i + Vec::size) * sizeof(scalar_t));
+ auto b2 = Vec::loadu(in2_ptr + (i + Vec::size()) * sizeof(scalar_t));
auto out1 = vop(a1, b1);
auto out2 = vop(a2, b2);
out1.store(out_ptr + i * sizeof(scalar_t));
- out2.store(out_ptr + (i + Vec::size) * sizeof(scalar_t));
+ out2.store(out_ptr + (i + Vec::size()) * sizeof(scalar_t));
}
int64_t strides[] = { sizeof(scalar_t), sizeof(scalar_t), sizeof(scalar_t) };
binary_loop(data, strides, i, n, op);
VEC_LOOP_HEADER(func_t, data)
int64_t i = 0;
auto a = Vec(*(scalar_t*)in1_ptr);
- for (; i <= n - 2 * Vec::size; i += 2 * Vec::size) {
+ for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
auto b1 = Vec::loadu(in2_ptr + i * sizeof(scalar_t));
- auto b2 = Vec::loadu(in2_ptr + (i + Vec::size) * sizeof(scalar_t));
+ auto b2 = Vec::loadu(in2_ptr + (i + Vec::size()) * sizeof(scalar_t));
auto out1 = vop(a, b1);
auto out2 = vop(a, b2);
out1.store(out_ptr + i * sizeof(scalar_t));
- out2.store(out_ptr + (i + Vec::size) * sizeof(scalar_t));
+ out2.store(out_ptr + (i + Vec::size()) * sizeof(scalar_t));
}
int64_t strides[] = { sizeof(scalar_t), 0, sizeof(scalar_t) };
binary_loop(data, strides, i, n, op);
VEC_LOOP_HEADER(func_t, data)
int64_t i = 0;
auto b = Vec(*(scalar_t*)in2_ptr);
- for (; i <= n - 2 * Vec::size; i += 2 * Vec::size) {
+ for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
auto a1 = Vec::loadu(in1_ptr + i * sizeof(scalar_t));
- auto a2 = Vec::loadu(in1_ptr + (i + Vec::size) * sizeof(scalar_t));
+ auto a2 = Vec::loadu(in1_ptr + (i + Vec::size()) * sizeof(scalar_t));
auto out1 = vop(a1, b);
auto out2 = vop(a2, b);
out1.store(out_ptr + i * sizeof(scalar_t));
- out2.store(out_ptr + (i + Vec::size) * sizeof(scalar_t));
+ out2.store(out_ptr + (i + Vec::size()) * sizeof(scalar_t));
}
int64_t strides[] = { sizeof(scalar_t), sizeof(scalar_t), 0 };
binary_loop(data, strides, i, n, op);
char* in_ptr = data[1];
Vec acc[4];
for (int j = 0; j < 4; j++) {
- acc[j] = Vec::loadu(in_ptr + j * Vec::size * sizeof(scalar_t));
+ acc[j] = Vec::loadu(in_ptr + j * Vec::size() * sizeof(scalar_t));
}
for (int64_t i = 1; i < n; i++) {
const char* ptr = in_ptr + stride * i;
- acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size * sizeof(scalar_t))));
- acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size * sizeof(scalar_t))));
- acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size * sizeof(scalar_t))));
- acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size * sizeof(scalar_t))));
+ acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size() * sizeof(scalar_t))));
+ acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size() * sizeof(scalar_t))));
+ acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size() * sizeof(scalar_t))));
+ acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size() * sizeof(scalar_t))));
}
if (reduce) {
- scalar_t buffer[Vec::size];
+ scalar_t buffer[Vec::size()];
acc[0] = vop(vop(acc[0], acc[1]), vop(acc[2], acc[3]));
acc[0].store(buffer);
- for (int j = 1; j < Vec::size; j++) {
+ for (int j = 1; j < Vec::size(); j++) {
buffer[0] = op(buffer[0], buffer[j]);
}
auto dst = (scalar_t*)out_ptr;
*dst = op(*dst, buffer[0]);
} else {
for (int j = 0; j < 4; j++) {
- auto dst = out_ptr + j * Vec::size * sizeof(scalar_t);
+ auto dst = out_ptr + j * Vec::size() * sizeof(scalar_t);
acc[j] = vop(acc[j], Vec::loadu(dst));
acc[j].store(dst);
}
template <typename func_t, typename vec_func_t>
static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
VEC_HEADER(func_t)
- int64_t vector_stride = 4 * Vec::size * sizeof(scalar_t);
- int64_t count = n / (4 * Vec::size);
+ int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
+ int64_t count = n / (4 * Vec::size());
if (count > 0) {
reduction128(data, count, vector_stride, op, vop, /*reduce=*/true);
}
char* ptrs[3] = { data[0], data[0], data[1] };
int64_t strides[] = { 0, 0, sizeof(scalar_t) };
- binary_loop(ptrs, strides, count * 4 * Vec::size, n, op);
+ binary_loop(ptrs, strides, count * 4 * Vec::size(), n, op);
}
// computes the reduction out = op(out, in)
static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
VEC_HEADER(func_t)
- // reduce down each column of 4 * Vec::size elements (128 bytes)
+ // reduce down each column of 4 * Vec::size() elements (128 bytes)
int64_t outer_stride[2] = { 128, 128 };
- UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size), [&] {
+ UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
reduction128(data, size0, inner_stride, op, vop, /*reduce=*/false);
});
// reduce down the remaining columns
int64_t step[] = { sizeof(scalar_t), sizeof(scalar_t) };
- int64_t remaining = size1 % (4 * Vec::size);
+ int64_t remaining = size1 % (4 * Vec::size());
UNARY_OUTER_LOOP(data, step, remaining, [&] {
char* ptrs[3] = { data[0], data[0], data[1] };
int64_t strides[] = { 0, 0, inner_stride };
if (pval == 1){
for (int row = 0; row < rows; row ++) {
for (int j = 0; j != 4; j++) {
- auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size]);
+ auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size()]);
acc[j] = acc[j] + val.abs();
}
}
else if (pval == 2) {
for (int row = 0; row < rows; row ++) {
for (int j = 0; j != 4; j++) {
- auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size]);
+ auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size()]);
acc[j] = acc[j] + val * val;
}
}
else if (pval == 3) {
for (int row = 0; row < rows; row ++) {
for (int j = 0; j != 4; j++) {
- auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size]);
+ auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size()]);
acc[j] = acc[j] + (val * val * val).abs();
}
}
}
scalar_t buf[WIDTH] = {0};
for (int j = 0; j != 4; j++) {
- acc[j].store(&buf[j * Vec::size]);
+ acc[j].store(&buf[j * Vec::size()]);
}
for (int i = 0; i < WIDTH; i++) {
result += buf[i];
int64_t outer_size,
int64_t dim_size) {
using Vec = vec256::Vec256<scalar_t>;
- static constexpr int64_t CHUNK_SIZE = (128 / sizeof(scalar_t)) * Vec::size;
+ static constexpr int64_t CHUNK_SIZE = (128 / sizeof(scalar_t)) * Vec::size();
int64_t grain_size = internal::GRAIN_SIZE / (16 * dim_size * CHUNK_SIZE);
if (grain_size < CHUNK_SIZE)
grain_size = CHUNK_SIZE;
int64_t _sigmoid(float* x, float* y, int64_t size) {
using Vec = Vec256<float>;
int64_t i = 0;
- for (; i < size - (size % (2 * Vec::size)); i += 2 * Vec::size) {
+ for (; i < size - (size % (2 * Vec::size())); i += 2 * Vec::size()) {
Vec ret = Vec::loadu(y + i);
- Vec ret2 = Vec::loadu(y + i + Vec::size);
+ Vec ret2 = Vec::loadu(y + i + Vec::size());
ret = ret.neg();
ret2 = ret2.neg();
#if defined(__AVX2__) && !defined(_MSC_VER)
ret = ret.reciprocal();
ret2 = ret2.reciprocal();
ret.store(x + i);
- ret2.store(x + i + Vec::size);
+ ret2.store(x + i + Vec::size());
}
return i;
}
int64_t _sigmoid(double* x, double* y, int64_t size) {
using Vec = Vec256<double>;
int64_t i = 0;
- for (; i < size - (size % (2 * Vec::size)); i += 2 * Vec::size) {
+ for (; i < size - (size % (2 * Vec::size())); i += 2 * Vec::size()) {
Vec ret = Vec::loadu(y + i);
- Vec ret2 = Vec::loadu(y + i + Vec::size);
+ Vec ret2 = Vec::loadu(y + i + Vec::size());
ret = ret.neg();
ret2 = ret2.neg();
ret = ret.exp();
ret = ret.reciprocal();
ret2 = ret2.reciprocal();
ret.store(x + i);
- ret2.store(x + i + Vec::size);
+ ret2.store(x + i + Vec::size());
}
return i;
}
if (stridex == 1 && stridey == 1) {
i = _sigmoid(x, y, size);
}
- for (; i < size; i += Vec::size) {
- scalar_t buffer[Vec::size];
- int64_t width = Vec::size;
+ for (; i < size; i += Vec::size()) {
+ scalar_t buffer[Vec::size()];
+ int64_t width = Vec::size();
width = std::min(width, size - i);
for (int64_t j = 0; j < width; j++) {
buffer[j] = y[stridey * (i + j)];