From 68ad9ae5bebd9efab127fa99e2bafd6852bbd8ed Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Wed, 12 Dec 2018 09:55:42 -0800 Subject: [PATCH] =?utf8?q?Ensure=20there=20aren't=20variables=20in=20check?= =?utf8?q?ed=5Ftensor=5Funwrap,=20checked=5Ftenso=E2=80=A6=20(#15105)?= MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Summary: …r_list_unwrap. These functions use unsafeGetTensorImpl(), which doesn't work with Variables (in a silent way that may blow up later). So let's do early checking. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15105 Reviewed By: ezyang Differential Revision: D13429149 Pulled By: gchanan fbshipit-source-id: b85f6f5b7cdb9a6dd0c40205b924c840a3920ba0 --- aten/src/ATen/Utils.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index e2d1a6d..67201e1 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -73,6 +73,9 @@ static inline TensorImpl* checked_tensor_unwrap(const Tensor& expr, const char * AT_ERROR("Expected object of scalar type ", scalar_type, " but got scalar type ", expr.scalar_type(), " for argument #", pos, " '", name, "'"); } + if (expr.is_variable()) { + AT_ERROR("Expected Tensor (not Variable) for argument #", pos, " '", name, "'"); + } return expr.unsafeGetTensorImpl(); } @@ -88,7 +91,11 @@ static inline std::vector checked_tensor_list_unwrap(ArrayRef