From 5744d5213d0bcd56b9029a4abaa5148a33ed66b3 Mon Sep 17 00:00:00 2001 From: Dmytro Dzhulgakov Date: Thu, 21 Feb 2019 09:22:12 -0800 Subject: [PATCH] Enforce non-negativity of tensor construction (#17077) Summary: Apparently, before the only way we enforced it was size>=0 in alloc_cpu. So empty((5,-5)) would fail but empty((-5,-5)) would hang :) Please suggest better place to enforce it if any. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17077 Differential Revision: D14077930 Pulled By: dzhulgakov fbshipit-source-id: 1120513300fd5448e06fa15c2d72f9b0ee5734e4 --- aten/src/ATen/native/TensorFactories.cpp | 3 +++ aten/src/ATen/native/TensorFactories.h | 8 ++++++++ aten/src/ATen/native/cuda/TensorFactories.cu | 1 + aten/src/ATen/test/basic.cpp | 8 ++++++++ 4 files changed, 20 insertions(+) diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 18672c0..1e712ad 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -91,6 +91,7 @@ Tensor _dim_arange(const Tensor& like, int64_t dim) { Tensor empty_cpu(IntArrayRef size, const TensorOptions& options) { AT_ASSERT(options.backend() == Backend::CPU); AT_ASSERT(!options.is_variable()); // is_variable should have been 'unpacked' // TODO: remove this when Variable and Tensor are merged + check_size_nonnegative(size); auto* allocator = at::getCPUAllocator(); int64_t nelements = prod_intlist(size); @@ -111,12 +112,14 @@ Tensor empty_cpu(IntArrayRef size, const TensorOptions& options) { } Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, const TensorOptions& options) { + check_size_nonnegative(size); auto t = at::native::empty_cpu({0}, options); at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride); return t; } Tensor& empty_out(Tensor& result, IntArrayRef size) { + check_size_nonnegative(size); if (result.is_sparse()) { result.sparse_resize_and_clear_(size, size.size(), 0); } else { diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index 8f68ab3..c17b4c1 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -1,5 +1,7 @@ #pragma once +#include + namespace at { namespace native { // Different combinations of row, col, and offset can lead to two cases: // @@ -56,5 +58,11 @@ inline void check_args( options.layout()) } } + +inline void check_size_nonnegative(IntArrayRef size) { + for (auto x: size) { + AT_CHECK(x >= 0, "Trying to create tensor with negative dimension ", x, ": ", size); + } +} } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index 531d58b..58b90bc 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -46,6 +46,7 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) { Tensor empty_cuda(IntArrayRef size, const TensorOptions& options) { AT_ASSERT(options.backend() == at::Backend::CUDA); AT_ASSERT(!options.is_variable()); // is_variable should have been 'unpacked' // TODO: remove this when Variable and Tensor are merged + check_size_nonnegative(size); auto* allocator = at::cuda::getCUDADeviceAllocator(); int64_t nelements = prod_intlist(size); diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index 83a2340..ebd569a 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -276,6 +276,13 @@ void TestDispatch() { ASSERT_TRUE(result.allclose(mse_loss(relu(tensor), other))); } +void TestNegativeDim(Type& type) { + ASSERT_ANY_THROW(empty({5, -5, 5}, type.options())); + ASSERT_ANY_THROW(empty({5, -5, -5}, type.options())); + Tensor tensor = empty({5, 5}, type.options()); + ASSERT_ANY_THROW(tensor.reshape({-5, -5})); +} + void test(Type& type) { TestResize(type); TestOnesAndDot(type); @@ -302,6 +309,7 @@ void test(Type& type) { TestIndexingByZerodimTensor(); TestIndexingMixedDevice(type); TestDispatch(); + TestNegativeDim(type); } TEST(BasicTest, BasicTestCPU) { -- 2.7.4