_fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
+class F16(ctypes.Structure):
+ """A ctype representation for MLIR's Float16."""
+ _fields_ = [("f16", ctypes.c_int16)]
+
+
def as_ctype(dtp):
"""Converts dtype to ctype."""
if dtp is np.dtype(np.complex128):
return C128
if dtp is np.dtype(np.complex64):
return C64
+ if dtp is np.dtype(np.float16):
+ return F16
return np.ctypeslib.as_ctypes_type(dtp)
+def to_numpy(array):
+ """Converts ctypes array back to numpy dtype array."""
+ if array.dtype == C128:
+ return array.view("complex128")
+ if array.dtype == C64:
+ return array.view("complex64")
+ if array.dtype == F16:
+ return array.view("float16")
+ return array
+
+
def make_nd_memref_descriptor(rank, dtype):
class MemRefDescriptor(ctypes.Structure):
np.ctypeslib.as_array(val[0].shape),
np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
)
- if strided_arr.dtype == C128:
- return strided_arr.view("complex128")
- if strided_arr.dtype == C64:
- return strided_arr.view("complex64")
- return strided_arr
+ return to_numpy(strided_arr)
def ranked_memref_to_numpy(ranked_memref):
np.ctypeslib.as_array(ranked_memref[0].shape),
np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
)
- if strided_arr.dtype == C128:
- return strided_arr.view("complex128")
- if strided_arr.dtype == C64:
- return strided_arr.view("complex64")
- return strided_arr
+ return to_numpy(strided_arr)
run(testMemrefAdd)
+# Test addition of two f16 memrefs
+# CHECK-LABEL: TEST: testF16MemrefAdd
+def testF16MemrefAdd():
+ with Context():
+ module = Module.parse("""
+ module {
+ func.func @main(%arg0: memref<1xf16>,
+ %arg1: memref<1xf16>,
+ %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
+ %0 = arith.constant 0 : index
+ %1 = memref.load %arg0[%0] : memref<1xf16>
+ %2 = memref.load %arg1[%0] : memref<1xf16>
+ %3 = arith.addf %1, %2 : f16
+ memref.store %3, %arg2[%0] : memref<1xf16>
+ return
+ }
+ } """)
+
+ arg1 = np.array([11.]).astype(np.float16)
+ arg2 = np.array([22.]).astype(np.float16)
+ arg3 = np.array([0.]).astype(np.float16)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg1)))
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg2)))
+ arg3_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg3)))
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
+ arg3_memref_ptr)
+ # CHECK: [11.] + [22.] = [33.]
+ log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+ # test to-numpy utility
+ # CHECK: [33.]
+ npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+ log(npout)
+
+
+run(testF16MemrefAdd)
+
+
# Test addition of two complex memrefs
# CHECK-LABEL: TEST: testComplexMemrefAdd
def testComplexMemrefAdd():
ctypes.pointer(get_ranked_memref_descriptor(arg0)))
if sys.platform == 'win32':
- shared_libs = [
- "../../../../bin/mlir_runner_utils.dll",
- "../../../../bin/mlir_c_runner_utils.dll"
- ]
+ shared_libs = [
+ "../../../../bin/mlir_runner_utils.dll",
+ "../../../../bin/mlir_c_runner_utils.dll"
+ ]
else:
- shared_libs = [
- "../../../../lib/libmlir_runner_utils.so",
- "../../../../lib/libmlir_c_runner_utils.so"
- ]
+ shared_libs = [
+ "../../../../lib/libmlir_runner_utils.so",
+ "../../../../lib/libmlir_c_runner_utils.so"
+ ]
execution_engine = ExecutionEngine(
lowerToLLVM(module),
}""")
if sys.platform == 'win32':
- shared_libs = [
- "../../../../bin/mlir_runner_utils.dll",
- "../../../../bin/mlir_c_runner_utils.dll"
- ]
+ shared_libs = [
+ "../../../../bin/mlir_runner_utils.dll",
+ "../../../../bin/mlir_c_runner_utils.dll"
+ ]
else:
- shared_libs = [
- "../../../../lib/libmlir_runner_utils.so",
- "../../../../lib/libmlir_c_runner_utils.so"
- ]
+ shared_libs = [
+ "../../../../lib/libmlir_runner_utils.so",
+ "../../../../lib/libmlir_c_runner_utils.so"
+ ]
execution_engine = ExecutionEngine(
lowerToLLVM(module),