Fixed bool Tensor value change bug (#19096)
authoriurii zdebskyi <47012416+izdeby@users.noreply.github.com>
Wed, 10 Apr 2019 18:05:54 +0000 (11:05 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 10 Apr 2019 18:09:07 +0000 (11:09 -0700)
Summary:
Fixes #19077
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19096

Differential Revision: D14871044

Pulled By: izdeby

fbshipit-source-id: 61b12559c8c5b9613e00ba5933f478321ea80469

test/test_torch.py
torch/csrc/autograd/python_variable_indexing.cpp

index 5adce86..ebfd34e 100644 (file)
@@ -3014,6 +3014,13 @@ class _TestTorchMixin(object):
         self.assertTrue(x.is_cuda)
         torch.set_default_tensor_type(saved_type)
 
+    def test_bool_tensor_value_change(self):
+        for device in torch.testing.get_all_device_types():
+            x = torch.tensor([True, False], dtype=torch.bool)
+            x[0] = False
+            x[1] = True
+            self.assertEqual(x, torch.tensor([False, True], dtype=torch.bool))
+
     def test_unfold_all_devices_and_dtypes(self):
         for device in torch.testing.get_all_device_types():
             for dt in torch.testing.get_all_dtypes():
index 80461c4..eff82a9 100644 (file)
@@ -114,7 +114,7 @@ static Variable valueToTensor(const at::Type & type, PyObject* value) {
   if (THPVariable_Check(value)) {
     return reinterpret_cast<THPVariable*>(value)->cdata;
   }
-  if (THPUtils_checkLong(value)) {
+  if (THPUtils_checkLong(value) || PyBool_Check(value)) {
     return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options());
   }
   if (PyFloat_Check(value)) {