add operation list for AutocastCPU (#63534)
authorleslie-fang-intel <leslie.fang@intel.com>
Tue, 31 Aug 2021 02:28:59 +0000 (19:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 02:30:33 +0000 (19:30 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63534

In this PR:
* We have changed the default dtype of `AutocastCPU` from `float16` to `bfloat16` as discussed here `https://github.com/pytorch/pytorch/pull/61002`
* We also update the operation list which needs casting to `lower_precision_fp` or `float32`.

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D30644914

Pulled By: ezyang

fbshipit-source-id: 8b93485ba452b3759611e3f0ac88e920fe495ac1

aten/src/ATen/autocast_mode.cpp
test/run_test.py
torch/cpu/amp/autocast_mode.py
torch/testing/_internal/autocast_test_lists.py

index 1ac5ad1..9f5f486 100644 (file)
@@ -461,22 +461,22 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
   KERNEL_CPU(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp)
   KERNEL_CPU(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp)
   KERNEL_CPU(ADD_NS(conv3d), "conv3d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp)
-  KERNEL_CPU(ADD_NS(_log_softmax), "_log_softmax", Tensor (const Tensor &, int64_t, bool), lower_precision_fp)
   KERNEL_CPU(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), lower_precision_fp)
   KERNEL_CPU(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), lower_precision_fp)
   KERNEL_CPU(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
   KERNEL_CPU(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
   KERNEL_CPU(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
   KERNEL_CPU(ADD_NS(linear), "linear", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor> &), lower_precision_fp)
+  KERNEL_CPU(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), lower_precision_fp)
 
   // fp32 cast policy
+  KERNEL_CPU(ADD_NS(conv_transpose1d), "conv_transpose1d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor> &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(conv_transpose2d), "conv_transpose2d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor> &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32)
   KERNEL_CPU(ADD_NS(conv_transpose3d), "conv_transpose3d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor> &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32)
   KERNEL_CPU(ADD_NS(batch_norm), "batch_norm", Tensor (const Tensor &, const c10::optional<Tensor> &, const c10::optional<Tensor> &, const c10::optional<Tensor> &, const c10::optional<Tensor> &, bool, double, double, bool), fp32)
-  KERNEL_CPU(ADD_NS(max_pool2d), "max_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool), fp32)
-  KERNEL_CPU(ADD_NS(adaptive_avg_pool2d), "adaptive_avg_pool2d", Tensor (const Tensor &, IntArrayRef), fp32)
 
-  KERNEL_CPU(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), fp32)
   KERNEL_CPU(ADD_NS(dropout), "dropout", Tensor (const Tensor &, double, bool), fp32)
+  KERNEL_CPU(ADD_NS(avg_pool1d), "avg_pool1d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool), fp32)
   KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
   KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
   KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &), fp32)
@@ -492,45 +492,285 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
   KERNEL_CPU(ADD_NS(upsample_bilinear2d), "upsample_bilinear2d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, bool, c10::optional<ArrayRef<double>>), fp32)
   KERNEL_CPU(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<double>, c10::optional<double>, c10::optional<double>), fp32)
   KERNEL_CPU(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, bool, c10::optional<ArrayRef<double>>), fp32)
+
   KERNEL_CPU(ADD_NS(binary_cross_entropy), "binary_cross_entropy", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, int64_t), fp32)
   KERNEL_CPU(ADD_NS(binary_cross_entropy_with_logits), "binary_cross_entropy_with_logits", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t), fp32)
