import ctypes
+class C128(ctypes.Structure):
+ """A ctype representation for MLIR's Double Complex."""
+ _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
+
+
+class C64(ctypes.Structure):
+ """A ctype representation for MLIR's Float Complex."""
+ _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
+
+
+def as_ctype(dtp):
+ """Converts dtype to ctype."""
+ if dtp is np.dtype(np.complex128):
+ return C128
+ if dtp is np.dtype(np.complex64):
+ return C64
+ return np.ctypeslib.as_ctypes_type(dtp)
+
+
def make_nd_memref_descriptor(rank, dtype):
- class MemRefDescriptor(ctypes.Structure):
- """
- Build an empty descriptor for the given rank/dtype, where rank>0.
- """
- _fields_ = [
- ("allocated", ctypes.c_longlong),
- ("aligned", ctypes.POINTER(dtype)),
- ("offset", ctypes.c_longlong),
- ("shape", ctypes.c_longlong * rank),
- ("strides", ctypes.c_longlong * rank),
- ]
+ class MemRefDescriptor(ctypes.Structure):
+ """Builds an empty descriptor for the given rank/dtype, where rank>0."""
- return MemRefDescriptor
+ _fields_ = [
+ ("allocated", ctypes.c_longlong),
+ ("aligned", ctypes.POINTER(dtype)),
+ ("offset", ctypes.c_longlong),
+ ("shape", ctypes.c_longlong * rank),
+ ("strides", ctypes.c_longlong * rank),
+ ]
+
+ return MemRefDescriptor
def make_zero_d_memref_descriptor(dtype):
- class MemRefDescriptor(ctypes.Structure):
- """
- Build an empty descriptor for the given dtype, where rank=0.
- """
- _fields_ = [
- ("allocated", ctypes.c_longlong),
- ("aligned", ctypes.POINTER(dtype)),
- ("offset", ctypes.c_longlong),
- ]
+ class MemRefDescriptor(ctypes.Structure):
+ """Builds an empty descriptor for the given dtype, where rank=0."""
- return MemRefDescriptor
+ _fields_ = [
+ ("allocated", ctypes.c_longlong),
+ ("aligned", ctypes.POINTER(dtype)),
+ ("offset", ctypes.c_longlong),
+ ]
+ return MemRefDescriptor
-class UnrankedMemRefDescriptor(ctypes.Structure):
- """ Creates a ctype struct for memref descriptor"""
- _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
+class UnrankedMemRefDescriptor(ctypes.Structure):
+ """Creates a ctype struct for memref descriptor"""
+ _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
def get_ranked_memref_descriptor(nparray):
- """
- Return a ranked memref descriptor for the given numpy array.
- """
- if nparray.ndim == 0:
- x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))()
- x.allocated = nparray.ctypes.data
- x.aligned = nparray.ctypes.data_as(
- ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
- )
- x.offset = ctypes.c_longlong(0)
- return x
-
- x = make_nd_memref_descriptor(
- nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype)
- )()
+ """Returns a ranked memref descriptor for the given numpy array."""
+ ctp = as_ctype(nparray.dtype)
+ if nparray.ndim == 0:
+ x = make_zero_d_memref_descriptor(ctp)()
x.allocated = nparray.ctypes.data
- x.aligned = nparray.ctypes.data_as(
- ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
- )
+ x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
- x.shape = nparray.ctypes.shape
-
- # Numpy uses byte quantities to express strides, MLIR OTOH uses the
- # torch abstraction which specifies strides in terms of elements.
- strides_ctype_t = ctypes.c_longlong * nparray.ndim
- x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
return x
+ x = make_nd_memref_descriptor(nparray.ndim, ctp)()
+ x.allocated = nparray.ctypes.data
+ x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
+ x.offset = ctypes.c_longlong(0)
+ x.shape = nparray.ctypes.shape
+
+ # Numpy uses byte quantities to express strides, MLIR OTOH uses the
+ # torch abstraction which specifies strides in terms of elements.
+ strides_ctype_t = ctypes.c_longlong * nparray.ndim
+ x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
+ return x
+
def get_unranked_memref_descriptor(nparray):
- """
- Return a generic/unranked memref descriptor for the given numpy array.
- """
- d = UnrankedMemRefDescriptor()
- d.rank = nparray.ndim
- x = get_ranked_memref_descriptor(nparray)
- d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
- return d
+ """Returns a generic/unranked memref descriptor for the given numpy array."""
+ d = UnrankedMemRefDescriptor()
+ d.rank = nparray.ndim
+ x = get_ranked_memref_descriptor(nparray)
+ d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
+ return d
def unranked_memref_to_numpy(unranked_memref, np_dtype):
- """
- Converts unranked memrefs to numpy arrays.
- """
- descriptor = make_nd_memref_descriptor(
- unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype)
- )
- val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
- np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
- strided_arr = np.lib.stride_tricks.as_strided(
- np_arr,
- np.ctypeslib.as_array(val[0].shape),
- np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
- )
- return strided_arr
+ """Converts unranked memrefs to numpy arrays."""
+ ctp = as_ctype(np_dtype)
+ descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
+ val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
+ np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
+ strided_arr = np.lib.stride_tricks.as_strided(
+ np_arr,
+ np.ctypeslib.as_array(val[0].shape),
+ np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
+ )
+ if strided_arr.dtype == C128:
+ return strided_arr.view("complex128")
+ if strided_arr.dtype == C64:
+ return strided_arr.view("complex64")
+ return strided_arr
def ranked_memref_to_numpy(ranked_memref):
- """
- Converts ranked memrefs to numpy arrays.
- """
- np_arr = np.ctypeslib.as_array(
- ranked_memref[0].aligned, shape=ranked_memref[0].shape
- )
- strided_arr = np.lib.stride_tricks.as_strided(
- np_arr,
- np.ctypeslib.as_array(ranked_memref[0].shape),
- np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
- )
- return strided_arr
+ """Converts ranked memrefs to numpy arrays."""
+ np_arr = np.ctypeslib.as_array(
+ ranked_memref[0].aligned, shape=ranked_memref[0].shape)
+ strided_arr = np.lib.stride_tricks.as_strided(
+ np_arr,
+ np.ctypeslib.as_array(ranked_memref[0].shape),
+ np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
+ )
+ if strided_arr.dtype == C128:
+ return strided_arr.view("complex128")
+ if strided_arr.dtype == C64:
+ return strided_arr.view("complex64")
+ return strided_arr
def lowerToLLVM(module):
import mlir.conversions
pm = PassManager.parse(
- "convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts")
+ "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts")
pm.run(module)
return module
run(testMemrefAdd)
+# Test addition of two complex memrefs
+# CHECK-LABEL: TEST: testComplexMemrefAdd
+def testComplexMemrefAdd():
+ with Context():
+ module = Module.parse("""
+ module {
+ func.func @main(%arg0: memref<1xcomplex<f64>>,
+ %arg1: memref<1xcomplex<f64>>,
+ %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
+ %0 = arith.constant 0 : index
+ %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>>
+ %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>>
+ %3 = complex.add %1, %2 : complex<f64>
+ memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
+ return
+ }
+ } """)
+
+ arg1 = np.array([1.+2.j]).astype(np.complex128)
+ arg2 = np.array([3.+4.j]).astype(np.complex128)
+ arg3 = np.array([0.+0.j]).astype(np.complex128)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg1)))
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg2)))
+ arg3_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_ranked_memref_descriptor(arg3)))
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke("main",
+ arg1_memref_ptr,
+ arg2_memref_ptr,
+ arg3_memref_ptr)
+ # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
+ log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+ # test to-numpy utility
+ # CHECK: [4.+6.j]
+ npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+ log(npout)
+
+
+run(testComplexMemrefAdd)
+
+
+# Test addition of two complex unranked memrefs
+# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
+def testComplexUnrankedMemrefAdd():
+ with Context():
+ module = Module.parse("""
+ module {
+ func.func @main(%arg0: memref<*xcomplex<f32>>,
+ %arg1: memref<*xcomplex<f32>>,
+ %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
+ %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
+ %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
+ %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
+ %0 = arith.constant 0 : index
+ %1 = memref.load %A[%0] : memref<1xcomplex<f32>>
+ %2 = memref.load %B[%0] : memref<1xcomplex<f32>>
+ %3 = complex.add %1, %2 : complex<f32>
+ memref.store %3, %C[%0] : memref<1xcomplex<f32>>
+ return
+ }
+ } """)
+
+ arg1 = np.array([5.+6.j]).astype(np.complex64)
+ arg2 = np.array([7.+8.j]).astype(np.complex64)
+ arg3 = np.array([0.+0.j]).astype(np.complex64)
+
+ arg1_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_unranked_memref_descriptor(arg1)))
+ arg2_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_unranked_memref_descriptor(arg2)))
+ arg3_memref_ptr = ctypes.pointer(
+ ctypes.pointer(get_unranked_memref_descriptor(arg3)))
+
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+ execution_engine.invoke("main",
+ arg1_memref_ptr,
+ arg2_memref_ptr,
+ arg3_memref_ptr)
+ # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
+ log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+ # test to-numpy utility
+ # CHECK: [12.+14.j]
+ npout = unranked_memref_to_numpy(arg3_memref_ptr[0],
+ np.dtype(np.complex64))
+ log(npout)
+
+
+run(testComplexUnrankedMemrefAdd)
+
+
# Test addition of two 2d_memref
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
def testDynamicMemrefAdd2D():