Type llvmPointerPointerType =
this->getTypeConverter()->getPointerType(llvmPointerType);
Type llvmInt8Type = IntegerType::get(context, 8);
+ Type llvmInt16Type = IntegerType::get(context, 16);
Type llvmInt32Type = IntegerType::get(context, 32);
Type llvmInt64Type = IntegerType::get(context, 64);
Type llvmInt8PointerType =
{llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */}};
- FunctionCallBuilder memsetCallBuilder = {
+ FunctionCallBuilder memset16CallBuilder = {
+ "mgpuMemset16",
+ llvmVoidType,
+ {llvmPointerType /* void *dst */,
+ llvmInt16Type /* unsigned short value */,
+ llvmIntPtrType /* intptr_t sizeBytes */,
+ llvmPointerType /* void *stream */}};
+ FunctionCallBuilder memset32CallBuilder = {
"mgpuMemset32",
llvmVoidType,
{llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
auto loc = memsetOp.getLoc();
Type valueType = adaptor.getValue().getType();
- if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
- return rewriter.notifyMatchFailure(memsetOp,
- "value must be a 32 bit scalar");
+ unsigned bitWidth = valueType.getIntOrFloatBitWidth();
+ // Ints and floats of 16 or 32 bit width are allowed.
+ if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
+ return rewriter.notifyMatchFailure(
+ memsetOp, "value must be a 16 or 32 bit int or float");
}
+ unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
+ Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
+
MemRefDescriptor dstDesc(adaptor.getDst());
Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
auto value =
- rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.getValue());
+ rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
dstDesc.alignedPtr(rewriter, loc),
*getTypeConverter());
auto stream = adaptor.getAsyncDependencies().front();
- memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
+ FunctionCallBuilder builder =
+ valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
+ builder.create(loc, rewriter, {dst, value, numElements, stream});
rewriter.replaceOp(memsetOp, {stream});
return success();
module attributes {gpu.container_module} {
- // CHECK: func @foo
- func.func @foo(%dst : memref<7xf32, 1>, %value : f32) {
+ // CHECK: func @memset_f32
+ func.func @memset_f32(%dst : memref<7xf32, 1>, %value : f32) {
// CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
%t0 = gpu.wait async
// CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
}
}
+// -----
+
+module attributes {gpu.container_module} {
+
+ // CHECK: func @memset_f16
+ func.func @memset_f16(%dst : memref<7xf16, 1>, %value : f16) {
+ // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
+ %t0 = gpu.wait async
+ // CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
+ // CHECK: %[[value:.*]] = llvm.bitcast
+ // CHECK: %[[addr_cast:.*]] = llvm.addrspacecast
+ // CHECK: %[[dst:.*]] = llvm.bitcast %[[addr_cast]]
+ // CHECK: llvm.call @mgpuMemset16(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]])
+ %t1 = gpu.memset async [%t0] %dst, %value : memref<7xf16, 1>, f16
+ // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])
+ // CHECK: llvm.call @mgpuStreamDestroy(%[[t0]])
+ gpu.wait [%t1]
+ return
+ }
+}