build_binary_op(maybe_get_output(), self, other);
}
-void comparison_op_check(const Tensor& self, const Tensor& other) {
+void comparison_op_check(const Tensor& self, const Tensor& other, const Tensor& result) {
// Validate that is possible to convert zero-dim tensor's dtype to other dtype
// without overflow
if (self.scalar_type() != other.scalar_type()) {
native::check_convert(self.item(), other.scalar_type());
}
}
+ // In-place operation To avoid overflow during type promotion we will check that
+ // both dtypes of self and other are same
+ if (result.is_same(self)) {
+ TORCH_CHECK(self.dtype() == other.dtype(),
+ "Expected object of scalar type ", self.dtype(), " but got scalar type ",
+ other.dtype(), " for argument 'other'");
+ }
}
#define CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(func) \
TORCH_META_FUNC2(func, Tensor)(const Tensor& self, const Tensor& other) { \
- comparison_op_check(self, other); \
- build_comparison_op(maybe_get_output(), self, other); \
+ const Tensor& result = maybe_get_output(); \
+ comparison_op_check(self, other, result); \
+ build_comparison_op(result, self, other); \
} \
\
TORCH_META_FUNC2(func, Scalar)(const Tensor& self, const Scalar& other) { \
# TODO: update to work on CUDA, too
@onlyCPU
def test_inplace_comparison_ops_require_inputs_have_same_dtype(self, device):
- with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'):
- for op in ['lt_', 'le_', 'gt_', 'ge_', 'eq_', 'ne_', 'logical_xor_', 'logical_and_', 'logical_or_']:
+ for op in ['lt_', 'le_', 'gt_', 'ge_', 'eq_', 'ne_', 'logical_xor_', 'logical_and_', 'logical_or_']:
+ with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'):
x = torch.tensor([1], dtype=torch.int)
y = torch.tensor([2], dtype=torch.long)
in_place_method = getattr(x, op)