//
// 3. sequential RAM + transpose: create an n X 2 Tensor, fill the Tensor
// sequentially, and then transpose it.
- AT_DISPATCH_ALL_TYPES(result.scalar_type(), "tril_indices", [&]() -> void {
+ AT_DISPATCH_ALL_TYPES_AND(kBFloat16, result.scalar_type(), "tril_indices", [&]() -> void {
// fill the Tensor with correct values
scalar_t* result_data = result.data_ptr<scalar_t>();
int64_t i = 0;
// create an empty Tensor with correct size
auto result = at::native::empty_cpu({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
- AT_DISPATCH_ALL_TYPES(result.scalar_type(), "triu_indices", [&]() -> void {
+ AT_DISPATCH_ALL_TYPES_AND(kBFloat16, result.scalar_type(), "triu_indices", [&]() -> void {
// fill the Tensor with correct values
scalar_t* result_data = result.data_ptr<scalar_t>();
int64_t i = 0;
Tensor self_c;
std::tie(inplace, self_c) = checkTrilTriuBatchContiguous(self, true);
Tensor result = inplace ? self : at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
apply_triu_tril<scalar_t, false>(result, self_c, inplace, k);
});
if (!inplace) self.copy_(result);
}
Tensor self_c;
std::tie(std::ignore, self_c) = checkTrilTriuBatchContiguous(self, false);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
apply_triu_tril<scalar_t, false>(result, self_c, false, k);
});
return result;
Tensor self_c;
std::tie(inplace, self_c) = checkTrilTriuBatchContiguous(self, true);
Tensor result = inplace ? self : at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
apply_triu_tril<scalar_t, true>(result, self_c, inplace, k);
});
if (!inplace) self.copy_(result);
}
Tensor self_c;
std::tie(std::ignore, self_c) = checkTrilTriuBatchContiguous(self, false);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
apply_triu_tril<scalar_t, true>(result, self_c, false, k);
});
return result;
}
static void cross_kernel_impl(Tensor& result, const Tensor& a, const Tensor& b, const int64_t dim) {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX(result.scalar_type(), "cross", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, result.scalar_type(), "cross", [&]() {
apply_cross<scalar_t>(result, a, b, dim);
});
}
for s, d, dtype in product(shapes, diagonals, dtypes):
run_test(s, device, d, dtype)
+ @onlyCPU
+ def test_triu_tril_bfloat16(self, device):
+ op_funcs = [torch.tril, torch.triu]
+ for op_fun in op_funcs:
+ input = torch.randn(3, 3, dtype=torch.float32, device=device).bfloat16().requires_grad_(True)
+ input2 = input.detach().clone().float().requires_grad_(True)
+ out = op_fun(input)
+ out.sum().backward()
+ out2 = op_fun(input2)
+ out2.sum().backward()
+ self.assertEqual(out.dtype, torch.bfloat16)
+ self.assertEqual(input.grad.dtype, torch.bfloat16)
+ self.assertEqual(out, out2.bfloat16())
+ self.assertEqual(input.grad, input2.grad.bfloat16(), atol=0.01, rtol=0)
+
def test_diagflat(self, device):
dtype = torch.float32
# Basic sanity test
self.assertEqual(b.triu(2), output)
self.assertRaises(RuntimeError, lambda: b.triu_(2))
+ @onlyCPU
+ def test_triu_tril_indices_bfloat16(self, device):
+ op_funcs = [torch.tril_indices, torch.triu_indices]
+ for op_fun in op_funcs:
+ out = op_fun(4, 3, 1, dtype=torch.bfloat16)
+ out2 = op_fun(4, 3, 1, dtype=torch.float)
+ self.assertEqual(out.dtype, torch.bfloat16)
+ self.assertEqual(out, out2.bfloat16())
+
# TODO: update to work on CUDA, too
@onlyCPU
def test_stack(self, device):
skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),)),
OpInfo('cross',
dtypes=all_types_and_complex(),
+ dtypesIfCPU=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.half),
sample_inputs_func=sample_inputs_cross,
supports_forward_ad=True,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_transpose_swapdims),
OpInfo('tril',
+ dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.half),
supports_forward_ad=True,
sample_inputs_func=sample_inputs_tril_triu),
OpInfo('triu',
+ dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.half),
supports_forward_ad=True,
sample_inputs_func=sample_inputs_tril_triu),