/// OffsetCalculator calculates the offset in bytes of a linear index for NARGS
/// operands that share the same shape, but may have different strides.
-template <int NARGS>
+template <int NARGS, typename index_t = uint32_t>
struct OffsetCalculator {
static constexpr int MAX_DIMS = 25;
// The offset for each argument (in bytes). Wrapper around fixed-size array.
- using offset_type = at::cuda::Array<uint32_t, NARGS>;
+ using offset_type = at::cuda::Array<index_t, NARGS>;
OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides) : dims(dims) {
AT_CHECK(dims <= MAX_DIMS, "tensor has too many (>25) dims");
for (int i = 0; i < MAX_DIMS; ++i) {
if (i < dims) {
- sizes_[i] = IntDivider<uint32_t>(sizes[i]);
+ sizes_[i] = IntDivider<index_t>(sizes[i]);
} else {
- sizes_[i] = IntDivider<uint32_t>(1);
+ sizes_[i] = IntDivider<index_t>(1);
}
for (int arg = 0; arg < NARGS; arg++) {
strides_[i][arg] = i < dims ? strides[arg][i] : 0;
}
}
- C10_HOST_DEVICE offset_type get(uint32_t linear_idx) const {
+ C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
offset_type offsets;
#pragma unroll
for (int arg = 0; arg < NARGS; arg++) {
}
int dims;
- IntDivider<uint32_t> sizes_[MAX_DIMS];
- uint32_t strides_[MAX_DIMS][NARGS];
+ IntDivider<index_t> sizes_[MAX_DIMS];
+ index_t strides_[MAX_DIMS][NARGS];
};
// Modified from https://stackoverflow.com/questions/7943525/is-it-possible-to-figure-out-the-parameter-type-and-return-type-of-a-lambda
-// For generic types, directly use the result of the signature of its 'operator()'
+// Fallback, anything with an operator()
template <typename T>
struct function_traits : public function_traits<decltype(&T::operator())> {
};
+// Pointers to class members that are themselves functors.
+// For example, in the following code:
+// template <typename func_t>
+// struct S {
+// func_t f;
+// };
+// template <typename func_t>
+// S<func_t> make_s(func_t f) {
+// return S<func_t> { .f = f };
+// }
+//
+// auto s = make_s([] (int, float) -> double { /* ... */ });
+//
+// function_traits<decltype(&s::f)> traits;
+template <typename ClassType, typename T>
+struct function_traits<T ClassType::*> : public function_traits<T> {
+};
+
+// Const class member functions
template <typename ClassType, typename ReturnType, typename... Args>
-struct function_traits<ReturnType(ClassType::*)(Args...) const> {
+struct function_traits<ReturnType(ClassType::*)(Args...) const> : public function_traits<ReturnType(Args...)> {
+};
+
+// Free functions
+template <typename ReturnType, typename... Args>
+struct function_traits<ReturnType(Args...)> {
// arity is the number of arguments.
enum { arity = sizeof...(Args) };
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"std only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "std only supports floating-point dtypes");
- if (self.type().backend() != Backend::CPU) {
- // TODO(btv): implement multi-dim `std` and `var` on CUDA.
- AT_CHECK(dim.size() == 1, "`std` across arbitrarily many dimensions is not yet supported for CUDA.")
- int64_t one_dim = maybe_wrap_dim(dim[0], self.dim());
- if (_dimreduce_return_trivial(result, self, std::numeric_limits<double>::quiet_NaN(), one_dim, keepdim)) {
- return result;
- } else {
- return at::legacy::th::_th_std_out(result, self, one_dim, unbiased, keepdim);
- }
- }
ScalarType dtype = get_dtype(result, self, {}, true);
auto iter = make_reduction("std", result, self, dim, keepdim, dtype);
- std_stub(iter->device_type(), *iter, unbiased);
+ if (iter->numel() == 0) {
+ result.fill_(NAN);
+ } else {
+ std_stub(iter->device_type(), *iter, unbiased);
+ }
return result;
}
--- /dev/null
+#pragma once
+// Please note that this file is
+// used across both CPU and GPU.
+
+#include <c10/macros/Macros.h>
+#if defined(__CUDACC__)
+#include <THC/THCDeviceUtils.cuh>
+#include <ATen/native/cuda/DeviceSqrt.cuh>
+#elif defined(__HIPCC__)
+#include <THH/THHDeviceUtils.cuh>
+#include <ATen/native/hip/DeviceSqrt.cuh>
+#else
+#include <cmath>
+#define device_sqrt std::sqrt
+#endif
+
+namespace at { namespace native {
+
+template <typename scalar_t>
+struct WelfordData {
+ scalar_t mean;
+ scalar_t m2;
+ int64_t n;
+ C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0) {}
+ C10_DEVICE WelfordData(scalar_t mean, scalar_t m2, int64_t n) : mean(mean), m2(m2), n(n) {}
+};
+
+
+template <typename scalar_t, typename acc_scalar_t>
+struct WelfordOps {
+ bool unbiased;
+ public:
+ using acc_t = WelfordData<acc_scalar_t>;
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data) const {
+ acc_scalar_t delta = data - acc.mean;
+ acc_scalar_t new_mean = acc.mean + delta / (acc.n + 1);
+ acc_scalar_t new_delta = data - new_mean;
+ return {
+ new_mean,
+ acc.m2 + delta * new_delta,
+ acc.n + 1
+ };
+ }
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+ if (a.n == 0) {
+ return b;
+ }
+ if (b.n == 0) {
+ return a;
+ }
+ acc_scalar_t delta = b.mean - a.mean;
+ int64_t new_count = a.n + b.n;
+ acc_scalar_t nb_over_n = (scalar_t)b.n / new_count;
+ return {
+ a.mean + delta * nb_over_n,
+ a.m2 + b.m2 + delta * delta * a.n * nb_over_n,
+ new_count
+ };
+ }
+ inline C10_DEVICE scalar_t project(acc_t acc) const {
+ int64_t divisor = unbiased ? (acc.n - 1) : acc.n;
+ return (divisor > 0) ? (scalar_t)device_sqrt(acc.m2 / divisor) : (scalar_t)NAN;
+ }
+#if defined(__CUDACC__) || defined(__HIPCC__)
+ inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
+ return {
+ WARP_SHFL_DOWN(acc.mean, offset)
+ , WARP_SHFL_DOWN(acc.m2, offset)
+ , WARP_SHFL_DOWN(acc.n, offset)
+ };
+ }
+#endif
+ WelfordOps(bool unbiased) : unbiased(unbiased) {
+ }
+};
+
+template <typename acc_t, typename factor_t>
+struct MeanOps {
+ factor_t factor;
+
+ inline C10_DEVICE acc_t reduce(acc_t a, acc_t b) const {
+ return a + b;
+ }
+
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+ return reduce(a, b);
+ }
+
+ inline C10_DEVICE acc_t project(acc_t a) const {
+ return a * factor;
+ }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
+ return WARP_SHFL_DOWN(data, offset);
+ }
+#endif
+
+ MeanOps(factor_t factor): factor(factor) {
+ }
+};
+
+
+}} // namespace at::native
// acc_t is a type that contains all the necessary data
// to continue reducing.
//
-// Then:
+// ops_t is such that &ops_t::reduce, &ops_t::combine, and &ops_t::project exist and satisfy
+// the following.
// reduce: (acc_t, data_t) -> acc_t adds one data point to the accumulated value.
// combine: (acc_t, acc_t) -> acc_t combines two accumulated values into one.
// project: acc_t -> data_t finishes the reduction, getting the required output.
// If, on the other hand, there is only one, then we split the input into
// into several pieces, reduce each separately, and then combine them.
-template <typename rf_t,
- typename cf_t,
- typename pf_t,
- typename init_t>
-void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &combine, pf_t const &project, init_t init) {
+template <typename ops_t, typename init_t>
+void binary_kernel_reduce(TensorIterator& iter, ops_t ops, init_t init) {
+ using rf_t = decltype(&ops_t::reduce);
+ using cf_t = decltype(&ops_t::combine);
+ using pf_t = decltype(&ops_t::project);
using r_traits = binary_function_traits<rf_t>;
using c_traits = binary_function_traits<cf_t>;
using p_traits = unary_function_traits<pf_t>;
at::parallel_for(0, numel, serial ? (1 + numel) : internal::GRAIN_SIZE,
[&](int64_t begin, int64_t end) {
auto &acc = buffer[at::get_thread_num()];
- sub_iter.serial_for_each([&acc, &reduce, &init](int ntensors, char** data, const int64_t* strides, int64_t size) {
+ sub_iter.serial_for_each([&acc, &ops, &init](int ntensors, char** data, const int64_t* strides, int64_t size) {
AT_ASSERT(ntensors == 2);
char *in = data[1];
int64_t stride = strides[1];
acc = init;
}
for (int64_t i = 0; i < size; ++i) {
- acc = reduce(*acc, *(data_t*)in);
+ acc = ops.reduce(*acc, *(data_t*)in);
in += stride;
}
}, {begin, end});
acc_t acc = init;
for (int i = 0; i < max_threads; ++i) {
if (buffer[i]) {
- acc = combine(acc, *buffer[i]);
+ acc = ops.combine(acc, *buffer[i]);
}
}
char *out = (char *)sub_iter.data_ptr(0);
- *(data_t*)out = project(acc);
+ *(data_t*)out = ops.project(acc);
});
}
#include <ATen/cpu/vec256/vec256.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/TensorIterator.h>
+#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/cpu/Reduce.h>
#include <c10/util/Optional.h>
scalar_t factor = scalar_t(iter.num_output_elements()) / iter.numel();
binary_kernel_reduce(
iter,
- [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
- [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
- [factor](scalar_t a) -> scalar_t { return a*factor; }, scalar_t(0));
+ MeanOps<scalar_t, scalar_t> {factor},
+ scalar_t(0)
+ );
});
}
-struct WelfordData {
- double mean;
- double m2;
- int64_t n;
- WelfordData() : mean(0), m2(0), n(0) {}
- WelfordData(double mean, double m2, int64_t n) : mean(mean), m2(m2), n(n) {}
-};
-
static void std_kernel_impl(TensorIterator &iter, bool unbiased) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&] {
binary_kernel_reduce(
iter,
- [](WelfordData acc, scalar_t data) -> WelfordData {
- double delta = data - acc.mean;
- double new_mean = acc.mean + delta / (acc.n + 1);
- double new_delta = data - new_mean;
- return {
- new_mean,
- acc.m2 + delta * new_delta,
- acc.n + 1
- };
- },
- [](WelfordData a, WelfordData b) -> WelfordData {
- if (a.n == 0) {
- return b;
- }
- if (b.n == 0) {
- return a;
- }
- double delta = b.mean - a.mean;
- int64_t new_count = a.n + b.n;
- double nb_over_n = (double)b.n / new_count;
- return {
- a.mean + delta * nb_over_n,
- a.m2 + b.m2 + delta * delta * a.n * nb_over_n,
- new_count
- };
- },
- [unbiased](WelfordData acc) -> scalar_t {
- int64_t divisor = unbiased ? (acc.n - 1) : acc.n;
- return (divisor > 0) ? std::sqrt(acc.m2 / divisor) : NAN;
- },
- WelfordData()
+ WelfordOps<scalar_t, double> { unbiased },
+ WelfordData<double>()
);
});
}
--- /dev/null
+#pragma once
+
+namespace at { namespace native {
+#if defined(__HIP_PLATFORM_HCC__)
+// take these out when ROCm implements std:: math functions
+#include <math.h>
+template <typename scalar_t>
+static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
+
+template <>
+__forceinline__ __device__ float device_sqrt(float val) {
+ return ::sqrtf(val);
+}
+
+template <>
+__forceinline__ __device__ double device_sqrt(double val) {
+ return ::sqrt(val);
+}
+#else
+template<typename scalar_t>
+__forceinline__ __device__ double device_sqrt(scalar_t val) {
+ return std::sqrt(val);
+}
+#endif
+}}
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/native/cuda/DeviceSqrt.cuh>
#include <ATen/native/cuda/Normalization.cuh>
#if defined(__HIP_PLATFORM_HCC__)
constexpr int WARP_SIZE = 64;
-
-// take these out when ROCm implements std:: math functions
-#include <math.h>
-template <typename scalar_t>
-static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
-
-template <>
-__forceinline__ __device__ float device_sqrt(float val) {
- return ::sqrtf(val);
-}
-
-template <>
-__forceinline__ __device__ double device_sqrt(double val) {
- return ::sqrt(val);
-}
-
#else
constexpr int WARP_SIZE = 32;
-
-template<typename scalar_t>
-__forceinline__ __device__ double device_sqrt(scalar_t val) {
- return std::sqrt(val);
-}
#endif
// The maximum number of threads in a block
#pragma once
+#include <assert.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Array.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCGeneral.hpp>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
+#include <functional>
#include <iosfwd>
+#include <tuple>
+#include <type_traits>
+#include <utility>
namespace at { namespace native {
return element_size_bytes * NUM_THREADS;
}
- int global_memory_size() const {
+ int64_t global_memory_size() const {
if (!should_global_reduce()) {
return 0;
}
- int size = element_size_bytes * num_outputs * ctas_per_output;
+ auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
if (!should_warp_reduce()) {
size *= block().x;
}
reduction.run();
}
-static OffsetCalculator<2> make_output_calculator(const TensorIterator& iter) {
+template <typename index_t>
+static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) {
int num_reduce_dims = iter.num_reduce_dims();
int num_output_dims = iter.ndim() - num_reduce_dims;
std::array<const int64_t*, 2> strides = {
iter.strides(1).data() + num_reduce_dims,
};
auto shape = iter.shape().data() + num_reduce_dims;
- return OffsetCalculator<2>(num_output_dims, shape, strides.data());
+ return OffsetCalculator<2, index_t>(num_output_dims, shape, strides.data());
}
-static OffsetCalculator<1> make_input_calculator(const TensorIterator& iter) {
+template <typename index_t>
+static OffsetCalculator<1, index_t> make_input_calculator(const TensorIterator& iter) {
int num_reduce_dims = iter.num_reduce_dims();
std::array<const int64_t*, 1> strides = {
iter.strides(1).data(),
};
- return OffsetCalculator<1>(num_reduce_dims, iter.shape().data(), strides.data());
+ return OffsetCalculator<1, index_t>(num_reduce_dims, iter.shape().data(), strides.data());
}
-template <int vt, typename func_t>
-__device__ void strided_iterate(func_t f, int begin, int end, int stride) {
+template <int vt, typename index_t, typename func_t>
+__device__ void strided_iterate(func_t f, index_t begin, index_t end, index_t stride) {
if (begin + (vt - 1) * stride < end) {
#pragma unroll
- for (int i = 0; i < vt; i++) {
+ for (index_t i = 0; i < vt; i++) {
f(i, begin + i * stride);
}
} else {
#pragma unroll
- for (int i = 0; i < vt; i++) {
- int idx = begin + i * stride;
+ for (index_t i = 0; i < vt; i++) {
+ index_t idx = begin + i * stride;
if (idx < end) {
f(i, idx);
}
}
}
-template <int vt, typename type_t, typename foo_t>
-__device__ Array<type_t, vt> load_memory(const type_t* in, int begin, int end, int stride, foo_t foo) {
+template <int vt, typename index_t, typename type_t, typename foo_t>
+__device__ Array<type_t, vt> load_memory(const type_t* in, index_t begin, index_t end, index_t stride, foo_t foo) {
Array<type_t, vt> res;
- strided_iterate<vt>([&](int i, int idx) {
+ strided_iterate<vt>([&](index_t i, index_t idx) {
res[i] = in[foo(idx)];
}, begin, end, stride);
return res;
}
-template <int vt, typename type_t>
-__device__ Array<type_t, vt> load_memory(const type_t* in, int begin, int end, int stride) {
- return load_memory<vt>(in, begin, end, stride, [](int idx) { return idx; });
+template <int vt, typename index_t, typename type_t>
+__device__ Array<type_t, vt> load_memory(const type_t* in, index_t begin, index_t end, index_t stride) {
+ return load_memory<vt, index_t>(in, begin, end, stride, [](index_t idx) { return idx; });
}
-template <typename scalar_t, typename func_t, typename pre_func_t,
- typename post_func_t, typename out_scalar_t=scalar_t>
+template <typename out_scalar_t, typename func_t>
+struct func_wrapper_t {
+ using arg_t = typename binary_function_traits<func_t>::arg2_t;
+ func_t reduce;
+ func_t combine;
+ static inline __device__ out_scalar_t project(arg_t arg) {
+ return (out_scalar_t) arg;
+ }
+ static inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
+ return WARP_SHFL_DOWN(arg, offset);
+ }
+
+ func_wrapper_t(const func_t& op) : reduce(op), combine(op) {
+ }
+};
+
+template <typename scalar_t, typename func_t>
+func_wrapper_t<scalar_t, func_t> func_wrapper(const func_t& op) {
+ using arg_t = typename binary_function_traits<func_t>::arg2_t;
+ return func_wrapper_t<scalar_t, func_t> { op };
+}
+
+template <typename scalar_t, typename ops_t, typename index_t, typename out_scalar_t=scalar_t>
struct ReduceOp {
- using traits = binary_function_traits<func_t>;
- using arg_t = typename traits::arg2_t;
+ using traits = binary_function_traits<decltype(&ops_t::reduce)>;
+ using arg_t = typename std::remove_const<typename std::remove_reference<typename traits::arg1_t>::type>::type;
- using InputCalculator = OffsetCalculator<1>;
- using OutputCalculator = OffsetCalculator<2>;
+ using InputCalculator = OffsetCalculator<1, index_t>;
+ using OutputCalculator = OffsetCalculator<2, index_t>;
static constexpr int vt0 = 4;
+ static constexpr bool can_accumulate_in_output =
+ std::is_convertible<arg_t, out_scalar_t>::value;
+
- func_t op;
- pre_func_t pre_op;
- post_func_t post_op;
+ ops_t ops;
arg_t ident;
ReduceConfig config;
InputCalculator input_calc;
int* semaphores;
bool accumulate;
- ReduceOp(func_t op, ReduceConfig config, InputCalculator input_calc, OutputCalculator output_calc,
- const void* src, void* dst, void* buffer, int* semaphores, pre_func_t pre_op,
- post_func_t post_op)
- : op(op)
- , pre_op(pre_op)
- , post_op(post_op)
+ ReduceOp(ops_t ops, ReduceConfig config, InputCalculator input_calc, OutputCalculator output_calc,
+ const void* src, void* dst, void* buffer, int* semaphores, arg_t ident)
+ : ops(ops)
, config(config)
, input_calc(input_calc)
, output_calc(output_calc)
, src(src)
, dst(dst)
, buffer(buffer)
- , semaphores(semaphores) {
+ , semaphores(semaphores)
+ , ident(ident) {
}
C10_DEVICE void run() const {
- int output_idx = config.output_idx();
- int input_idx = config.input_idx();
+ index_t output_idx = config.output_idx();
+ index_t input_idx = config.input_idx();
auto base_offsets = output_calc.get(output_idx);
arg_t value = ident;
if (config.should_global_reduce()) {
value = global_reduce(value, out);
} else if (config.should_store(output_idx)) {
- value = post_op(value);
if (accumulate) {
- value = op(*out, value);
+ value = accumulate_in_output<can_accumulate_in_output>(out, value);
}
- *out = value;
+ *out = ops.project(value);
}
}
- C10_DEVICE Array<scalar_t, vt0> load_inputs(const scalar_t* data, int offset) const {
- int end = config.num_inputs;
- int stride = input_calc.strides_[0][0] / sizeof(scalar_t);
+ C10_DEVICE Array<scalar_t, vt0> load_inputs(const scalar_t* data, index_t offset) const {
+ index_t end = config.num_inputs;
+ index_t stride = input_calc.strides_[0][0] / sizeof(scalar_t);
if (input_calc.dims == 1) {
- return load_memory<vt0>(data, offset, end, config.step_input, [&](int idx) {
+ return load_memory<vt0, index_t>(data, offset, end, config.step_input, [&](index_t idx) {
return idx * stride;
});
} else {
- return load_memory<vt0>(data, offset, end, config.step_input, [&](int idx) {
+ return load_memory<vt0, index_t>(data, offset, end, config.step_input, [&](index_t idx) {
return input_calc.get(idx)[0] / sizeof(scalar_t);
});
}
}
- C10_DEVICE arg_t thread_reduce_once(const scalar_t* data, int offset) const {
+ C10_DEVICE arg_t thread_reduce_once(const scalar_t* data, index_t offset) const {
auto values = load_inputs(data, offset);
- arg_t value;
- strided_iterate<vt0>([&](int i, int idx) {
- value = i == 0 ? pre_op(values[0]) : op(value, pre_op(values[i]));
+ arg_t value = ident;
+ strided_iterate<vt0, index_t>([&](index_t i, index_t idx) {
+ value = ops.reduce(value, values[i]);
}, offset, config.num_inputs, config.step_input);
return value;
C10_DEVICE arg_t thread_reduce(const scalar_t* data) const {
arg_t value = ident;
- int idx = config.input_idx();
+ index_t idx = config.input_idx();
while (idx < config.num_inputs) {
arg_t next = thread_reduce_once(data, idx);
- value = op(value, next);
+ value = ops.combine(value, next);
idx += config.step_input * vt0;
}
return value;
C10_DEVICE arg_t warp_reduce(arg_t value) const {
for (int offset = 1; offset < warpSize; offset <<= 1) {
- arg_t other = WARP_SHFL_DOWN(value, offset);
- value = op(value, other);
+ arg_t other = ops.warp_shfl_down(value, offset);
+ value = ops.combine(value, other);
}
return value;
}
__syncthreads();
if (threadIdx.y < offset && threadIdx.y + offset < num_warps) {
arg_t other = shared[config.shared_memory_offset(offset)];
- value = op(value, other);
+ value = ops.combine(value, other);
shared[config.shared_memory_offset(0)] = value;
}
}
return is_last_block_done;
}
+
+ template <bool can_acc>
+ C10_DEVICE arg_t accumulate_in_output(
+ out_scalar_t* out, arg_t value,
+ typename std::enable_if<can_acc>::type* = nullptr
+ ) const {
+ return ops.reduce(*out, value);
+ }
+
+ // This function should never be called --
+ // it's the version of `accumulate_in_output`
+ // when accumulation in the output is not possible.
+ template <bool can_acc>
+ C10_DEVICE arg_t accumulate_in_output(
+ out_scalar_t*, arg_t,
+ typename std::enable_if<!can_acc>::type* = nullptr
+ ) const {
+ assert(false); // can't use AT_ASSERT in Cuda.
+ return arg_t {};
+ }
C10_DEVICE arg_t global_reduce(arg_t value, out_scalar_t* out) const {
arg_t* reduce_buffer = (arg_t*)buffer;
bool should_store = config.should_store(config.output_idx());
if (should_store) {
- int offset = config.staging_memory_offset(blockIdx.y);
+ index_t offset = config.staging_memory_offset(blockIdx.y);
reduce_buffer[offset] = value;
}
bool is_last_block_done = mark_block_finished();
if (is_last_block_done) {
- value = 0;
+ value = arg_t {};
if (config.should_warp_reduce()) {
- int input_offset = threadIdx.x + threadIdx.y * blockDim.x;
- int step = blockDim.x * blockDim.y;
+ index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
+ index_t step = blockDim.x * blockDim.y;
for (; input_offset < config.ctas_per_output; input_offset += step) {
- int idx = config.staging_memory_offset(input_offset);
+ index_t idx = config.staging_memory_offset(input_offset);
arg_t next = reduce_buffer[idx];
- value = op(value, next);
+ value = ops.combine(value, next);
}
} else {
- int input_offset = threadIdx.y;
- int step = blockDim.y;
+ index_t input_offset = threadIdx.y;
+ index_t step = blockDim.y;
for (; input_offset < config.ctas_per_output; input_offset += step) {
- int idx = config.staging_memory_offset(input_offset);
+ index_t idx = config.staging_memory_offset(input_offset);
arg_t next = reduce_buffer[idx];
- value = op(value, next);
+ value = ops.combine(value, next);
}
}
value = block_reduce(value);
if (config.should_warp_reduce()) {
value = warp_reduce(value);
}
- value = post_op(value);
if (should_store) {
if (accumulate) {
- value = op(*out, value);
+ value = accumulate_in_output<can_accumulate_in_output>(out, value);
}
- *out = value;
+ *out = ops.project(value);
}
}
AT_CUDA_CHECK(cudaGetLastError());
}
-template <typename scalar_t, typename out_scalar_t, typename func_t, typename pre_func_t,
- typename post_func_t, typename ident_t=double>
-inline void gpu_reduce_kernel(TensorIterator& iter, const pre_func_t &pre_op,
- const post_func_t &post_op, const func_t& op,
- ident_t ident=0) {
- ASSERT_HOST_DEVICE_LAMBDA(func_t);
+template <typename scalar_t, typename out_scalar_t, typename ops_t, typename ident_t=double>
+inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0) {
AT_ASSERT(iter.numel() > 0 && iter.ntensors() == 2);
- if (!iter.can_use_32bit_indexing()) {
+ using traits = binary_function_traits<decltype(&ops_t::reduce)>;
+ using arg_t = typename traits::arg1_t;
+ static constexpr bool can_accumulate_in_output =
+ std::is_convertible<arg_t, out_scalar_t>::value;
+
+ bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
+ if (can_accumulate_in_output && !can_use_32bit_indexing) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
- gpu_reduce_kernel<scalar_t, out_scalar_t>(sub_iter, pre_op, post_op, op);
+ gpu_reduce_kernel<scalar_t, out_scalar_t>(sub_iter, ops, ident);
}
return;
}
char* out_data = (char*)iter.data_ptr(0);
const char* in_data = (char*)iter.data_ptr(1);
- using traits = binary_function_traits<func_t>;
- using arg_t = typename traits::arg2_t;
int warp_size = at::cuda::warp_size();
int warps_per_cta = ReduceConfig::NUM_THREADS / warp_size;
config.input_mult[2] = config.split_input(config.ctas_per_output);
}
- auto output_calc = make_output_calculator(iter);
- auto input_calc = make_input_calculator(iter);
-
at::DataPtr buffer;
at::DataPtr semaphores;
if (config.should_global_reduce()) {
auto stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
}
- auto reduce = ReduceOp<scalar_t, func_t, pre_func_t, post_func_t, out_scalar_t>(
- op,
- config,
- input_calc,
- output_calc,
- in_data,
- out_data,
- buffer.get(),
- (int*)semaphores.get(),
- pre_op,
- post_op);
- reduce.ident = ident;
- reduce.accumulate = iter.should_accumulate();
-
- launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
+
+ if (can_use_32bit_indexing) {
+ auto output_calc = make_output_calculator<uint32_t>(iter);
+ auto input_calc = make_input_calculator<uint32_t>(iter);
+ auto reduce = ReduceOp<scalar_t, ops_t, uint32_t, out_scalar_t>(
+ ops,
+ config,
+ input_calc,
+ output_calc,
+ in_data,
+ out_data,
+ buffer.get(),
+ (int*)semaphores.get(),
+ ident);
+ reduce.accumulate = iter.should_accumulate();
+
+ launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
+ } else {
+ auto output_calc = make_output_calculator<uint64_t>(iter);
+ auto input_calc = make_input_calculator<uint64_t>(iter);
+ auto reduce = ReduceOp<scalar_t, ops_t, uint64_t, out_scalar_t>(
+ ops,
+ config,
+ input_calc,
+ output_calc,
+ in_data,
+ out_data,
+ buffer.get(),
+ (int*)semaphores.get(),
+ ident);
+ AT_ASSERT(!iter.should_accumulate());
+ reduce.accumulate = false;
+
+ launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
+ }
}
}} // namespace at::native
+#include <ATen/native/SharedReduceOps.h>
#include <ATen/AccumulateType.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
+#include <ATen/native/cuda/DeviceSqrt.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Reduce.cuh>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/ReduceOps.h>
#include <limits>
+#include <tuple>
namespace at { namespace native {
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
void sum_kernel_impl(TensorIterator& iter) {
- gpu_reduce_kernel<scalar_t, out_t>(iter, SimpleCopy<acc_t>(), SimpleCopy<acc_t>(),
- []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
+ gpu_reduce_kernel<scalar_t, out_t>(iter, func_wrapper<out_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
- });
+ }));
+}
+
+template <typename scalar_t>
+void std_kernel_impl(TensorIterator& iter, bool unbiased) {
+ gpu_reduce_kernel<scalar_t, scalar_t>(iter, WelfordOps<scalar_t, scalar_t> { unbiased }, WelfordData<scalar_t> {});
+}
+
+template <>
+void std_kernel_impl<at::Half>(TensorIterator& iter, bool unbiased) {
+ gpu_reduce_kernel<at::Half, at::Half>(iter, WelfordOps<at::Half, float> { unbiased }, WelfordData<float> {});
}
#ifdef __HIPCC__
// compiler segfaults:
// https://bugs.llvm.org/show_bug.cgi?id=39602
// To work around it, use int32 as the accumulate type.
- gpu_reduce_kernel<int16_t, int16_t>(iter, SimpleCopy<int32_t>(), SimpleCopy<int32_t>(),
- []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t {
+ gpu_reduce_kernel<int16_t, int16_t>(iter, func_wrapper<int16_t> ([]GPU_LAMBDA(int32_t a, int32_t b) -> int32_t {
return a + b;
- });
+ }));
}
#endif
template <typename scalar_t, typename acc_t=scalar_t>
void prod_kernel_impl(TensorIterator& iter) {
- gpu_reduce_kernel<scalar_t, scalar_t>(iter, SimpleCopy<acc_t>(), SimpleCopy<acc_t>(),
- []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
+ gpu_reduce_kernel<scalar_t, scalar_t>(iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a * b;
- }, 1);
+ }), 1);
+}
+
+static void std_kernel_cuda(TensorIterator& iter, bool unbiased) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&]() {
+ std_kernel_impl<scalar_t>(iter, unbiased);
+ });
}
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
void mean_kernel_impl(TensorIterator& iter) {
float factor = float(iter.num_output_elements()) / iter.numel();
- gpu_reduce_kernel<scalar_t, out_t>(iter, SimpleCopy<acc_t>(),
- [factor]GPU_LAMBDA(acc_t a) -> acc_t { return a*factor; },
- []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { return a + b; });
+ gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<acc_t, float> {factor});
}
#ifdef __HIPCC__
// https://bugs.llvm.org/show_bug.cgi?id=39602
// To work around it, use int32 as the accumulate type.
float factor = float(iter.num_output_elements()) / iter.numel();
- gpu_reduce_kernel<int16_t, int16_t>(iter, SimpleCopy<int32_t>(),
- [factor]GPU_LAMBDA(int32_t a) -> int32_t { return a*factor; },
- []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t { return a + b; });
+ gpu_reduce_kernel<int16_t, int16_t>(iter, MeanOps<int32_t, float> {factor});
}
#endif // __HIPCC__
});
}
+REGISTER_DISPATCH(std_stub, &std_kernel_cuda);
REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
expected = numpy_op(tensor.numpy(), dim)
actual = pytorch_op(tensor, dim)
self._assert_matches_numpy(actual, expected)
+ if torch.cuda.is_available():
+ self._assert_matches_numpy(pytorch_op(tensor.cuda(),
+ dim).cpu(),
+ expected)
do_one(self._make_tensors((5, 400000), use_floating=use_floating,
use_integral=use_integral), 1)
do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
Args:
input (Tensor): the input tensor
- dim (int): the dimension to reduce
+ dim (int or tuple of ints): the dimension or dimensions to reduce
keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not
out (Tensor): the output tensor
.. function:: std(input, dim, keepdim=False, unbiased=True, out=None) -> Tensor
Returns the standard-deviation of each row of the :attr:`input` tensor in the
-given dimension :attr:`dim`.
+dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
+reduce over all of them.
-If :attr:`keepdim` is ``True``, the output tensor is of the same size as
-:attr:`input` except in the dimension :attr:`dim` where it is of size 1.
-Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting
-in the output tensor having 1 fewer dimension than :attr:`input`.
+If :attr:`keepdim` is ``True``, the output tensor is of the same size
+as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
+Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
+output tensor having 1 (or ``len(dim)``) fewer dimension(s).
If :attr:`unbiased` is ``False``, then the standard-deviation will be calculated
via the biased estimator. Otherwise, Bessel's correction will be used.
Args:
input (Tensor): the input tensor
- dim (int): the dimension to reduce
+ dim (int or tuple of ints): the dimension or dimensions to reduce
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
unbiased (bool): whether to use the unbiased estimation or not
out (Tensor, optional): the output tensor