Added numpy conversion (#18505)
authorIurii Zdebskyi <iuriiz@fb.com>
Wed, 3 Apr 2019 14:22:38 +0000 (07:22 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 14:28:24 +0000 (07:28 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18505
ghimport-source-id: f3c9b9251e5793f9e192f587194ddfebb45facc1

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18505 [WIP]Added numpy conversion**
* #18166 Bool Tensor for CUDA

Differential Revision: D14646403

fbshipit-source-id: 79d39d692c778ce1981c1d35b1c33e3d93111041

c10/core/ScalarType.h
test/common_utils.py
test/test_torch.py
tools/autograd/templates/python_variable_methods.cpp
torch/csrc/utils/tensor_numpy.cpp

index 65ce4e8..2d852d1 100644 (file)
@@ -25,7 +25,7 @@ _(double,Double,d) /* 7 */ \
 _(at::ComplexHalf,ComplexHalf,z)        /* 8 */ \
 _(std::complex<float>,ComplexFloat,z)   /* 9 */ \
 _(std::complex<double>,ComplexDouble,z) /* 10 */ \
-_(bool,Bool,i) /* 11 */
+_(bool,Bool,i)     /* 11 */
 
 // If you want to support ComplexHalf for real, replace occurrences
 // of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX.  But
@@ -193,19 +193,25 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
   if (isComplexType(a) || isComplexType(b)) {
     AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
   }
+
+  // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX so that's why we have to add
+  // undefined as we are not sure what is the corrent values for the type promotions in complex type cases.
   static constexpr ScalarType _promoteTypesLookup
       [static_cast<int>(ScalarType::NumOptions)]
       [static_cast<int>(ScalarType::NumOptions)] = {
-            /* u1  i1  i2  i4  i8  f2  f4  f8  b1 */
-    /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, u1 },
-    /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, i1 },
-    /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, i2 },
-    /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, i4 },
-    /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, i8 },
-    /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, f2 },
-    /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, f4 },
-    /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, f8 },
-    /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, b1 },
+            /* u1  i1  i2  i4  i8  f2  f4  f8  c2  c4  c8  b1 */
+    /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, u1 },
+    /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, i1 },
+    /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, i2 },
+    /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, ud, ud, ud, i4 },
+    /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, ud, ud, ud, i8 },
+    /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, ud, ud, ud, f2 },
+    /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, ud, ud, ud, f4 },
+    /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud, ud, ud, f8 },
+    /* c2 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
+    /* c4 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
+    /* c8 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
+    /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1 },
   };
   return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
 }
index cca6663..6fb0d00 100644 (file)
@@ -414,6 +414,10 @@ class TestCase(expecttest.TestCase):
             self.assertEqual(x.item(), y, prec, message, allow_inf)
         elif isinstance(y, torch.Tensor) and isinstance(x, Number):
             self.assertEqual(x, y.item(), prec, message, allow_inf)
+        elif isinstance(x, torch.Tensor) and isinstance(y, numpy.bool_):
+            self.assertEqual(x.item(), y, prec, message, allow_inf)
+        elif isinstance(y, torch.Tensor) and isinstance(x, numpy.bool_):
+            self.assertEqual(x, y.item(), prec, message, allow_inf)
         elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
             def assertTensorsEqual(a, b):
                 super(TestCase, self).assertEqual(a.size(), b.size(), message)
index d2e8f38..97169ad 100644 (file)
@@ -10001,6 +10001,23 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
                 y[0][1] = 3
                 self.assertTrue(x[0][1] == 3)
 
+    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+    def test_to_numpy_bool(self):
+        x = torch.tensor([True, False], dtype=torch.bool)
+        self.assertEqual(x.dtype, torch.bool)
+
+        y = x.numpy()
+        self.assertEqual(y.dtype, np.bool)
+        for i in range(len(x)):
+            self.assertEqual(x[i], y[i])
+
+        x = torch.tensor([True], dtype=torch.bool)
+        self.assertEqual(x.dtype, torch.bool)
+
+        y = x.numpy()
+        self.assertEqual(y.dtype, np.bool)
+        self.assertEqual(x[0], y[0])
+
     def test_dlpack_conversion(self):
         x = torch.randn(1, 2, 3, 4).type('torch.FloatTensor')
         z = from_dlpack(to_dlpack(x))
@@ -10024,6 +10041,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             np.int8,
             np.uint8,
             np.longlong,
+            np.bool,
         ]
         for dtype in dtypes:
             array = np.array([1, 2, 3, 4], dtype=dtype)
@@ -10075,6 +10093,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             np.int16,
             np.int8,
             np.uint8,
+            np.bool,
         ]
 
         incorrect_byteorder = '>' if sys.byteorder == 'little' else '<'
@@ -10120,7 +10139,8 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             np.int64,
             np.int32,
             np.int16,
-            np.uint8
+            np.uint8,
+            np.bool,
         ]
         for dtype in dtypes:
             self.assertEqual(dtype(42), torch.tensor(dtype(42)).item())
index 9b2111f..0f043a6 100644 (file)
@@ -219,6 +219,15 @@ static int64_t dispatch_to_CLong(const Tensor & self) {
   return self.item<int64_t>();
 }
 
+static bool dispatch_to_Bool(const Tensor & self) {
+  AutoNoGIL no_gil;
+  OptionalDeviceGuard device_guard(device_of(self));
+  if (self.numel() != 1) {
+    throw ValueError("only one element tensors can be converted to Python scalars");
+  }
+  return self.item<bool>();
+}
+
 static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
   HANDLE_TH_ERRORS
   jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
@@ -439,6 +448,8 @@ static PyObject * THPVariable_item(PyObject* self, PyObject* args)
     return wrap(dispatch_to_CDouble(self_));
   } else if (self_.is_complex()) {
     return wrap(dispatch_to_CComplexDouble(self_));
+  } else if (self_.scalar_type() == ScalarType::Bool) {
+    return wrap(dispatch_to_Bool(self_));
   } else {
     return wrap(dispatch_to_CLong(self_));
   }
index fa0cb54..cf41742 100644 (file)
@@ -156,6 +156,7 @@ static int aten_to_dtype(const ScalarType scalar_type) {
     case kShort: return NPY_INT16;
     case kChar: return NPY_INT8;
     case kByte: return NPY_UINT8;
+    case kBool: return NPY_BOOL;
     default:
       throw ValueError("Got unsupported ScalarType ", toString(scalar_type));
   }
@@ -170,6 +171,7 @@ ScalarType numpy_dtype_to_aten(int dtype) {
     case NPY_INT16: return kShort;
     case NPY_INT8: return kChar;
     case NPY_UINT8: return kByte;
+    case NPY_BOOL: return kBool;
     default:
       // Workaround: MSVC does not support two switch cases that have the same value
       if (dtype == NPY_LONGLONG || dtype == NPY_INT64) {