DEFINE_DISPATCH(std_stub);
DEFINE_DISPATCH(prod_stub);
DEFINE_DISPATCH(norm_kernel);
+DEFINE_DISPATCH(mean_stub);
static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dtype) {
ScalarType scalarType = self.type().scalarType();
// ALL REDUCE #################################################################
-static inline Tensor mean(const Tensor &self, optional<ScalarType> dtype) {
- ScalarType scalarType = self.type().scalarType();
- AT_CHECK(
- at::isFloatingType(scalarType),
- "Can only calculate the mean of floating types. Got ",
- toString(scalarType),
- " instead.");
- if (self.numel() > 0) {
- Tensor result = at::native::sum(self);
- return result.div_(self.numel());
- } else {
- return at::scalar_tensor(std::numeric_limits<double>::quiet_NaN(), self.options());
- }
-}
-
-Tensor mean(const Tensor &self, ScalarType dtype) {
- return at::native::mean(self, optional<ScalarType>(dtype));
-}
-
-Tensor mean(const Tensor &self) {
- return at::native::mean(self, c10::nullopt);
-}
-
static ScalarType get_dtype(Tensor& result, const Tensor& self, optional<ScalarType> dtype,
bool promote_integers=false) {
if (dtype.has_value()) {
return at::native::prod(self, {}, false, c10::nullopt);
}
-// \ALL REDUCE ################################################################
-
-// DIM REDUCE #################################################################
-
static inline Tensor &mean_out(Tensor &result, const Tensor &self, IntList dim,
- bool keepdim, optional<ScalarType> dtype) {
- ScalarType scalarType = result.type().scalarType();
+ bool keepdim, optional<ScalarType> opt_dtype) {
+ ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.type().scalarType();
AT_CHECK(
at::isFloatingType(scalarType),
"Can only calculate the mean of floating types. Got ",
toString(scalarType),
" instead.");
- at::native::sum_out(
- result, self.toType(result.type().scalarType()), dim, keepdim);
- if (result.numel() > 0 && self.ndimension() > 0) {
- int64_t numel = n_dim_size(self, dim);
- if (numel > 0) {
- result.div_(numel);
- } else {
- // NumPy equivalent
- result.fill_(std::numeric_limits<double>::quiet_NaN());
- }
+ ScalarType dtype = get_dtype(result, self, opt_dtype, true);
+ auto iter = make_reduction("mean", result, self, dim, keepdim, dtype);
+ if (iter->numel() == 0) {
+ result.fill_(std::numeric_limits<double>::quiet_NaN());
+ } else {
+ mean_stub(iter->device_type(), *iter);
}
return result;
}
+// \ALL REDUCE ################################################################
+
+// DIM REDUCE #################################################################
+
Tensor& mean_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
return at::native::mean_out(
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
return at::native::mean_out(result, self, dim, false, dtype);
}
+static inline Tensor mean(const Tensor &self, IntList dim, bool keepdim, optional<ScalarType> dtype) {
+ Tensor result;
+ return at::native::mean_out(result, self, dim, keepdim, dtype);
+}
+
+static inline Tensor mean(const Tensor &self, optional<ScalarType> dtype) {
+ return at::native::mean(self, {}, false, dtype);
+}
+
+Tensor mean(const Tensor &self, ScalarType dtype) {
+ return at::native::mean(self, optional<ScalarType>(dtype));
+}
+
+Tensor mean(const Tensor &self) {
+ return at::native::mean(self, c10::nullopt);
+}
+
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
return at::native::sum_out(
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
return at::native::prod_out(result, self, dim, false, dtype);
}
-static inline Tensor mean(const Tensor &self, IntList dim, bool keepdim, optional<ScalarType> dtype) {
- ScalarType scalarType = self.type().scalarType();
- AT_CHECK(
- at::isFloatingType(scalarType),
- "Can only calculate the mean of floating types. Got ",
- toString(scalarType),
- " instead.");
- Tensor result = at::native::sum(self, dim, keepdim);
- if (result.numel() > 0 && self.ndimension() > 0) {
- int64_t numel = n_dim_size(self, dim);
- if (numel > 0) {
- result.div_(numel);
- } else {
- // NumPy equivalent
- result.fill_(std::numeric_limits<double>::quiet_NaN());
- }
- }
- return result;
-}
-
Tensor mean(const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
return at::native::mean(self, dim, keepdim, c10::optional<ScalarType>(dtype));
}
DECLARE_DISPATCH(reduce_fn, sum_stub);
DECLARE_DISPATCH(reduce_fn, prod_stub);
+DECLARE_DISPATCH(reduce_fn, mean_stub);
using reduce_std_function =
void (*)(TensorIterator&, bool unbiased);
if (op.tensor.defined() && op.tensor.type() != *op.type) {
if (op.is_output) {
AT_ERROR("output with type ", op.tensor.type().toString(),
- " doesn't match the desired type ", type().toString());
+ " doesn't match the desired type ", op.type->toString());
} else if (op.tensor.dim() == 0) {
op.tensor = op.tensor.to(*op.type);
} else {
- AT_ERROR("expected type ", type().toString(), " but got ",
+ AT_ERROR("expected type ", op.type->toString(), " but got ",
op.tensor.type().toString());
}
}
template <typename rf_t,
typename cf_t,
- typename pf_t>
-void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &combine, pf_t const &project) {
+ typename pf_t,
+ typename init_t>
+void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &combine, pf_t const &project, init_t init) {
using r_traits = binary_function_traits<rf_t>;
using c_traits = binary_function_traits<cf_t>;
using p_traits = unary_function_traits<pf_t>;
static_assert(
all_same<
acc_t,
+ init_t,
typename r_traits::arg1_t,
typename r_traits::result_type,
typename c_traits::arg1_t,
bool serial = numel < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region();
int max_threads = serial ? 1 : at::get_max_threads();
AT_ASSERT(max_threads > 0);
- std::vector<optional<acc_t>> buffer{(unsigned)max_threads, optional<acc_t> {}};
+ std::vector<optional<acc_t>> buffer((unsigned)max_threads, optional<acc_t> {});
at::parallel_for(0, numel, serial ? (1 + numel) : internal::GRAIN_SIZE,
[&](int64_t begin, int64_t end) {
auto &acc = buffer[at::get_thread_num()];
- sub_iter.serial_for_each([&acc, &reduce](int ntensors, char** data, const int64_t* strides, int64_t size) {
+ sub_iter.serial_for_each([&acc, &reduce, &init](int ntensors, char** data, const int64_t* strides, int64_t size) {
AT_ASSERT(ntensors == 2);
char *in = data[1];
int64_t stride = strides[1];
if (!acc && size > 0) {
- acc = acc_t {};
+ //acc = acc_t {};
+ acc = init;
}
for (int64_t i = 0; i < size; ++i) {
acc = reduce(*acc, *(data_t*)in);
}
}, {begin, end});
});
- acc_t acc;
+ acc_t acc = init;
for (int i = 0; i < max_threads; ++i) {
if (buffer[i]) {
acc = combine(acc, *buffer[i]);
});
}
+static void mean_kernel_impl(TensorIterator& iter) {
+ AT_DISPATCH_ALL_TYPES(iter.type(), "mean", [&] {
+ scalar_t factor = scalar_t(iter.num_output_elements()) / iter.numel();
+ binary_kernel_reduce(
+ iter,
+ [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
+ [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
+ [factor](scalar_t a) -> scalar_t { return a*factor; }, scalar_t(0));
+ });
+}
+
struct WelfordData {
double mean;
double m2;
[unbiased](WelfordData acc) -> scalar_t {
int64_t divisor = unbiased ? (acc.n - 1) : acc.n;
return (divisor > 0) ? std::sqrt(acc.m2 / divisor) : NAN;
- }
+ },
+ WelfordData()
);
});
}
REGISTER_DISPATCH(std_stub, &std_kernel_impl);
REGISTER_DISPATCH(prod_stub, &prod_kernel_impl);
REGISTER_DISPATCH(norm_kernel, &norm_kernel_impl);
+REGISTER_DISPATCH(mean_stub, &mean_kernel_impl);
}} // namespace at::native
-
return load_memory<vt>(in, begin, end, stride, [](int idx) { return idx; });
}
-template <typename scalar_t, typename func_t, typename out_scalar_t=scalar_t>
+template <typename scalar_t, typename func_t, typename pre_func_t,
+ typename post_func_t, typename out_scalar_t=scalar_t>
struct ReduceOp {
using traits = binary_function_traits<func_t>;
using arg_t = typename traits::arg2_t;
static constexpr int vt0 = 4;
func_t op;
+ pre_func_t pre_op;
+ post_func_t post_op;
arg_t ident;
ReduceConfig config;
InputCalculator input_calc;
bool accumulate;
ReduceOp(func_t op, ReduceConfig config, InputCalculator input_calc, OutputCalculator output_calc,
- const void* src, void* dst, void* buffer, int* semaphores)
+ const void* src, void* dst, void* buffer, int* semaphores, pre_func_t pre_op,
+ post_func_t post_op)
: op(op)
+ , pre_op(pre_op)
+ , post_op(post_op)
, config(config)
, input_calc(input_calc)
, output_calc(output_calc)
if (config.should_global_reduce()) {
value = global_reduce(value, out);
} else if (config.should_store(output_idx)) {
+ value = post_op(value);
if (accumulate) {
value = op(*out, value);
}
arg_t value;
strided_iterate<vt0>([&](int i, int idx) {
- value = i == 0 ? (arg_t)values[0] : op(value, values[i]);
+ value = i == 0 ? pre_op(values[0]) : op(value, pre_op(values[i]));
}, offset, config.num_inputs, config.step_input);
return value;
if (config.should_warp_reduce()) {
value = warp_reduce(value);
}
+ value = post_op(value);
if (should_store) {
if (accumulate) {
value = op(*out, value);
AT_CUDA_CHECK(cudaGetLastError());
}
-template <typename scalar_t, typename out_scalar_t, typename func_t, typename ident_t=double>
-inline void gpu_reduce_kernel(TensorIterator& iter, const func_t& op, ident_t ident=0) {
+template <typename scalar_t, typename out_scalar_t, typename func_t, typename pre_func_t,
+ typename post_func_t, typename ident_t=double>
+inline void gpu_reduce_kernel(TensorIterator& iter, const pre_func_t &pre_op,
+ const post_func_t &post_op, const func_t& op,
+ ident_t ident=0) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
AT_ASSERT(iter.numel() > 0 && iter.ntensors() == 2);
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
- gpu_reduce_kernel<scalar_t, out_scalar_t>(sub_iter, op);
+ gpu_reduce_kernel<scalar_t, out_scalar_t>(sub_iter, pre_op, post_op, op);
}
return;
}
auto stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
}
- auto reduce = ReduceOp<scalar_t, func_t, out_scalar_t>(
+ auto reduce = ReduceOp<scalar_t, func_t, pre_func_t, post_func_t, out_scalar_t>(
op,
config,
input_calc,
in_data,
out_data,
buffer.get(),
- (int*)semaphores.get());
+ (int*)semaphores.get(),
+ pre_op,
+ post_op);
reduce.ident = ident;
reduce.accumulate = iter.should_accumulate();
namespace at { namespace native {
+namespace {
+
+template <typename scalar_t>
+struct SimpleCopy {
+ __device__ __forceinline__ scalar_t operator() (const scalar_t a) const {
+ return a;
+ }
+};
+
+} // namespace
+
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
void sum_kernel_impl(TensorIterator& iter) {
- gpu_reduce_kernel<scalar_t, out_t>(iter, []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
+ gpu_reduce_kernel<scalar_t, out_t>(iter, SimpleCopy<acc_t>(), SimpleCopy<acc_t>(),
+ []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
});
}
// compiler segfaults:
// https://bugs.llvm.org/show_bug.cgi?id=39602
// To work around it, use int32 as the accumulate type.
- gpu_reduce_kernel<int16_t, int16_t>(iter, []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t {
+ gpu_reduce_kernel<int16_t, int16_t>(iter, SimpleCopy<int32_t>(), SimpleCopy<int32_t>(),
+ []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t {
return a + b;
});
}
template <typename scalar_t, typename acc_t=scalar_t>
void prod_kernel_impl(TensorIterator& iter) {
- gpu_reduce_kernel<scalar_t, scalar_t>(iter, []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
+ gpu_reduce_kernel<scalar_t, scalar_t>(iter, SimpleCopy<acc_t>(), SimpleCopy<acc_t>(),
+ []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a * b;
}, 1);
}
+template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
+void mean_kernel_impl(TensorIterator& iter) {
+ float factor = float(iter.num_output_elements()) / iter.numel();
+ gpu_reduce_kernel<scalar_t, out_t>(iter, SimpleCopy<acc_t>(),
+ [factor]GPU_LAMBDA(acc_t a) -> acc_t { return a*factor; },
+ []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { return a + b; });
+}
+
+#ifdef __HIPCC__
+template <>
+void mean_kernel_impl<int16_t, int16_t, int16_t>(TensorIterator& iter) {
+ // There is a Register Coalescing bug in LLVM causing the hcc
+ // compiler segfaults:
+ // https://bugs.llvm.org/show_bug.cgi?id=39602
+ // To work around it, use int32 as the accumulate type.
+ float factor = float(iter.num_output_elements()) / iter.numel();
+ gpu_reduce_kernel<int16_t, int16_t>(iter, SimpleCopy<int32_t>(),
+ [factor]GPU_LAMBDA(int32_t a) -> int32_t { return a*factor; },
+ []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t { return a + b; });
+}
+#endif // __HIPCC__
+
static void sum_kernel_cuda(TensorIterator& iter) {
if (iter.type().scalarType() == kHalf) {
return sum_kernel_impl<at::Half, float>(iter);
});
}
+static void mean_kernel_cuda(TensorIterator& iter) {
+ if (iter.type().scalarType() == kHalf) {
+ return mean_kernel_impl<at::Half, float>(iter);
+ } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) {
+ // type promotion that does cast and reduction in a single kernel
+ return mean_kernel_impl<at::Half, float, float>(iter);
+ }
+ AT_DISPATCH_ALL_TYPES(iter.type(), "mean", [&]() {
+ mean_kernel_impl<scalar_t>(iter);
+ });
+}
+
REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
+REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
}} // namespace at::native
self.assertEqual(x.sum(), 65504)
self.assertEqual(x.sum(dtype=torch.float32), 65504)
+ x = torch.ones(65536, device='cuda', dtype=torch.float16)
+ self.assertEqual(x.sum(dtype=torch.float32), 65536)
+
a = torch.zeros(1203611).bernoulli_(0.0005)
x = a.to(device='cuda', dtype=torch.float16)
self.assertEqual(x.sum().item(), a.sum().item())
x = a.to(device='cuda', dtype=torch.float16)
self.assertEqual(x.sum((0, 2)).float().cpu(), a.sum((0, 2)))
+ @skipIfRocm
+ def test_mean_fp16(self):
+ x = torch.ones(65536, device='cuda', dtype=torch.float16)
+ self.assertEqual(x.mean(), 1)
+
+ x = torch.ones(65536, device='cuda', dtype=torch.float16)
+ self.assertEqual(x.mean(dtype=torch.float32), 1)
+
@staticmethod
def _select_broadcastable_dims(dims_full=None):
return _TestTorchMixin._select_broadcastable_dims(dims_full)
self: grad.clone().masked_fill_(self <= other, 0)
other: grad.clone().masked_fill_(self > other, 0)
+- name: mean(Tensor self)
+ self: grad.expand(self.sizes()) / self.numel()
+
+- name: mean(Tensor self, ScalarType dtype)
+ self: grad.expand(self.sizes()).to(self.type().scalarType()) / self.numel()
+
- name: mean(Tensor self, IntList dim, bool keepdim)
self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim)
-- name: mean(Tensor self)
- self: grad.expand(self.sizes()) / self.numel()
+- name: mean(Tensor self, IntList dim, ScalarType dtype)
+ self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType()) / _safe_size(self.sizes(), dim)
+
+- name: mean(Tensor self, IntList dim, bool keepdim, ScalarType dtype)
+ self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) / _safe_size(self.sizes(), dim)
- name: median(Tensor self)
self: select_equals_backward(grad, self, result)