From 2223737da91454336ccf880b75718db7369cdc8f Mon Sep 17 00:00:00 2001 From: Freey0 Date: Wed, 8 Sep 2021 06:40:54 -0700 Subject: [PATCH] restore test_inplace_comparison_ops_require_inputs_have_same_dtype Expected behavior (#64267) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64267 This test expects every operation to throw a runtime error. And Reinsert in-place operation test,Fix bug for comparison operation fix: #64018 Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D30720915 Pulled By: ezyang fbshipit-source-id: 215a6556d20770f70f4ced1c1f9a9753933f1d37 --- aten/src/ATen/native/BinaryOps.cpp | 14 +++++++++++--- test/test_binary_ufuncs.py | 4 ++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 4eeca76..5ffaba5 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -194,7 +194,7 @@ TORCH_META_FUNC(fmin) (const Tensor& self, const Tensor& other) { build_binary_op(maybe_get_output(), self, other); } -void comparison_op_check(const Tensor& self, const Tensor& other) { +void comparison_op_check(const Tensor& self, const Tensor& other, const Tensor& result) { // Validate that is possible to convert zero-dim tensor's dtype to other dtype // without overflow if (self.scalar_type() != other.scalar_type()) { @@ -204,12 +204,20 @@ void comparison_op_check(const Tensor& self, const Tensor& other) { native::check_convert(self.item(), other.scalar_type()); } } + // In-place operation To avoid overflow during type promotion we will check that + // both dtypes of self and other are same + if (result.is_same(self)) { + TORCH_CHECK(self.dtype() == other.dtype(), + "Expected object of scalar type ", self.dtype(), " but got scalar type ", + other.dtype(), " for argument 'other'"); + } } #define CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(func) \ TORCH_META_FUNC2(func, Tensor)(const Tensor& self, const Tensor& other) { \ - comparison_op_check(self, other); \ - build_comparison_op(maybe_get_output(), self, other); \ + const Tensor& result = maybe_get_output(); \ + comparison_op_check(self, other, result); \ + build_comparison_op(result, self, other); \ } \ \ TORCH_META_FUNC2(func, Scalar)(const Tensor& self, const Scalar& other) { \ diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 7153902..4610efb 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -228,8 +228,8 @@ class TestBinaryUfuncs(TestCase): # TODO: update to work on CUDA, too @onlyCPU def test_inplace_comparison_ops_require_inputs_have_same_dtype(self, device): - with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'): - for op in ['lt_', 'le_', 'gt_', 'ge_', 'eq_', 'ne_', 'logical_xor_', 'logical_and_', 'logical_or_']: + for op in ['lt_', 'le_', 'gt_', 'ge_', 'eq_', 'ne_', 'logical_xor_', 'logical_and_', 'logical_or_']: + with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'): x = torch.tensor([1], dtype=torch.int) y = torch.tensor([2], dtype=torch.long) in_place_method = getattr(x, op) -- 2.7.4