Ensure there aren't variables in checked_tensor_unwrap, checked_tenso… (#15105)
authorGregory Chanan <gchanan@fb.com>
Wed, 12 Dec 2018 17:55:42 +0000 (09:55 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 12 Dec 2018 17:58:03 +0000 (09:58 -0800)
Summary:
…r_list_unwrap.

These functions use unsafeGetTensorImpl(), which doesn't work with Variables (in a silent way that may blow up later).
So let's do early checking.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15105

Reviewed By: ezyang

Differential Revision: D13429149

Pulled By: gchanan

fbshipit-source-id: b85f6f5b7cdb9a6dd0c40205b924c840a3920ba0

aten/src/ATen/Utils.h

index e2d1a6d..67201e1 100644 (file)
@@ -73,6 +73,9 @@ static inline TensorImpl* checked_tensor_unwrap(const Tensor& expr, const char *
     AT_ERROR("Expected object of scalar type ", scalar_type, " but got scalar type ", expr.scalar_type(),
              " for argument #", pos, " '", name, "'");
   }
+  if (expr.is_variable()) {
+    AT_ERROR("Expected Tensor (not Variable) for argument #", pos, " '", name, "'");
+  }
   return expr.unsafeGetTensorImpl();
 }
 
@@ -88,7 +91,11 @@ static inline std::vector<TensorImpl*> checked_tensor_list_unwrap(ArrayRef<Tenso
     }
     if (expr.scalar_type() != scalar_type) {
       AT_ERROR("Expected object of scalar type ", scalar_type, " but got scalar type ", expr.scalar_type(),
-               " for sequence elment ", i , " in sequence argument at position #", pos, " '", name, "'");
+               " for sequence element ", i , " in sequence argument at position #", pos, " '", name, "'");
+    }
+    if (expr.is_variable()) {
+      AT_ERROR("Expected Tensor (not Variable) for sequence element ",
+               i , " in sequence argument at position #", pos, " '", name, "'");
     }
     unwrapped.emplace_back(expr.unsafeGetTensorImpl());
   }