-  KERNEL_CPU(ADD_NS(pow), "pow.Tensor_Scalar", Tensor (const Tensor &, const Scalar &), fp32)
-  KERNEL_CPU(ADD_NS(pow), "pow.Tensor_Tensor", Tensor (const Tensor &, const Tensor &), fp32)
-  KERNEL_CPU(ADD_NS(pow), "pow.Scalar", Tensor (const Scalar&, const Tensor &), fp32)
-  KERNEL_CPU(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32)
-  KERNEL_CPU(ADD_NS(reflection_pad1d), "reflection_pad1d", Tensor (const Tensor &, IntArrayRef), fp32)
-  KERNEL_CPU(ADD_NS(std), "std", Tensor (const Tensor &, bool), fp32)
-  KERNEL_CPU(ADD_NS(std), "std.dim", Tensor (const Tensor &, IntArrayRef, bool, bool), fp32)
   KERNEL_CPU(ADD_NS(instance_norm), "instance_norm", Tensor (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, const c10::optional<Tensor>&, const c10::optional<Tensor>&, bool, double, double, bool), fp32)
+  KERNEL_CPU(ADD_NS(grid_sampler), "grid_sampler", Tensor(const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32)
+  KERNEL_CPU(ADD_NS(polar), "polar", Tensor(const Tensor &, const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(multinomial), "multinomial", Tensor(const Tensor &, int64_t, bool, c10::optional<at::Generator>), fp32)
+  KERNEL_CPU(ADD_NS(poisson), "poisson", Tensor(const Tensor &, c10::optional<at::Generator>), fp32)
+  KERNEL_CPU(ADD_NS(fmod), "fmod.Tensor", Tensor(const Tensor &, const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(fmod), "fmod.Scalar", Tensor(const Tensor &, const Scalar &), fp32)
+  KERNEL_CPU(ADD_NS(prod), "prod", Tensor(const Tensor &, c10::optional<at::ScalarType>), fp32)
+  KERNEL_CPU(ADD_NS(prod), "prod.dim_int", Tensor(const Tensor &, int64_t, bool, c10::optional<at::ScalarType>), fp32)
+  KERNEL_CPU(ADD_NS(prod), "prod.dim_Dimname", Tensor(const Tensor &, at::Dimname, bool, c10::optional<at::ScalarType>), fp32)
+  KERNEL_CPU(ADD_NS(quantile), "quantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool), fp32)
+  KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool), fp32)
+  KERNEL_CPU(ADD_NS(quantile), "quantile.new", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
+  KERNEL_CPU(ADD_NS(quantile), "quantile.new_scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
+  KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool), fp32)
+  KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool), fp32)
+  KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.new", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
+  KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.new_scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
+  KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::optional<bool>, c10::optional<bool>), fp32)
+  KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
+  KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>), fp32)
+  KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional<at::ScalarType>), fp32)
+  KERNEL_CPU(ADD_NS(cumprod), "cumprod.dimname", Tensor(const Tensor &, at::Dimname, c10::optional<at::ScalarType>), fp32)
+  KERNEL_CPU(ADD_NS(cumsum), "cumsum", Tensor(const Tensor &, int64_t, c10::optional<at::ScalarType>), fp32)
+  KERNEL_CPU(ADD_NS(cumsum), "cumsum.dimname", Tensor(const Tensor &, at::Dimname, c10::optional<at::ScalarType>), fp32)
+  KERNEL_CPU(ADD_NS(diag), "diag", Tensor(const Tensor &, int64_t), fp32)
+  KERNEL_CPU(ADD_NS(diagflat), "diagflat", Tensor(const Tensor &, int64_t), fp32)
+  KERNEL_CPU(ADD_NS(histc), "histc", Tensor(const Tensor &, int64_t, const at::Scalar &, const at::Scalar &), fp32)
+  KERNEL_CPU(ADD_NS(logcumsumexp), "logcumsumexp", Tensor(const Tensor &, int64_t), fp32)
+  KERNEL_CPU(ADD_NS(searchsorted), "searchsorted.Tensor", Tensor(const Tensor &, const Tensor &, bool, bool), fp32)
+  KERNEL_CPU(ADD_NS(searchsorted), "searchsorted.Scalar", Tensor(const Tensor &, const at::Scalar &, bool, bool), fp32)
+  KERNEL_CPU(ADD_NS(trace), "trace", Tensor(const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(tril), "tril", Tensor(const Tensor &, int64_t), fp32)
+  KERNEL_CPU(ADD_NS(triu), "triu", Tensor(const Tensor &, int64_t), fp32)
+  KERNEL_CPU(ADD_NS(vander), "vander", Tensor(const Tensor &, c10::optional<int64_t>, bool), fp32)
+  KERNEL_CPU(ADD_NS(view_as_complex), "view_as_complex", Tensor(const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(cholesky), "cholesky", Tensor(const Tensor &, bool), fp32)
+  KERNEL_CPU(ADD_NS(cholesky_inverse), "cholesky_inverse", Tensor(const Tensor &, bool), fp32)
+  KERNEL_CPU(ADD_NS(cholesky_solve), "cholesky_solve", Tensor(const Tensor &, const Tensor &, bool), fp32)
+  KERNEL_CPU(ADD_NS(dot), "dot", Tensor(const Tensor &, const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(inverse), "inverse", Tensor(const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(lu_solve), "lu_solve", Tensor(const Tensor &, const Tensor &, const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(matrix_rank), "matrix_rank", Tensor(const Tensor &, bool), fp32)
+  KERNEL_CPU(ADD_NS(orgqr), "orgqr", Tensor(const Tensor &, const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(ormqr), "ormqr", Tensor(const Tensor &, const Tensor &, const Tensor &, bool, bool), fp32)
+  KERNEL_CPU(ADD_NS(pinverse), "pinverse", Tensor(const Tensor &, double), fp32)
+  KERNEL_CPU(ADD_NS(vdot), "vdot", Tensor(const Tensor &, const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(im2col), "im2col", Tensor(const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(col2im), "col2im", Tensor(const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(max_pool3d), "max_pool3d", Tensor(const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool), fp32)
+  KERNEL_CPU(ADD_NS(max_unpool2d), "max_unpool2d", Tensor(const Tensor &, const Tensor &, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(max_unpool3d), "max_unpool3d", Tensor(const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(adaptive_avg_pool3d), "adaptive_avg_pool3d", Tensor(const Tensor &, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(reflection_pad1d), "reflection_pad1d", Tensor(const Tensor &, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(reflection_pad2d), "reflection_pad2d", Tensor(const Tensor &, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(replication_pad1d), "replication_pad1d", Tensor(const Tensor &, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(replication_pad2d), "replication_pad2d", Tensor(const Tensor &, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(replication_pad3d), "replication_pad3d", Tensor(const Tensor &, IntArrayRef), fp32)
+  KERNEL_CPU(ADD_NS(elu), "elu", Tensor(const Tensor &, const Scalar &, const Scalar &, const Scalar &), fp32)
+  KERNEL_CPU(ADD_NS(hardshrink), "hardshrink", Tensor(const Tensor &, const Scalar &), fp32)
+  KERNEL_CPU(ADD_NS(hardsigmoid), "hardsigmoid", Tensor(const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(hardswish), "hardswish", Tensor(const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(log_sigmoid), "log_sigmoid", Tensor(const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(prelu), "prelu", Tensor(const Tensor &, const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(selu), "selu", Tensor(const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(celu), "celu", Tensor(const Tensor &, const Scalar &), fp32)
+  KERNEL_CPU(ADD_NS(softplus), "softplus", Tensor(const Tensor &, const Scalar &, const Scalar &), fp32)
+  KERNEL_CPU(ADD_NS(softshrink), "softshrink", Tensor(const Tensor &, const Scalar &), fp32)
+  KERNEL_CPU(ADD_NS(group_norm), "group_norm", Tensor(const Tensor &, int64_t, const c10::optional<Tensor> &, const c10::optional<Tensor> &, double, bool), fp32)
+  KERNEL_CPU(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32)
+  KERNEL_CPU(ADD_NS(mse_loss), "mse_loss", Tensor(const Tensor &, const Tensor &, int64_t), fp32)
+  KERNEL_CPU(ADD_NS(ctc_loss), "ctc_loss.IntList", Tensor(const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t, int64_t, bool), fp32)
+  KERNEL_CPU(ADD_NS(ctc_loss), "ctc_loss.Tensor", Tensor(const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32)
+  KERNEL_CPU(ADD_NS(kl_div), "kl_div", Tensor(const Tensor &, const Tensor &, int64_t, bool), fp32)
+  KERNEL_CPU(ADD_NS(multilabel_margin_loss), "multilabel_margin_loss", Tensor(const Tensor &, const Tensor &, int64_t), fp32)
+  KERNEL_CPU(ADD_NS(fft_fft), "fft_fft", Tensor(const Tensor &, c10::optional<int64_t>, int64_t, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_ifft), "fft_ifft", Tensor(const Tensor &, c10::optional<int64_t>, int64_t, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_fft2), "fft_fft2", Tensor(const Tensor &, c10::optional<at::IntArrayRef>, at::IntArrayRef, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_ifft2), "fft_ifft2", Tensor(const Tensor &, c10::optional<at::IntArrayRef>, at::IntArrayRef, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_fftn), "fft_fftn", Tensor(const Tensor &, c10::optional<at::IntArrayRef>, c10::optional<at::IntArrayRef>, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_ifftn), "fft_ifftn", Tensor(const Tensor &, c10::optional<at::IntArrayRef>, c10::optional<at::IntArrayRef>, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_rfft), "fft_rfft", Tensor(const Tensor &, c10::optional<int64_t>, int64_t, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_irfft), "fft_irfft", Tensor(const Tensor &, c10::optional<int64_t>, int64_t, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_rfft2), "fft_rfft2", Tensor(const Tensor &, c10::optional<at::IntArrayRef>, at::IntArrayRef, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_irfft2), "fft_irfft2", Tensor(const Tensor &, c10::optional<at::IntArrayRef>, at::IntArrayRef, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_rfftn), "fft_rfftn", Tensor(const Tensor &, c10::optional<at::IntArrayRef>, c10::optional<at::IntArrayRef>, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_irfftn), "fft_irfftn", Tensor(const Tensor &, c10::optional<at::IntArrayRef>, c10::optional<at::IntArrayRef>, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_hfft), "fft_hfft", Tensor(const Tensor &, c10::optional<int64_t>, int64_t, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(fft_ihfft), "fft_ihfft", Tensor(const Tensor &, c10::optional<int64_t>, int64_t, c10::optional<c10::string_view>), fp32)
+  KERNEL_CPU(ADD_NS(conv_tbc), "conv_tbc", Tensor(const Tensor &, const Tensor &, const Tensor &, int64_t), fp32)
+  KERNEL_CPU(ADD_NS(linalg_matrix_norm), "linalg_matrix_norm", Tensor(const Tensor &, const at::Scalar &, at::IntArrayRef, bool, c10::optional<at::ScalarType>), fp32)
+  KERNEL_CPU(ADD_NS(linalg_matrix_norm), "linalg_matrix_norm.str_ord", Tensor(const Tensor &, c10::string_view, at::IntArrayRef, bool, c10::optional<at::ScalarType>), fp32)
+  KERNEL_CPU(ADD_NS(linalg_cond), "linalg_cond", Tensor(const Tensor &, const c10::optional<at::Scalar> &), fp32)
+  KERNEL_CPU(ADD_NS(linalg_cond), "linalg_cond.p_str", Tensor(const Tensor &, c10::string_view), fp32)
+  KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank", Tensor(const Tensor &, const c10::optional<double>, bool), fp32)
+  KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.tol_tensor", Tensor(const Tensor &, const Tensor &, bool), fp32)
+  KERNEL_CPU(ADD_NS(linalg_solve), "linalg_solve", Tensor(const Tensor &, const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(linalg_cholesky), "linalg_cholesky", Tensor(const Tensor &, bool), fp32)
+  KERNEL_CPU(ADD_NS(linalg_svdvals), "linalg_svdvals", Tensor(const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(linalg_eigvals), "linalg_eigvals", Tensor(const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(linalg_eigvalsh), "linalg_eigvalsh", Tensor(const Tensor &, c10::string_view), fp32)
+  KERNEL_CPU(ADD_NS(linalg_inv), "linalg_inv", Tensor(const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(linalg_householder_product), "linalg_householder_product", Tensor(const Tensor &, const Tensor &), fp32)
+  KERNEL_CPU(ADD_NS(linalg_tensorinv), "linalg_tensorinv", Tensor(const Tensor &, int64_t), fp32)
+  KERNEL_CPU(ADD_NS(linalg_tensorsolve), "linalg_tensorsolve", Tensor(const Tensor &, const Tensor &, c10::optional<at::IntArrayRef>), fp32)
   KERNEL_CPU(ADD_NS(fake_quantize_per_tensor_affine), "fake_quantize_per_tensor_affine", Tensor (const Tensor &, double, int64_t, int64_t, int64_t), fp32)
+  KERNEL_CPU(ADD_NS(glu), "glu", Tensor (const Tensor &, int64_t), fp32)
 
-  // promote
-  KERNEL_CPU(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote)
-  KERNEL_CPU(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
+  m.impl(TORCH_SELECTIVE_NAME("aten::cummax"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, int64_t),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, int64_t),
+                                 &ADD_NS(cummax)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::cummax.dimname"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, at::Dimname),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, at::Dimname),
+                                 &ADD_NS(cummax)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::cummin"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, int64_t),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, int64_t),
+                                 &ADD_NS(cummin)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::cummin.dimname"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, at::Dimname),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, at::Dimname),
+                                 &ADD_NS(cummin)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::eig"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, bool),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, bool),
+                                 &ADD_NS(eig)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::geqrf"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &),
+                                 std::tuple<Tensor, Tensor> (const Tensor &),
+                                 &ADD_NS(geqrf)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::lstsq"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, const Tensor &),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, const Tensor &),
+                                 &ADD_NS(lstsq)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::_lu_with_info"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor, Tensor> (const Tensor &, bool, bool),
+                                 std::tuple<Tensor, Tensor, Tensor> (const Tensor &, bool, bool),
+                                 &ADD_NS(_lu_with_info)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::lu_unpack"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor, Tensor> (const Tensor &, const Tensor &, bool, bool),
+                                 std::tuple<Tensor, Tensor, Tensor> (const Tensor &, const Tensor &, bool, bool),
+                                 &ADD_NS(lu_unpack)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::qr"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, bool),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, bool),
+                                 &ADD_NS(qr)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::solve"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, const Tensor &),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, const Tensor &),
+                                 &ADD_NS(solve)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::svd"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor, Tensor> (const Tensor &, bool, bool),
+                                 std::tuple<Tensor, Tensor, Tensor> (const Tensor &, bool, bool),
+                                 &ADD_NS(svd)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::symeig"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, bool, bool),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, bool, bool),
+                                 &ADD_NS(symeig)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::triangular_solve"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, const Tensor &, bool, bool, bool),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, const Tensor &, bool, bool, bool),
+                                 &ADD_NS(triangular_solve)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::fractional_max_pool2d"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef, IntArrayRef, const Tensor &),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef, IntArrayRef, const Tensor &),
+                                 &ADD_NS(fractional_max_pool2d)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::fractional_max_pool3d"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef, IntArrayRef, const Tensor &),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef, IntArrayRef, const Tensor &),
+                                 &ADD_NS(fractional_max_pool3d)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_max_pool1d"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef),
+                                 &ADD_NS(adaptive_max_pool1d)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_max_pool2d"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef),
+                                 &ADD_NS(adaptive_max_pool2d)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_max_pool3d"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef),
+                                 &ADD_NS(adaptive_max_pool3d)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::multilabel_margin_loss_forward"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, const Tensor &, int64_t),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, const Tensor &, int64_t),
+                                 &ADD_NS(multilabel_margin_loss_forward)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::linalg_qr"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, c10::string_view),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, c10::string_view),
+                                 &ADD_NS(linalg_qr)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::linalg_cholesky_ex"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, bool, bool),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, bool, bool),
+                                 &ADD_NS(linalg_cholesky_ex)>::type::call)));
 
-  m.impl(TORCH_SELECTIVE_NAME("aten::topk"),
+  m.impl(TORCH_SELECTIVE_NAME("aten::linalg_svd"),
          TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
-                                 std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool, bool),
-                                 std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool, bool),
-                                 &ADD_NS(topk)>::type::call)));
+                                 std::tuple<Tensor, Tensor, Tensor> (const Tensor &, bool),
+                                 std::tuple<Tensor, Tensor, Tensor> (const Tensor &, bool),
+                                 &ADD_NS(linalg_svd)>::type::call)));
 
-  m.impl(TORCH_SELECTIVE_NAME("aten::sort"),
+  m.impl(TORCH_SELECTIVE_NAME("aten::linalg_eig"),
          TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
-                                 std::tuple<Tensor,Tensor> (const Tensor &, int64_t, bool),
-                                 std::tuple<Tensor,Tensor> (const Tensor &, int64_t, bool),
-                                 &ADD_NS(sort)>::type::call)));
+                                 std::tuple<Tensor, Tensor> (const Tensor &),
+                                 std::tuple<Tensor, Tensor> (const Tensor &),
+                                 &ADD_NS(linalg_eig)>::type::call)));
 
-   m.impl(TORCH_SELECTIVE_NAME("aten::kthvalue"),
+  m.impl(TORCH_SELECTIVE_NAME("aten::linalg_eigh"),
          TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
-                                 std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool),
-                                 std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool),
-                                 &ADD_NS(kthvalue)>::type::call)));
+                                 std::tuple<Tensor, Tensor> (const Tensor &, c10::string_view),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, c10::string_view),
+                                 &ADD_NS(linalg_eigh)>::type::call)));
 
-   m.impl(TORCH_SELECTIVE_NAME("aten::kthvalue.dimname"),
+  m.impl(TORCH_SELECTIVE_NAME("aten::linalg_lstsq"),
          TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
-                                 std::tuple<Tensor,Tensor> (const Tensor &, int64_t, at::Dimname, bool),
-                                 std::tuple<Tensor,Tensor> (const Tensor &, int64_t, at::Dimname, bool),
-                                 &ADD_NS(kthvalue)>::type::call)));
+                                 std::tuple<Tensor, Tensor, Tensor, Tensor> (const Tensor &, const Tensor &, c10::optional<double>, c10::optional<c10::string_view>),
+                                 std::tuple<Tensor, Tensor, Tensor, Tensor> (const Tensor &, const Tensor &, c10::optional<double>, c10::optional<c10::string_view>),
+                                 &ADD_NS(linalg_lstsq)>::type::call)));
+
+  m.impl(TORCH_SELECTIVE_NAME("aten::linalg_inv_ex"),
+         TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
+                                 std::tuple<Tensor, Tensor> (const Tensor &, bool),
+                                 std::tuple<Tensor, Tensor> (const Tensor &, bool),
+                                 &ADD_NS(linalg_inv_ex)>::type::call)));
+
+  // promote
+  KERNEL_CPU(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote)
+  KERNEL_CPU(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
+  KERNEL_CPU(ADD_NS(index_copy), "index_copy", Tensor (const Tensor &, int64_t, const Tensor &, const Tensor &), promote)
+  KERNEL_CPU(ADD_NS(index_copy), "index_copy.dimname", Tensor (const Tensor &, at::Dimname, const Tensor &, const Tensor &), promote)
+
 }
 } // namespace
 } // namespace autocast
