template<typename RNG>
void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) {
- AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
using self_t = scalar_t;
return static_cast<self_t>(bernoulli(generator));
});
} else {
- AT_DISPATCH_FLOATING_TYPES(p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
using p_t = scalar_t;
cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
at::bernoulli_distribution<float> bernoulli(p_val);
template<typename RNG>
void bernoulli_kernel(Tensor& self, double p, RNG generator) {
- AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
auto iter = TensorIterator::borrowing_nullary_op(self);
int64_t n = self.numel();
bool contig = self.is_contiguous();
- AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
at::Tensor tmp_int_tensor;
if (std::is_same<scalar_t, int>::value && contig) {
tmp_int_tensor = self;
self._test_dropout_stride_mean_preserve(nn.Dropout, device)
- if self.device_type == 'cuda':
+ if self.device_type == 'cuda' or self.device_type == 'cpu':
input = input.bfloat16()
self._test_dropout(nn.Dropout, device, input)
self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
@dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False)))
+ @dtypesIfCPU(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=True)))
@dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
def test_bernoulli_p(self, device, dtype):
for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]):