add BFloat16 operators on CPU: cummax, cummin (#63307)
authorJiayi Sun <jiayi.sun@intel.com>
Mon, 13 Sep 2021 14:59:00 +0000 (07:59 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 13 Sep 2021 15:00:17 +0000 (08:00 -0700)
Summary:
Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63307

Reviewed By: nikithamalgifb

Differential Revision: D30342002

Pulled By: anjali411

fbshipit-source-id: eee6e640da996ef0e983960119608d9c12405336

aten/src/ATen/native/ReduceOps.cpp
torch/testing/_internal/common_methods_invocations.py

index 4bef219..c11de20 100644 (file)
@@ -642,7 +642,7 @@ void cummax_cummin_helper(const T1* self_data, T1* values_data, T2* indices_data
 }
 
 void cummax_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
+  AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16,
     self.scalar_type(), "cummax_cpu",
     [&] {
       at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::greater_equal<scalar_t>>);
@@ -677,7 +677,7 @@ std::tuple<Tensor, Tensor> cummax(const Tensor& self, int64_t dim) {
 }
 
 void cummin_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
+  AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16,
     self.scalar_type(), "cummin_cpu",
     [&] {
       at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::less_equal<scalar_t>>);
index da7b3a0..aaf3972 100644 (file)
@@ -6465,13 +6465,13 @@ op_db: List[OpInfo] = [
            sample_inputs_func=sample_inputs_cumprod,
            gradcheck_fast_mode=False),
     OpInfo('cummax',
-           dtypesIfCPU=all_types_and(torch.bool),
+           dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
            dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
            sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
            supports_forward_ad=True,
            gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
     OpInfo('cummin',
-           dtypesIfCPU=all_types_and(torch.bool),
+           dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
            dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
            sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
            supports_forward_ad=True,