index 615aaf9..77e7f15 100755 (executable)
@@ -75,6 +75,7 @@ TESTS = [
     "distributed/test_pg_wrapper",
     "distributed/algorithms/test_join",
     "test_cuda",
+    "test_autocast",
     "test_jit_cuda_fuser",
     "test_cuda_primary_ctx",
     "test_dataloader",
index 08ea200..8c65f72 100644 (file)
@@ -5,5 +5,5 @@ class autocast(torch.autocast_mode.autocast):
     See :class:`torch.autocast`.
     ``torch.cpu.amp.autocast(args...)`` is equivalent to ``torch.autocast("cpu", args...)``
     """
-    def __init__(self, enabled=True, dtype=torch.float16):
+    def __init__(self, enabled=True, dtype=torch.bfloat16):
         super().__init__("cpu", enabled=enabled, dtype=dtype)
index 754ccca..8350845 100644 (file)
@@ -307,7 +307,6 @@ class AutocastCPUTestLists(object):
             ("conv1d", conv_args_fp32[0]),
             ("conv2d", conv_args_fp32[1]),
             ("conv3d", conv_args_fp32[2]),
-            ("log_softmax", pointwise0_fp32 + (0,)),
             ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
                      torch.randn((n, n, n), device=dev, dtype=torch.float32))),
             ("mm", mat0_fp32 + mat1_fp32),
@@ -319,24 +318,22 @@ class AutocastCPUTestLists(object):
                                     torch.randn((n, n, n), device=dev, dtype=torch.float32))),
         ]
         self.torch_fp32 = [
+            ("conv_transpose1d", conv_args_bf16[0]),
+            ("conv_transpose2d", conv_args_bf16[1]),
             ("conv_transpose3d", conv_args_bf16[2]),
             ("batch_norm", dummy_bf16[2], {"weight": None, "bias": None, "running_mean": torch.rand((n), dtype=torch.float32),
                                            "running_var": torch.rand((n), dtype=torch.float32), "training": False,
                                            "momentum": 0.1, "eps": 1e-5, "cudnn_enabled": False}),
-            ("max_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}),
             ("dropout", dummy_bf16[2], {"p": 0.1, "train": False}),
             ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
-            ("pow", ((pointwise0_bf16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_bf16),
-            ("pow", ((pointwise0_bf16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)),
-            ("instance_norm", dummy_bf16[2], {"weight": None, "bias": None, "running_mean": torch.rand((n), dtype=torch.float32),
-                                              "running_var": torch.rand((n), dtype=torch.float32), "use_input_stats": False,
+            ("instance_norm", dummy_bf16[1], {"weight": None, "bias": None, "running_mean": None,
+                                              "running_var": None, "use_input_stats": True,
                                               "momentum": 0.1, "eps": 1e-5, "cudnn_enabled": False}),
         ]
         self.nn_bf16 = [
             ("linear", mat0_fp32 + mat1_fp32),
         ]
         self.nn_fp32 = [
-            ("adaptive_avg_pool2d", dummy_bf16[2], {"output_size": (3, 2)}),
             ("avg_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}),
             ("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
             ("gelu", dummy_bf16[3]),
@@ -348,9 +345,8 @@ class AutocastCPUTestLists(object):
             ("upsample_trilinear3d", dummy_bf16[4], {"output_size": (n, n, n), "align_corners": False}),
             ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) +
                                      (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
-            ("smooth_l1_loss", mat0_bf16 + mat1_bf16),
             ("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}),
-            ("std", dummy_bf16[2]),
+            ("smooth_l1_loss", mat0_bf16 + mat1_bf16),
         ]
         self.torch_need_autocast_promote = [
             ("cat", (pointwise0_bf16 + pointwise1_fp32,)),