From 58c8b253cdd5ccd4e61d5854a8a614b35498276f Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Thu, 1 Dec 2022 16:27:33 +0000 Subject: [PATCH] Replacing `is` with `==` for the dtype check. >>> a = np.ndarray([1,1]).astype(np.half) >>> a array([[0.007812]], dtype=float16) >>> a.dtype dtype('float16') >>> a.dtype == np.half True >>> a.dtype == np.float16 True >>> a.dtype is np.float16 False Checking with `is` leads to inconsistency in checking. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D139121 --- mlir/python/mlir/runtime/np_to_memref.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 5b3c3c4a..d709679 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -23,13 +23,14 @@ class F16(ctypes.Structure): _fields_ = [("f16", ctypes.c_int16)] +# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): """Converts dtype to ctype.""" - if dtp is np.dtype(np.complex128): + if dtp == np.dtype(np.complex128): return C128 - if dtp is np.dtype(np.complex64): + if dtp == np.dtype(np.complex64): return C64 - if dtp is np.dtype(np.float16): + if dtp == np.dtype(np.float16): return F16 return np.ctypeslib.as_ctypes_type(dtp) -- 2.7.4