Fix reshape for the Lazy key (#62846)
authorAlex Suhan <asuhan@fb.com>
Fri, 6 Aug 2021 22:28:38 +0000 (15:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 6 Aug 2021 22:29:56 +0000 (15:29 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62846

Test Plan: CI

Reviewed By: zou3519

Differential Revision: D30162185

Pulled By: asuhan

fbshipit-source-id: d582dcef35ce7e8bebf161a5c93e470339891e29

aten/src/ATen/native/TensorShape.cpp
aten/src/ATen/templates/TensorBody.h
c10/core/TensorImpl.h

index 7532859..e915078 100644 (file)
@@ -1058,7 +1058,7 @@ Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
     //
     // We need to do the checks here instead of in `native_functions.yaml`
     // to preserve backwards compatibility.
-    if (! self.is_xla()) {
+    if (!self.is_xla() && !self.is_lazy()) {
       return self._reshape_alias(shape, stride.value());
     } else {
       return self.view(shape);
index 800f17b..be14980 100644 (file)
@@ -451,6 +451,11 @@ class TORCH_API Tensor {
     return impl_->is_xla();
   }
 
+  /// Returns if a `Tensor` has Lazy backend.
+  bool is_lazy() const {
+    return impl_->is_lazy();
+  }
+
   /// Returns if a `Tensor` has HIP backend.
   bool is_hip() const {
     // NB: this is not a native function to avoid dispatching overhead.
index 5b5fa3a..65d7af3 100644 (file)
@@ -840,6 +840,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
     return key_set_.has(DispatchKey::XLA);
   }
 
+  bool is_lazy() const {
+    return key_set_.has(DispatchKey::Lazy);
+  }
+
   bool is_hip() const {
     // NB: This method is not virtual and avoid dispatches for performance
     // reasons.