Skip conjugate and negate fallback for view ops and their in-place versions (#64392)
authoranjali411 <chourdiaanjali123@gmail.com>
Fri, 10 Sep 2021 16:55:50 +0000 (09:55 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 16:57:27 +0000 (09:57 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64392

cc ezyang anjali411 dylanbespalko mruberry Lezcano nikitaved

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D30866330

Pulled By: anjali411

fbshipit-source-id: 7b2f51486bf1d610ad2b1472306bab608ee69c37

aten/src/ATen/ConjugateFallback.cpp
aten/src/ATen/native/MathBitFallThroughLists.h [new file with mode: 0644]
aten/src/ATen/native/NegateFallback.cpp
aten/src/ATen/native/UnaryOps.cpp

index 2cf9538..21c044e 100644 (file)
@@ -1,4 +1,5 @@
 #include <ATen/native/MathBitsFallback.h>
+#include <ATen/native/MathBitFallThroughLists.h>
 
 namespace at {
 
@@ -28,39 +29,22 @@ TORCH_LIBRARY_IMPL(_, Conjugate, m) {
 }
 
 TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
-  m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
   m.impl("set_.source_Storage_storage_offset", torch::CppFunction::makeFallthrough());
   m.impl("set_.source_Tensor", torch::CppFunction::makeFallthrough());
   m.impl("set_", torch::CppFunction::makeFallthrough());
   m.impl("copy_", torch::CppFunction::makeFallthrough());
   m.impl("clone", torch::CppFunction::makeFallthrough());
-  m.impl("conj", torch::CppFunction::makeFallthrough());
-  m.impl("_conj", torch::CppFunction::makeFallthrough());
   m.impl("_conj_physical", torch::CppFunction::makeFallthrough());
   m.impl("conj_physical", torch::CppFunction::makeFallthrough());
   m.impl("conj_physical_", torch::CppFunction::makeFallthrough());
   m.impl("resolve_conj", torch::CppFunction::makeFallthrough());
-  m.impl("empty_like", torch::CppFunction::makeFallthrough());
-  m.impl("empty.memory_format", torch::CppFunction::makeFallthrough());
-  m.impl("empty.out", torch::CppFunction::makeFallthrough());
-  m.impl("empty_strided", torch::CppFunction::makeFallthrough());
-  m.impl("full_like", torch::CppFunction::makeFallthrough());
-  m.impl("stride.int", torch::CppFunction::makeFallthrough());
-  m.impl("stride.Dimname", torch::CppFunction::makeFallthrough());
-  m.impl("size.int", torch::CppFunction::makeFallthrough());
-  m.impl("size.Dimname", torch::CppFunction::makeFallthrough());
-  m.impl("is_complex", torch::CppFunction::makeFallthrough());
-  m.impl("view_as_real", torch::CppFunction::makeFallthrough());
-  m.impl("imag", torch::CppFunction::makeFallthrough());
-  m.impl("real", torch::CppFunction::makeFallthrough());
-  m.impl("view", torch::CppFunction::makeFallthrough());
-  m.impl("_unsafe_view", torch::CppFunction::makeFallthrough());
-  m.impl("reshape", torch::CppFunction::makeFallthrough());
+  m.impl("resolve_neg", torch::CppFunction::makeFallthrough());
+
+  // linear algebra functions
   m.impl("dot", torch::CppFunction::makeFallthrough());
   m.impl("vdot", torch::CppFunction::makeFallthrough());
   m.impl("dot.out", torch::CppFunction::makeFallthrough());
   m.impl("vdot.out", torch::CppFunction::makeFallthrough());
-  m.impl("alias", torch::CppFunction::makeFallthrough());
   m.impl("mm", torch::CppFunction::makeFallthrough());
   m.impl("mm.out", torch::CppFunction::makeFallthrough());
   m.impl("addmm", torch::CppFunction::makeFallthrough());
@@ -71,6 +55,9 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
   m.impl("baddbmm", torch::CppFunction::makeFallthrough());
   m.impl("baddbmm_", torch::CppFunction::makeFallthrough());
   m.impl("baddbmm.out", torch::CppFunction::makeFallthrough());
+
+  TORCH_VIEW_FNS(m)
+  TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
 }
 
 } // namespace at
diff --git a/aten/src/ATen/native/MathBitFallThroughLists.h b/aten/src/ATen/native/MathBitFallThroughLists.h
new file mode 100644 (file)
index 0000000..052e730
--- /dev/null
@@ -0,0 +1,68 @@
+#pragma once
+
+namespace at {
+// views and their in-place version ops
+#define TORCH_VIEW_FNS(m) \
+  m.impl("as_strided", torch::CppFunction::makeFallthrough()); \
+  m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \
+  m.impl("detach", torch::CppFunction::makeFallthrough()); \
+  m.impl("detach_", torch::CppFunction::makeFallthrough()); \
+  m.impl("diagonal", torch::CppFunction::makeFallthrough()); \
+  m.impl("expand", torch::CppFunction::makeFallthrough()); \
+  m.impl("expand_as", torch::CppFunction::makeFallthrough()); \
+  m.impl("movedim.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("movedim.intlist", torch::CppFunction::makeFallthrough()); \
+  m.impl("narrow", torch::CppFunction::makeFallthrough()); \
+  m.impl("permute", torch::CppFunction::makeFallthrough()); \
+  m.impl("select.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("select.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("squeeze", torch::CppFunction::makeFallthrough()); \
+  m.impl("squeeze_", torch::CppFunction::makeFallthrough()); \
+  m.impl("transpose.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("transpose.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("transpose_", torch::CppFunction::makeFallthrough()); \
+  m.impl("t", torch::CppFunction::makeFallthrough()); \
+  m.impl("t_", torch::CppFunction::makeFallthrough()); \
+  m.impl("real", torch::CppFunction::makeFallthrough()); \
+  m.impl("imag", torch::CppFunction::makeFallthrough()); \
+  m.impl("view_as_real", torch::CppFunction::makeFallthrough()); \
+  m.impl("unflatten.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("unflatten.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("unfold", torch::CppFunction::makeFallthrough()); \
+  m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \
+  m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \
+  m.impl("view", torch::CppFunction::makeFallthrough()); \
+  m.impl("view_as", torch::CppFunction::makeFallthrough()); \
+  m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("split.Tensor", torch::CppFunction::makeFallthrough()); \
+  m.impl("split_with_sizes", torch::CppFunction::makeFallthrough()); \
+  m.impl("swapaxes", torch::CppFunction::makeFallthrough()); \
+  m.impl("swapdims", torch::CppFunction::makeFallthrough()); \
+  m.impl("chunk", torch::CppFunction::makeFallthrough()); \
+  m.impl("reshape", torch::CppFunction::makeFallthrough()); \
+  m.impl("alias", torch::CppFunction::makeFallthrough()); \
+  m.impl("hsplit.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("hsplit.array", torch::CppFunction::makeFallthrough()); \
+  m.impl("dsplit.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("dsplit.array", torch::CppFunction::makeFallthrough()); \
+  m.impl("vsplit.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \
+  m.impl("conj", torch::CppFunction::makeFallthrough()); \
+  m.impl("_conj", torch::CppFunction::makeFallthrough()); \
+  m.impl("_unsafe_view", torch::CppFunction::makeFallthrough());
+
+#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \
+  m.impl("empty_like", torch::CppFunction::makeFallthrough()); \
+  m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \
+  m.impl("empty.out", torch::CppFunction::makeFallthrough()); \
+  m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \
+  m.impl("full_like", torch::CppFunction::makeFallthrough()); \
+  m.impl("stride.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("stride.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("size.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("size.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("is_complex", torch::CppFunction::makeFallthrough()); \
+  m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \
+  m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
+}
index d8381f5..8920a2a 100644 (file)
@@ -1,4 +1,5 @@
 #include <ATen/native/MathBitsFallback.h>
+#include <ATen/native/MathBitFallThroughLists.h>
 
 namespace at {
 
@@ -28,34 +29,17 @@ TORCH_LIBRARY_IMPL(_, Negative, m) {
 }
 
 TORCH_LIBRARY_IMPL(aten, Negative, m) {
-  m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
   m.impl("set_.source_Storage_storage_offset", torch::CppFunction::makeFallthrough());
   m.impl("set_.source_Tensor", torch::CppFunction::makeFallthrough());
   m.impl("set_", torch::CppFunction::makeFallthrough());
   m.impl("copy_", torch::CppFunction::makeFallthrough());
   m.impl("clone", torch::CppFunction::makeFallthrough());
-  m.impl("conj", torch::CppFunction::makeFallthrough());
-  m.impl("_conj", torch::CppFunction::makeFallthrough());
   m.impl("neg_", torch::CppFunction::makeFallthrough());
   m.impl("resolve_neg", torch::CppFunction::makeFallthrough());
-  m.impl("empty_like", torch::CppFunction::makeFallthrough());
-  m.impl("empty.memory_format", torch::CppFunction::makeFallthrough());
-  m.impl("empty.out", torch::CppFunction::makeFallthrough());
-  m.impl("empty_strided", torch::CppFunction::makeFallthrough());
-  m.impl("full_like", torch::CppFunction::makeFallthrough());
-  m.impl("stride.int", torch::CppFunction::makeFallthrough());
-  m.impl("stride.Dimname", torch::CppFunction::makeFallthrough());
-  m.impl("size.int", torch::CppFunction::makeFallthrough());
-  m.impl("size.Dimname", torch::CppFunction::makeFallthrough());
-  m.impl("is_complex", torch::CppFunction::makeFallthrough());
-  m.impl("is_floating_point", torch::CppFunction::makeFallthrough());
-  m.impl("view_as_real", torch::CppFunction::makeFallthrough());
-  m.impl("imag", torch::CppFunction::makeFallthrough());
-  m.impl("real", torch::CppFunction::makeFallthrough());
-  m.impl("view", torch::CppFunction::makeFallthrough());
-  m.impl("_unsafe_view", torch::CppFunction::makeFallthrough());
-  m.impl("reshape", torch::CppFunction::makeFallthrough());
-  m.impl("alias", torch::CppFunction::makeFallthrough());
+  m.impl("resolve_conj", torch::CppFunction::makeFallthrough());
+
+  TORCH_VIEW_FNS(m)
+  TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
 }
 
 } // namespace at
index b7e5963..645e7fc 100644 (file)
@@ -442,6 +442,10 @@ Tensor& conj_physical_(Tensor& self) {
 // else returns a new negated tensor with neg bit set to 0
 Tensor resolve_neg(const Tensor& self) {
   if (!self.is_neg()) { return self; }
+  // currently a tensor should never have both conj and neg bit set
+  // the only way to get an imag bit is complex_tensor.conj().imag but there's
+  // no intended designed mechanism to enter the complex world with this imag bit
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!self.is_conj());
   // negation is materialized in `copy_()` that clone ultimately calls into
   return self.clone();
 }
@@ -450,6 +454,10 @@ Tensor resolve_neg(const Tensor& self) {
 // else returns a new negated tensor with neg bit set to 0
 Tensor resolve_conj(const Tensor& self) {
   if (!self.is_conj()) { return self; }
+  // currently a tensor should never have both conj and neg bit set
+  // the only way to get an imag bit is complex_tensor.conj().imag but there's
+  // no intended designed mechanism to enter the complex world with this imag bit
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!self.is_neg());
   // conjugation is materialized in `copy_()` that clone ultimately calls into
   return self.clone();
 }