restore test_inplace_comparison_ops_require_inputs_have_same_dtype Expected behavior...
authorFreey0 <freey7955@gmail.com>
Wed, 8 Sep 2021 13:40:54 +0000 (06:40 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 13:42:12 +0000 (06:42 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64267

This test expects every operation to throw a runtime error.

And Reinsert in-place operation test,Fix bug for comparison operation

fix: #64018

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D30720915

Pulled By: ezyang

fbshipit-source-id: 215a6556d20770f70f4ced1c1f9a9753933f1d37

aten/src/ATen/native/BinaryOps.cpp
test/test_binary_ufuncs.py

index 4eeca76..5ffaba5 100644 (file)
@@ -194,7 +194,7 @@ TORCH_META_FUNC(fmin) (const Tensor& self, const Tensor& other) {
     build_binary_op(maybe_get_output(), self, other);
 }
 
-void comparison_op_check(const Tensor& self, const Tensor& other) {
+void comparison_op_check(const Tensor& self, const Tensor& other, const Tensor& result) {
   // Validate that is possible to convert zero-dim tensor's dtype to other dtype
   // without overflow
   if (self.scalar_type() != other.scalar_type()) {
@@ -204,12 +204,20 @@ void comparison_op_check(const Tensor& self, const Tensor& other) {
       native::check_convert(self.item(), other.scalar_type());
     }
   }
+  // In-place operation To avoid overflow during type promotion we will check that
+  // both dtypes of self and other are same
+  if (result.is_same(self)) {
+    TORCH_CHECK(self.dtype() == other.dtype(),
+                "Expected object of scalar type ", self.dtype(), " but got scalar type ",
+                other.dtype(), " for argument 'other'");
+  }
 }
 
 #define CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(func)                     \
   TORCH_META_FUNC2(func, Tensor)(const Tensor& self, const Tensor& other) { \
-    comparison_op_check(self, other);                                       \
-    build_comparison_op(maybe_get_output(), self, other);                   \
+    const Tensor& result = maybe_get_output();                              \
+    comparison_op_check(self, other, result);                               \
+    build_comparison_op(result, self, other);                               \
   }                                                                         \
                                                                             \
   TORCH_META_FUNC2(func, Scalar)(const Tensor& self, const Scalar& other) { \
index 7153902..4610efb 100644 (file)
@@ -228,8 +228,8 @@ class TestBinaryUfuncs(TestCase):
     # TODO: update to work on CUDA, too
     @onlyCPU
     def test_inplace_comparison_ops_require_inputs_have_same_dtype(self, device):
-        with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'):
-            for op in ['lt_', 'le_', 'gt_', 'ge_', 'eq_', 'ne_', 'logical_xor_', 'logical_and_', 'logical_or_']:
+        for op in ['lt_', 'le_', 'gt_', 'ge_', 'eq_', 'ne_', 'logical_xor_', 'logical_and_', 'logical_or_']:
+            with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'):
                 x = torch.tensor([1], dtype=torch.int)
                 y = torch.tensor([2], dtype=torch.long)
                 in_place_method = getattr(x, op)