%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
- call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
gpu.terminator
}
- %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
// CHECK: [0, 2]
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
- call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
gpu.terminator
}
- %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
// CHECK: [16, 11]
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
- call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
gpu.terminator
}
- %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
// CHECK: [0, 2]
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
%sx = dim %dst, 2 : memref<?x?x?xf32>
%sy = dim %dst, 1 : memref<?x?x?xf32>
%sz = dim %dst, 0 : memref<?x?x?xf32>
- call @mcuMemHostRegisterMemRef3dFloat(%dst) : (memref<?x?x?xf32>) -> ()
+ %cast_dst = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) {
%t0 = muli %tz, %block_y : index
store %sum, %dst[%tz, %ty, %tx] : memref<?x?x?xf32>
gpu.terminator
}
- %U = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
- call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> ()
return
}
-func @mcuMemHostRegisterMemRef3dFloat(%ptr : memref<?x?x?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
- call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
gpu.terminator
}
- %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
// CHECK: [31, 15]
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
%dst = memref_cast %arg : memref<35xf32> to memref<?xf32>
%one = constant 1 : index
%sx = dim %dst, 0 : memref<?xf32>
- call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> ()
+ %cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%val = index_cast %tx : index to i32
store %res, %dst[%tx] : memref<?xf32>
gpu.terminator
}
- %U = memref_cast %dst : memref<?xf32> to memref<*xf32>
- call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> ()
return
}
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(memref<*xf32>)
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
- call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
gpu.terminator
}
- %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
// CHECK: [31, 1]
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
%arg0 = alloc() : memref<5xf32>
%21 = constant 5 : i32
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
- call @mcuMemHostRegisterMemRef1dFloat(%22) : (memref<?xf32>) -> ()
%23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%23) : (memref<*xf32>) -> ()
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
%24 = constant 1.0 : f32
call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
return
}
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xf32> to memref<?x?xf32>
- call @mcuMemHostRegisterMemRef2dFloat(%cast_data) : (memref<?x?xf32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xf32> to memref<?xf32>
- call @mcuMemHostRegisterMemRef1dFloat(%cast_sum) : (memref<?xf32>) -> ()
- %cast_mul = memref_cast %mul : memref<2xf32> to memref<?xf32>
- call @mcuMemHostRegisterMemRef1dFloat(%cast_mul) : (memref<?xf32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> ()
+ %cast_mul = memref_cast %mul : memref<2xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xf32>
store %cst1, %data[%c0, %c1] : memref<2x6xf32>
gpu.terminator
}
- %ptr_sum = memref_cast %sum : memref<2xf32> to memref<*xf32>
- call @print_memref_f32(%ptr_sum) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_sum) : (memref<*xf32>) -> ()
// CHECK: [31, 39]
- %ptr_mul = memref_cast %mul : memref<2xf32> to memref<*xf32>
- call @print_memref_f32(%ptr_mul) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_mul) : (memref<*xf32>) -> ()
// CHECK: [0, 27720]
return
}
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
-func @mcuMemHostRegisterMemRef2dFloat(%ptr : memref<?x?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(memref<*xf32>)
%dst = memref_cast %arg : memref<13xf32> to memref<?xf32>
%one = constant 1 : index
%sx = dim %dst, 0 : memref<?xf32>
- call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> ()
+ %cast_dest = memref_cast %dst : memref<?xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%t0 = index_cast %tx : index to i32
store %value, %dst[%tx] : memref<?xf32>
gpu.terminator
}
- %U = memref_cast %dst : memref<?xf32> to memref<*xf32>
- call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_dest) : (memref<*xf32>) -> ()
return
}
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
%dst = memref_cast %arg : memref<13xi32> to memref<?xi32>
%one = constant 1 : index
%sx = dim %dst, 0 : memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%dst) : (memref<?xi32>) -> ()
+ %cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%t0 = index_cast %tx : index to i32
store %t0, %dst[%tx] : memref<?xi32>
gpu.terminator
}
- %U = memref_cast %dst : memref<?xi32> to memref<*xi32>
- call @print_memref_i32(%U) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_dst) : (memref<*xi32>) -> ()
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @print_memref_i32(%ptr : memref<*xi32>)
+func @mcuMemHostRegisterInt32(%memref : memref<*xi32>)
+func @print_memref_i32(%memref : memref<*xi32>)
#include <cassert>
#include <numeric>
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
"MemHostRegister");
}
-// A struct that corresponds to how MLIR represents memrefs.
-template <typename T, int N> struct MemRefType {
- T *basePtr;
- T *data;
- int64_t offset;
- int64_t sizes[N];
- int64_t strides[N];
-};
-
// Allows to register a MemRef with the CUDA runtime. Initializes array with
// value. Helpful until we have transfer functions implemented.
template <typename T>
mcuMemHostRegister(pointer, count * sizeof(T));
}
-extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated,
- float *aligned, int64_t offset,
- int64_t size, int64_t stride) {
- mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 1.23f);
-}
-
-extern "C" void mcuMemHostRegisterMemRef2dFloat(float *allocated,
- float *aligned, int64_t offset,
- int64_t size0, int64_t size1,
- int64_t stride0,
- int64_t stride1) {
- mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
- 1.23f);
-}
-
-extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
- float *aligned, int64_t offset,
- int64_t size0, int64_t size1,
- int64_t size2, int64_t stride0,
- int64_t stride1,
- int64_t stride2) {
- mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
- {stride0, stride1, stride2}, 1.23f);
-}
-
-extern "C" void mcuMemHostRegisterMemRef1dInt32(int32_t *allocated,
- int32_t *aligned,
- int64_t offset, int64_t size,
- int64_t stride) {
- mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 123);
-}
-
-extern "C" void mcuMemHostRegisterMemRef2dInt32(int32_t *allocated,
- int32_t *aligned,
- int64_t offset, int64_t size0,
- int64_t size1, int64_t stride0,
- int64_t stride1) {
- mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
- 123);
+extern "C" void mcuMemHostRegisterFloat(int64_t rank, void *ptr) {
+ auto *desc = static_cast<StridedMemRefType<float, 1> *>(ptr);
+ auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
+ auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
+ mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f);
}
-extern "C" void
-mcuMemHostRegisterMemRef3dInt32(int32_t *allocated, int32_t *aligned,
- int64_t offset, int64_t size0, int64_t size1,
- int64_t size2, int64_t stride0, int64_t stride1,
- int64_t stride2) {
- mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
- {stride0, stride1, stride2}, 123);
+extern "C" void mcuMemHostRegisterInt32(int64_t rank, void *ptr) {
+ auto *desc = static_cast<StridedMemRefType<int32_t, 1> *>(ptr);
+ auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
+ auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
+ mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123);
}