} // namespace native
namespace meta {
-
void resize_reduction(
impl::MetaBase& meta,
const Tensor& self,
return opt.has_value() ? opt.value() : IntArrayRef{};
}
-ScalarType check_allany_and_get_output_dtype(
- const char* name,
- const Tensor& self,
- const Tensor& result,
- bool keepdim) {
+ScalarType get_result_or_bytebool_dtype(const Tensor& self, const Tensor& result) {
+ // Refer [all, any : uint8 compatibility]
+ if (result.defined()) {
+ return result.scalar_type();
+ } else {
+ return (self.scalar_type() == kByte) ? kByte : kBool;
+ }
+}
+
+void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
self.layout() == Layout::Strided,
name, " only supports strided layout, got: ",
self.layout());
- ScalarType out_dtype;
-
if (result.defined()) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
result.scalar_type() == ScalarType::Byte,
name, " only supports bool tensor for result, got: ",
result.scalar_type());
- out_dtype = result.scalar_type();
- } else {
- if (self.scalar_type() == ScalarType::Byte) {
- out_dtype = self.scalar_type();
- } else {
- out_dtype = ScalarType::Bool;
- }
}
-
- return out_dtype;
}
TORCH_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
- auto out_dtype = check_allany_and_get_output_dtype("all", self, maybe_get_output(), keepdim);
+ check_all_any("all", self, maybe_get_output());
+ auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
}
TORCH_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
- auto out_dtype = check_allany_and_get_output_dtype("any", self, maybe_get_output(), keepdim);
+ check_all_any("any", self, maybe_get_output());
+ auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
}
resize_reduction(*this, self, dim, keepdim, out_dtype);
}
+void check_floating_or_complex_dtype(const char* name, ScalarType dtype) {
+ TORCH_CHECK(
+ at::isFloatingType(dtype) || at::isComplexType(dtype),
+ name, "(): input dtype should be either floating point or complex dtypes. "
+ "Got ", toString(dtype), " instead.");
+}
+
TORCH_META_FUNC2(mean, dim)
(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
- auto self_dtype = self.scalar_type();
- TORCH_CHECK(
- at::isFloatingType(self_dtype) || at::isComplexType(self_dtype),
- "Can only calculate the mean of floating types. Got ",
- toString(self_dtype), " instead.");
+ check_floating_or_complex_dtype("mean", self.scalar_type());
auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
}
-} // namespace meta
+ScalarType get_result_or_self_value_dtype(
+ const Tensor& self,
+ const Tensor& result,
+ const c10::optional<ScalarType>& dtype) {
+ if (result.defined()) {
+ return result.scalar_type();
+ } else {
+ return dtype.value_or(toValueType(self.scalar_type()));
+ }
+}
-namespace meta {
+
+
+TORCH_META_FUNC2(norm, ScalarOpt_dim)
+(const Tensor& self, const OptionalScalarRef p, IntArrayRef dim, bool keepdim) {
+ check_floating_or_complex_dtype("norm", self.scalar_type());
+ auto out_dtype = get_result_or_self_value_dtype(self, maybe_get_output(), c10::nullopt);
+ resize_reduction(*this, self, dim, keepdim, out_dtype);
+}
+
+TORCH_META_FUNC2(norm, ScalarOpt_dim_dtype)
+(const Tensor& self,
+ const OptionalScalarRef p,
+ IntArrayRef dim,
+ bool keepdim,
+ ScalarType dtype) {
+ check_floating_or_complex_dtype("norm", dtype);
+ auto out_dtype = get_result_or_self_value_dtype(self, maybe_get_output(), dtype);
+ resize_reduction(*this, self, dim, keepdim, out_dtype);
+}
TORCH_META_FUNC(aminmax)
(const Tensor& self, c10::optional<int64_t> dim_opt, bool keepdim) {
return at::logsumexp_out(result, self, dims, keepdim);
}
-static Tensor& norm_out(Tensor &result, const Tensor &self, const optional<Scalar>& opt_p,
- IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
- auto p = opt_p.value_or(2.0).to<double>();
- TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
- "norm only supports CPU and CUDA device types, but got: ", self.device().type());
- TORCH_CHECK(self.layout() == Layout::Strided,
- "norm only supports strided layout, but got: ", self.layout());
-
- ScalarType in_dtype = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
- TORCH_CHECK(
- at::isFloatingType(in_dtype) || at::isComplexType(in_dtype),
- "Can only calculate the norm of floating point and complex dtypes. Got ",
- toString(in_dtype),
- " instead.");
-
- ScalarType out_dtype = result.defined() ? result.scalar_type() : (opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type()));
+void impl_func_norm(
+ const Tensor& self,
+ const OptionalScalarRef& opt_p,
+ IntArrayRef dim,
+ bool keepdim,
+ optional<ScalarType> opt_dtype,
+ const Tensor& result) {
+ auto p = opt_p.has_value() ? opt_p.get() : Scalar(2.0).to<double>();
+ auto in_dtype = opt_dtype.value_or(self.scalar_type());
+ auto out_dtype = result.scalar_type();
-// omit in_dtype in the following call, to avoid make_reduction explicitly casting input to out_dtype
- auto iter = isComplexType(self.scalar_type()) ?
- make_reduction("norm", result, self, dim, keepdim, in_dtype, out_dtype) :
- make_reduction("norm", result, self, dim, keepdim, out_dtype);
+ // omit in_dtype in the following call, to avoid make_reduction explicitly
+ // casting input to out_dtype
+ auto iter = isComplexType(self.scalar_type())
+ ? meta::make_reduction(self, result, dim, keepdim, in_dtype)
+ : meta::make_reduction_from_out_ty(self, result, dim, keepdim, out_dtype);
if (iter.numel() == 0) {
result.zero_();
} else {
norm_stub(iter.device_type(), iter, p);
}
- return result;
-}
-
-static inline Tensor _norm(const Tensor &self, const Scalar& p) {
- if (self.is_sparse()) {
- // Sparse tensors need a different implementation because their values
- // are accessed with a different API than strided tensors
- return at::native_norm(self, p);
- } else {
- TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
- "norm only supports CPU AND CUDA device type, got: ", self.device().type());
- TORCH_CHECK(self.layout() == Layout::Strided,
- "norm only supports strided layout, got: ", self.layout());
- TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
- "norm only supports floating-point dtypes");
-
- ScalarType dtype = toValueType(self.scalar_type());
- Tensor result = create_reduction_result(self, IntArrayRef{}, false, dtype);
- return at::native::norm_out(result, self, p, IntArrayRef{}, false, c10::nullopt);
- }
}
-Tensor &norm_out(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, ScalarType dtype, Tensor& result) {
- return at::native::norm_out(result, self, p, dim, keepdim, optional<ScalarType>(dtype));
+TORCH_IMPL_FUNC(norm_out)
+(const Tensor& self,
+ const OptionalScalarRef p,
+ IntArrayRef dim,
+ bool keepdim,
+ const Tensor& result) {
+ impl_func_norm(self, p, dim, keepdim, c10::nullopt, result);
}
-Tensor &norm_out(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, Tensor& result) {
- return at::native::norm_out(result, self, p, dim, keepdim, c10::nullopt);
+TORCH_IMPL_FUNC(norm_dtype_out)
+(const Tensor& self,
+ const OptionalScalarRef p,
+ IntArrayRef dim,
+ bool keepdim,
+ ScalarType dtype,
+ const Tensor& result) {
+ impl_func_norm(self, p, dim, keepdim, dtype, result);
}
-static Tensor norm(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim,
- optional<ScalarType> opt_dtype) {
- if (self.is_sparse()) {
- // Sparse tensors need a different implementation because their values
- // are accessed with a different API than strided tensors
- return at::native_norm(self, p, dim, keepdim, opt_dtype);
- } else {
- ScalarType out_dtype = value_or_else(opt_dtype, [&] {return toValueType(self.scalar_type());});
- Tensor result = create_reduction_result(self, dim, keepdim, out_dtype);
- return at::native::norm_out(result, self, p, dim, keepdim, opt_dtype);
- }
+Tensor sparse_norm(
+ const Tensor& self,
+ const optional<Scalar>& p,
+ IntArrayRef dim,
+ bool keepdim) {
+ return at::native_norm(self, p, dim, keepdim, c10::nullopt);
}
-Tensor norm(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, ScalarType dtype) {
- return at::native::norm(self, p, dim, keepdim, optional<ScalarType>(dtype));
+Tensor sparse_dtype_norm(
+ const Tensor& self,
+ const optional<Scalar>& p,
+ IntArrayRef dim,
+ bool keepdim,
+ ScalarType dtype) {
+ return at::native_norm(self, p, dim, keepdim, dtype);
}
Tensor norm(const Tensor& self, const optional<Scalar>& p, ScalarType dtype) {
- return at::native::norm(self, p, IntArrayRef{}, false, optional<ScalarType>(dtype));
-}
-
-Tensor norm(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim) {
- return at::native::norm(self, p, dim, keepdim, c10::nullopt);
+ return at::norm(self, p, IntArrayRef{}, false, dtype);
}
-// leave it so we support sparse tensors
Tensor norm(const Tensor& self, const Scalar& p) {
- return at::native::_norm(self, p);
+ return at::norm(self, p, IntArrayRef{}, false);
}
// Note [all, any : uint8 compatibility]:
Tensor all(const Tensor& self) {
Tensor result;
- auto out_dtype =
- meta::check_allany_and_get_output_dtype("all", self, result, false);
+ meta::check_all_any("all", self, result);
+ auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
auto shape = meta::get_reduction_shape(self, {}, false);
result = at::empty(shape, self.options().dtype(out_dtype));
Tensor any(const Tensor& self) {
Tensor result;
- auto out_dtype =
- meta::check_allany_and_get_output_dtype("any", self, result, false);
+ meta::check_all_any("any", self, result);
+ auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
auto shape = meta::get_reduction_shape(self, {}, false);
result = at::empty(shape, self.options().dtype(out_dtype));