[mlir][python][f16] add ctype python binding support for f16
authorAart Bik <ajcbik@google.com>
Thu, 2 Jun 2022 22:11:02 +0000 (15:11 -0700)
committerAart Bik <ajcbik@google.com>
Fri, 3 Jun 2022 00:21:24 +0000 (17:21 -0700)
Similar to complex128/complex64, float16 has no direct support
in the ctypes implementation. This fixes the issue by using a
custom F16 type to change the view in and out of MLIR code

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D126928

mlir/python/mlir/runtime/np_to_memref.py
mlir/test/python/execution_engine.py

index de5b8d6..5b3c3c4 100644 (file)
@@ -18,15 +18,33 @@ class C64(ctypes.Structure):
   _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
 
 
+class F16(ctypes.Structure):
+  """A ctype representation for MLIR's Float16."""
+  _fields_ = [("f16", ctypes.c_int16)]
+
+
 def as_ctype(dtp):
   """Converts dtype to ctype."""
   if dtp is np.dtype(np.complex128):
     return C128
   if dtp is np.dtype(np.complex64):
     return C64
+  if dtp is np.dtype(np.float16):
+    return F16
   return np.ctypeslib.as_ctypes_type(dtp)
 
 
+def to_numpy(array):
+  """Converts ctypes array back to numpy dtype array."""
+  if array.dtype == C128:
+    return array.view("complex128")
+  if array.dtype == C64:
+    return array.view("complex64")
+  if array.dtype == F16:
+    return array.view("float16")
+  return array
+
+
 def make_nd_memref_descriptor(rank, dtype):
 
   class MemRefDescriptor(ctypes.Structure):
@@ -105,11 +123,7 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype):
       np.ctypeslib.as_array(val[0].shape),
       np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
   )
-  if strided_arr.dtype == C128:
-    return strided_arr.view("complex128")
-  if strided_arr.dtype == C64:
-    return strided_arr.view("complex64")
-  return strided_arr
+  return to_numpy(strided_arr)
 
 
 def ranked_memref_to_numpy(ranked_memref):
@@ -121,8 +135,4 @@ def ranked_memref_to_numpy(ranked_memref):
       np.ctypeslib.as_array(ranked_memref[0].shape),
       np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
   )
-  if strided_arr.dtype == C128:
-    return strided_arr.view("complex128")
-  if strided_arr.dtype == C64:
-    return strided_arr.view("complex64")
-  return strided_arr
+  return to_numpy(strided_arr)
index 53cbac3..6eed53c 100644 (file)
@@ -266,6 +266,50 @@ def testMemrefAdd():
 run(testMemrefAdd)
 
 
+# Test addition of two f16 memrefs
+# CHECK-LABEL: TEST: testF16MemrefAdd
+def testF16MemrefAdd():
+  with Context():
+    module = Module.parse("""
+    module  {
+      func.func @main(%arg0: memref<1xf16>,
+                      %arg1: memref<1xf16>,
+                      %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
+        %0 = arith.constant 0 : index
+        %1 = memref.load %arg0[%0] : memref<1xf16>
+        %2 = memref.load %arg1[%0] : memref<1xf16>
+        %3 = arith.addf %1, %2 : f16
+        memref.store %3, %arg2[%0] : memref<1xf16>
+        return
+      }
+    } """)
+
+    arg1 = np.array([11.]).astype(np.float16)
+    arg2 = np.array([22.]).astype(np.float16)
+    arg3 = np.array([0.]).astype(np.float16)
+
+    arg1_memref_ptr = ctypes.pointer(
+        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
+    arg2_memref_ptr = ctypes.pointer(
+        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
+    arg3_memref_ptr = ctypes.pointer(
+        ctypes.pointer(get_ranked_memref_descriptor(arg3)))
+
+    execution_engine = ExecutionEngine(lowerToLLVM(module))
+    execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
+                            arg3_memref_ptr)
+    # CHECK: [11.] + [22.] = [33.]
+    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+    # test to-numpy utility
+    # CHECK: [33.]
+    npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+    log(npout)
+
+
+run(testF16MemrefAdd)
+
+
 # Test addition of two complex memrefs
 # CHECK-LABEL: TEST: testComplexMemrefAdd
 def testComplexMemrefAdd():
@@ -442,15 +486,15 @@ def testSharedLibLoad():
         ctypes.pointer(get_ranked_memref_descriptor(arg0)))
 
     if sys.platform == 'win32':
-        shared_libs = [
-            "../../../../bin/mlir_runner_utils.dll",
-            "../../../../bin/mlir_c_runner_utils.dll"
-        ]
+      shared_libs = [
+          "../../../../bin/mlir_runner_utils.dll",
+          "../../../../bin/mlir_c_runner_utils.dll"
+      ]
     else:
-        shared_libs = [
-            "../../../../lib/libmlir_runner_utils.so",
-            "../../../../lib/libmlir_c_runner_utils.so"
-        ]
+      shared_libs = [
+          "../../../../lib/libmlir_runner_utils.so",
+          "../../../../lib/libmlir_c_runner_utils.so"
+      ]
 
     execution_engine = ExecutionEngine(
         lowerToLLVM(module),
@@ -484,15 +528,15 @@ def testNanoTime():
     }""")
 
     if sys.platform == 'win32':
-        shared_libs = [
-            "../../../../bin/mlir_runner_utils.dll",
-            "../../../../bin/mlir_c_runner_utils.dll"
-        ]
+      shared_libs = [
+          "../../../../bin/mlir_runner_utils.dll",
+          "../../../../bin/mlir_c_runner_utils.dll"
+      ]
     else:
-        shared_libs = [
-            "../../../../lib/libmlir_runner_utils.so",
-            "../../../../lib/libmlir_c_runner_utils.so"
-        ]
+      shared_libs = [
+          "../../../../lib/libmlir_runner_utils.so",
+          "../../../../lib/libmlir_c_runner_utils.so"
+      ]
 
     execution_engine = ExecutionEngine(
         lowerToLLVM(module),