Port `norm` kernel to structured kernels. (#62711)
authorYukio Siraichi <yukio.siraichi@gmail.com>
Fri, 13 Aug 2021 15:20:19 +0000 (08:20 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 15:27:48 +0000 (08:27 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62711

Tracking issue: #55070

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D30109866

Pulled By: ezyang

fbshipit-source-id: 894c9496894d059c7690a174b75bbd4db7ed6016

aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/native_functions.yaml
torch/csrc/jit/runtime/static/ops.cpp

index d0800f7..6e5a153 100644 (file)
@@ -48,7 +48,6 @@ inline ScalarType get_dtype_from_self(
 } // namespace native
 
 namespace meta {
-
 void resize_reduction(
     impl::MetaBase& meta,
     const Tensor& self,
@@ -84,19 +83,22 @@ IntArrayRef optional_to_arrayref(const c10::optional<int64_t>& opt) {
   return opt.has_value() ? opt.value() : IntArrayRef{};
 }
 
-ScalarType check_allany_and_get_output_dtype(
-    const char* name,
-    const Tensor& self,
-    const Tensor& result,
-    bool keepdim) {
+ScalarType get_result_or_bytebool_dtype(const Tensor& self, const Tensor& result) {
+  // Refer [all, any : uint8 compatibility]
+  if (result.defined()) {
+    return result.scalar_type();
+  } else {
+    return (self.scalar_type() == kByte) ? kByte : kBool;
+  }
+}
+
+void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
   // Refer [all, any : uint8 compatibility]
   TORCH_CHECK(
       self.layout() == Layout::Strided,
       name, " only supports strided layout, got: ",
       self.layout());
 
-  ScalarType out_dtype;
-
   if (result.defined()) {
     // Refer [all, any : uint8 compatibility]
     TORCH_CHECK(
@@ -104,25 +106,18 @@ ScalarType check_allany_and_get_output_dtype(
             result.scalar_type() == ScalarType::Byte,
         name, " only supports bool tensor for result, got: ",
         result.scalar_type());
-    out_dtype = result.scalar_type();
-  } else {
-    if (self.scalar_type() == ScalarType::Byte) {
-      out_dtype = self.scalar_type();
-    } else {
-      out_dtype = ScalarType::Bool;
-    }
   }
-
-  return out_dtype;
 }
 
 TORCH_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
-  auto out_dtype = check_allany_and_get_output_dtype("all", self, maybe_get_output(), keepdim);
+  check_all_any("all", self, maybe_get_output());
+  auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
   resize_reduction(*this, self, dim, keepdim, out_dtype);
 }
 
 TORCH_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
-  auto out_dtype = check_allany_and_get_output_dtype("any", self, maybe_get_output(), keepdim);
+  check_all_any("any", self, maybe_get_output());
+  auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
   resize_reduction(*this, self, dim, keepdim, out_dtype);
 }
 
@@ -206,20 +201,50 @@ TORCH_META_FUNC2(prod, dim_int)
   resize_reduction(*this, self, dim, keepdim, out_dtype);
 }
 
+void check_floating_or_complex_dtype(const char* name, ScalarType dtype) {
+  TORCH_CHECK(
+      at::isFloatingType(dtype) || at::isComplexType(dtype),
+      name, "(): input dtype should be either floating point or complex dtypes. "
+      "Got ", toString(dtype), " instead.");
+}
+
 TORCH_META_FUNC2(mean, dim)
 (const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
-  auto self_dtype = self.scalar_type();
-  TORCH_CHECK(
-      at::isFloatingType(self_dtype) || at::isComplexType(self_dtype),
-      "Can only calculate the mean of floating types. Got ",
-      toString(self_dtype), " instead.");
+  check_floating_or_complex_dtype("mean", self.scalar_type());
   auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output());
   resize_reduction(*this, self, dim, keepdim, out_dtype);
 }
 
-} // namespace meta
+ScalarType get_result_or_self_value_dtype(
+    const Tensor& self,
+    const Tensor& result,
+    const c10::optional<ScalarType>& dtype) {
+  if (result.defined()) {
+    return result.scalar_type();
+  } else {
+    return dtype.value_or(toValueType(self.scalar_type()));
+  }
+}
 
-namespace meta {
+
+
+TORCH_META_FUNC2(norm, ScalarOpt_dim)
+(const Tensor& self, const OptionalScalarRef p, IntArrayRef dim, bool keepdim) {
+  check_floating_or_complex_dtype("norm", self.scalar_type());
+  auto out_dtype = get_result_or_self_value_dtype(self, maybe_get_output(), c10::nullopt);
+  resize_reduction(*this, self, dim, keepdim, out_dtype);
+}
+
+TORCH_META_FUNC2(norm, ScalarOpt_dim_dtype)
+(const Tensor& self,
+ const OptionalScalarRef p,
+ IntArrayRef dim,
+ bool keepdim,
+ ScalarType dtype) {
+  check_floating_or_complex_dtype("norm", dtype);
+  auto out_dtype = get_result_or_self_value_dtype(self, maybe_get_output(), dtype);
+  resize_reduction(*this, self, dim, keepdim, out_dtype);
+}
 
 TORCH_META_FUNC(aminmax)
 (const Tensor& self, c10::optional<int64_t> dim_opt, bool keepdim) {
@@ -1198,91 +1223,72 @@ Tensor& special_logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim
   return at::logsumexp_out(result, self, dims, keepdim);
 }
 
-static Tensor& norm_out(Tensor &result, const Tensor &self, const optional<Scalar>& opt_p,
-                               IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
-  auto p = opt_p.value_or(2.0).to<double>();
-  TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
-              "norm only supports CPU and CUDA device types, but got: ", self.device().type());
-  TORCH_CHECK(self.layout() == Layout::Strided,
-              "norm only supports strided layout, but got: ", self.layout());
-
-  ScalarType in_dtype = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
-  TORCH_CHECK(
-      at::isFloatingType(in_dtype) || at::isComplexType(in_dtype),
-      "Can only calculate the norm of floating point and complex dtypes. Got ",
-      toString(in_dtype),
-      " instead.");
-
-  ScalarType out_dtype = result.defined() ? result.scalar_type() : (opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type()));
+void impl_func_norm(
+    const Tensor& self,
+    const OptionalScalarRef& opt_p,
+    IntArrayRef dim,
+    bool keepdim,
+    optional<ScalarType> opt_dtype,
+    const Tensor& result) {
+  auto p = opt_p.has_value() ? opt_p.get() : Scalar(2.0).to<double>();
+  auto in_dtype = opt_dtype.value_or(self.scalar_type());
+  auto out_dtype = result.scalar_type();
 
-// omit in_dtype in the following call, to avoid make_reduction explicitly casting input to out_dtype
-  auto iter = isComplexType(self.scalar_type()) ?
-      make_reduction("norm", result, self, dim, keepdim, in_dtype, out_dtype) :
-      make_reduction("norm", result, self, dim, keepdim, out_dtype);
+  // omit in_dtype in the following call, to avoid make_reduction explicitly
+  // casting input to out_dtype
+  auto iter = isComplexType(self.scalar_type())
+      ? meta::make_reduction(self, result, dim, keepdim, in_dtype)
+      : meta::make_reduction_from_out_ty(self, result, dim, keepdim, out_dtype);
 
   if (iter.numel() == 0) {
     result.zero_();
   } else {
     norm_stub(iter.device_type(), iter, p);
   }
-  return result;
-}
-
-static inline Tensor _norm(const Tensor &self, const Scalar& p) {
-  if (self.is_sparse()) {
-    // Sparse tensors need a different implementation because their values
-    // are accessed with a different API than strided tensors
-    return at::native_norm(self, p);
-  } else {
-    TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
-                "norm only supports CPU AND CUDA device type, got: ", self.device().type());
-    TORCH_CHECK(self.layout() == Layout::Strided,
-                "norm only supports strided layout, got: ", self.layout());
-    TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
-                "norm only supports floating-point dtypes");
-
-    ScalarType dtype = toValueType(self.scalar_type());
-    Tensor result = create_reduction_result(self, IntArrayRef{}, false, dtype);
-    return at::native::norm_out(result, self, p, IntArrayRef{}, false, c10::nullopt);
-  }
 }
 
-Tensor &norm_out(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, ScalarType dtype, Tensor& result) {
-  return at::native::norm_out(result, self, p, dim, keepdim, optional<ScalarType>(dtype));
+TORCH_IMPL_FUNC(norm_out)
+(const Tensor& self,
+ const OptionalScalarRef p,
+ IntArrayRef dim,
+ bool keepdim,
+ const Tensor& result) {
+  impl_func_norm(self, p, dim, keepdim, c10::nullopt, result);
 }
 
-Tensor &norm_out(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, Tensor& result) {
-  return at::native::norm_out(result, self, p, dim, keepdim, c10::nullopt);
+TORCH_IMPL_FUNC(norm_dtype_out)
+(const Tensor& self,
+ const OptionalScalarRef p,
+ IntArrayRef dim,
+ bool keepdim,
+ ScalarType dtype,
+ const Tensor& result) {
+  impl_func_norm(self, p, dim, keepdim, dtype, result);
 }
 
-static Tensor norm(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim,
-            optional<ScalarType> opt_dtype) {
-  if (self.is_sparse()) {
-    // Sparse tensors need a different implementation because their values
-    // are accessed with a different API than strided tensors
-    return at::native_norm(self, p, dim, keepdim, opt_dtype);
-  } else {
-    ScalarType out_dtype = value_or_else(opt_dtype, [&] {return toValueType(self.scalar_type());});
-    Tensor result = create_reduction_result(self, dim, keepdim, out_dtype);
-    return at::native::norm_out(result, self, p, dim, keepdim, opt_dtype);
-  }
+Tensor sparse_norm(
+    const Tensor& self,
+    const optional<Scalar>& p,
+    IntArrayRef dim,
+    bool keepdim) {
+  return at::native_norm(self, p, dim, keepdim, c10::nullopt);
 }
 
-Tensor norm(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, ScalarType dtype) {
-  return at::native::norm(self, p, dim, keepdim, optional<ScalarType>(dtype));
+Tensor sparse_dtype_norm(
+    const Tensor& self,
+    const optional<Scalar>& p,
+    IntArrayRef dim,
+    bool keepdim,
+    ScalarType dtype) {
+  return at::native_norm(self, p, dim, keepdim, dtype);
 }
 
 Tensor norm(const Tensor& self, const optional<Scalar>& p, ScalarType dtype) {
-  return at::native::norm(self, p, IntArrayRef{}, false, optional<ScalarType>(dtype));
-}
-
-Tensor norm(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim) {
-  return at::native::norm(self, p, dim, keepdim, c10::nullopt);
+  return at::norm(self, p, IntArrayRef{}, false, dtype);
 }
 
-// leave it so we support sparse tensors
 Tensor norm(const Tensor& self, const Scalar& p) {
-  return at::native::_norm(self, p);
+  return at::norm(self, p, IntArrayRef{}, false);
 }
 
 // Note [all, any : uint8 compatibility]:
@@ -1320,8 +1326,8 @@ inline TensorIterator get_allany_iter(
 Tensor all(const Tensor& self) {
   Tensor result;
 
-  auto out_dtype =
-      meta::check_allany_and_get_output_dtype("all", self, result, false);
+  meta::check_all_any("all", self, result);
+  auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
   auto shape = meta::get_reduction_shape(self, {}, false);
 
   result = at::empty(shape, self.options().dtype(out_dtype));
@@ -1353,8 +1359,8 @@ inline const Tensor & _any(const Tensor & result, TensorIterator & iter) {
 Tensor any(const Tensor& self) {
   Tensor result;
 
-  auto out_dtype =
-      meta::check_allany_and_get_output_dtype("any", self, result, false);
+  meta::check_all_any("any", self, result);
+  auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
   auto shape = meta::get_reduction_shape(self, {}, false);
 
   result = at::empty(shape, self.options().dtype(out_dtype));
index 00b7e6d..663e8df 100644 (file)
     CPU, CUDA, SparseCPU, SparseCUDA: norm
 
 - func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
+  structured_delegate: norm.dtype_out
   device_check: NoCheck   # TensorIterator
   variants: function, method
   dispatch:
-    CPU, CUDA, SparseCPU, SparseCUDA: norm
+    SparseCPU, SparseCUDA: sparse_dtype_norm
 
 - func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
+  structured_delegate: norm.out
   device_check: NoCheck   # TensorIterator
   variants: function, method
   dispatch:
-    CPU, CUDA, SparseCPU, SparseCUDA: norm
+    SparseCPU, SparseCUDA: sparse_norm
 
 - func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
+  structured: True
   device_check: NoCheck   # TensorIterator
   dispatch:
-    CPU, CUDA: norm_out
+    CPU, CUDA: norm_dtype_out
 
 - func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
   device_check: NoCheck   # TensorIterator
   dispatch:
     CPU, CUDA: norm_out
index 9e48c30..bd74666 100644 (file)
@@ -1437,7 +1437,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator {
     const size_t num_inp = p_node->inputs().size();
     const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
     if (num_inp == 3) {
-      at::native::norm_out(
+      at::cpu::norm_outf(
           in0_t,
           in1_s,
           c10::IntArrayRef{},
@@ -1448,7 +1448,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator {
     }
 
     if (num_inp > 4) {
-      at::native::norm_out(
+      at::cpu::norm_outf(
           in0_t,
           in1_s,
           p_node->Input(2).toIntVector(), // dim
@@ -1457,7 +1457,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator {
           out_t);
       return;
     }
-    at::native::norm_out(
+    at::cpu::norm_outf(
         in0_t,
         in1_s,
         p_node->Input(2).toIntVector(), // dim