VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
+// wmma
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
+def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
+ VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
def AMDGPU_MFMAOp :
AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
let hasVerifier = 1;
}
+def AMDGPU_WMMAOp :
+ AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
+ AllTypesMatch<["sourceA", "sourceB"]>,
+ Pure]>,
+ Arguments<(ins
+ WMMAInTypes:$sourceA,
+ WMMAInTypes:$sourceB,
+ WMMAOutTypes:$destC,
+ DefaultValuedAttr<ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>, "0">:$subwordOffset,
+ UnitAttr:$unsignedA,
+ UnitAttr:$unsignedB,
+ UnitAttr:$clamp)>,
+ Results<(outs WMMAOutTypes: $destD)> {
+ let summary = "MLIR wrapper for RDNA3 wmma instructions";
+ let description = [{
+ The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
+ for various `wmma` instructions in the RDNA3 architecture, which perform
+ a 16x16 matrix multiplication for different data types.
+
+ When emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 (or 16xbf16) vector
+ containing only 8 valid values:
+ - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
+ - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
+
+ `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
+
+ The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+ in case of overflow.
+ }];
+ let assemblyFormat = [{
+ $sourceA `*` $sourceB `+` $destC
+ attr-dict
+ `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // AMDGPU
let assemblyFormat = "attr-dict";
}
+
//===---------------------------------------------------------------------===//
// Xdlops intrinsics
def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">;
//===---------------------------------------------------------------------===//
+// WMMA intrinsics
+class ROCDL_Wmma_IntrOp<string mnemonic, list<Trait> traits = []> :
+ LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
+ "amdgcn_" # !subst(".","_", mnemonic),
+ [0], [], traits, 1>,
+ Arguments<(ins Variadic<LLVM_Type>:$args)> {
+ let assemblyFormat =
+ "$args attr-dict `:` functional-type($args, $res)";
+}
+
+// Available on RDNA3
+def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16">;
+def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16">;
+def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16">;
+def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16">;
+def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8">;
+def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4">;
+
+
+//===---------------------------------------------------------------------===//
// Vector buffer load/store intrinsics
def ROCDL_MubufLoadOp :
return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
}
+static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
+ bool value) {
+ Type llvmI1 = rewriter.getI1Type();
+ return rewriter.createOrFold<LLVM::ConstantOp>(loc, llvmI1, value);
+}
+
namespace {
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
return input;
}
+/// Push an input operand. If it is a float type, nothing to do. If it is
+/// an integer type, then we need to also push its signdness (1 for signed, 0
+/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
+/// vector.
+static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
+ Location loc, TypeConverter *typeConverter,
+ bool isUnsigned, Value llvmInput,
+ SmallVector<Value, 4> &operands) {
+ Type inputType = llvmInput.getType();
+ auto vectorType = inputType.dyn_cast<VectorType>();
+ Type elemType = vectorType.getElementType();
+
+ if (!elemType.isInteger(8)) {
+ operands.push_back(llvmInput);
+ return;
+ }
+
+ int64_t numBytes = vectorType.getNumElements();
+ Type i32 = rewriter.getI32Type();
+ VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
+ auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
+
+ Value result = rewriter.createOrFold<LLVM::BitcastOp>(
+ loc, llvmVectorType32bits, llvmInput);
+
+ // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
+ bool localIsUnsigned = isUnsigned;
+ if (elemType.isUnsignedInteger(8)) {
+ localIsUnsigned = true;
+ } else if (elemType.isSignedInteger(8)) {
+ localIsUnsigned = false;
+ }
+ Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
+ operands.push_back(sign);
+ operands.push_back(result);
+}
+
+/// Push the output operand. For many cases this is only pushing the output in
+/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
+/// since the same numbers of VGPRs is used, we need to decide if to store the
+/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
+/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
+/// be stored it in the upper part
+static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
+ Location loc, TypeConverter *typeConverter,
+ Value output, int32_t subwordOffset,
+ bool clamp, SmallVector<Value, 4> &operands) {
+ Type inputType = output.getType();
+ auto vectorType = inputType.dyn_cast<VectorType>();
+ Type elemType = vectorType.getElementType();
+ operands.push_back(output);
+ if (elemType.isF16() || elemType.isBF16()) {
+ operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
+ } else if (elemType.isInteger(32)) {
+ operands.push_back(createI1Constant(rewriter, loc, clamp));
+ }
+}
+
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
return std::nullopt;
}
+/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// if one exists. This includes checking to ensure the intrinsic is supported
+/// on the architecture you are compiling for.
+static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
+ Chipset chipset) {
+
+ auto sourceVectorType = wmma.getSourceA().getType().dyn_cast<VectorType>();
+ auto destVectorType = wmma.getDestC().getType().dyn_cast<VectorType>();
+ auto elemSourceType = sourceVectorType.getElementType();
+ auto elemDestType = destVectorType.getElementType();
+
+ if (elemSourceType.isF16() && elemDestType.isF32()) {
+ return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+ } else if (elemSourceType.isBF16() && elemDestType.isF32()) {
+ return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+ } else if (elemSourceType.isF16() && elemDestType.isF16()) {
+ return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+ } else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
+ return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+ } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
+ return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+ }
+ return std::nullopt;
+}
+
namespace {
struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
MFMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
}
};
+struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
+ WMMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type outType = typeConverter->convertType(op.getDestD().getType());
+
+ if (chipset.majorVersion != 11)
+ return op->emitOpError("WMMA only supported on gfx11");
+
+ std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
+
+ if (!maybeIntrinsic.has_value())
+ return op.emitOpError("no intrinsic matching WMMA on the given chipset");
+
+ OperationState loweredOp(loc, *maybeIntrinsic);
+ loweredOp.addTypes(outType);
+
+ SmallVector<Value, 4> operands;
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
+ adaptor.getSourceA(), operands);
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
+ adaptor.getSourceB(), operands);
+ wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
+ op.getSubwordOffset(), op.getClamp(), operands);
+
+ loweredOp.addOperands(operands);
+ Operation *lowered = rewriter.create(loweredOp);
+ rewriter.replaceOp(op, lowered->getResults());
+
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
ConvertAMDGPUToROCDLPass() = default;
RawBufferOpLowering<RawBufferAtomicUminOp, ROCDL::RawBufferAtomicUMinOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawBufferAtomicCmpSwap>,
- MFMAOpLowering>(converter, chipset);
+ MFMAOpLowering, WMMAOpLowering>(converter, chipset);
}
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
}
//===----------------------------------------------------------------------===//
+// WMMAOp
+//===----------------------------------------------------------------------===//
+LogicalResult WMMAOp::verify() {
+ Type sourceAType = getSourceA().getType();
+ Type destType = getDestC().getType();
+
+ VectorType sourceVectorAType = sourceAType.dyn_cast<VectorType>();
+ VectorType destVectorType = destType.dyn_cast<VectorType>();
+
+ Type sourceAElemType = sourceVectorAType.getElementType();
+ Type destElemType = destVectorType.getElementType();
+
+ bool isDestFloat =
+ (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
+ bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
+
+ if (isDestFloat && !isSrcFloat) {
+ return emitOpError("Expected float sources with float destination");
+ }
+
+ if (!isDestFloat && isSrcFloat) {
+ return emitOpError("Expected int sources with int destination");
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// MFMAOp
//===----------------------------------------------------------------------===//
LogicalResult MFMAOp::verify() {
--- /dev/null
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
+ %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
+ %arg6 : vector<16xi8>, %arg7 : vector<4xi32>, %arg8 : vector<8xi32>,
+ %arg9 : vector<16xui8>){
+ // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
+ amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+ // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
+ amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+ // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xf32>) -> vector<8xf32>
+ amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+ // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<4xf32>) -> vector<4xf32>
+ amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+ // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
+ amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+ // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
+ amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+ // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<16xbf16>, i1) -> vector<16xbf16>
+ amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+ // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>, i1) -> vector<8xbf16>
+ amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
+ amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32>
+ // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
+ amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<8xi32>
+ func.return
+}
abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, f32, vector<32xf32>
func.return %d : vector<32xf32>
}
+
+// -----
+
+func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{'amdgpu.wmma' op Expected int sources with int destination}}
+ %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
%0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, f32, vector<32xf32>
func.return %0 : vector<32xf32>
}
+
+// CHECK-LABEL: func @wmma
+func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+ // CHECK: amdgpu.wmma
+ %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+ func.return %0 : vector<8xf16>
+}
llvm.return %r0 : vector<32 x f32>
}
+llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : vector<16 x i16>, %arg3 : vector<8 x i32>,
+ %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>) -> vector<8xf32> {
+ %zero = llvm.mlir.constant(false) : i1
+
+ // ---- Wave32 -----
+
+ // f16 -> f32
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x float> %{{.*}})
+ %r0 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg0 : (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
+
+ // bf16 -> f32
+ // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x float> %{{.*}})
+ %r1 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg0 : (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
+
+ // f16 -> f16 (OPSEL = {0,1})
+ // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x half> %{{.*}}, i1 {{.*}})
+ %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1, %zero : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
+
+ // bf16 -> bf16 (OPSEL = {0,1})
+ // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}, i1 {{.*}})
+ %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2, %zero : (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+
+ // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
+ %r5 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg3, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
+
+ // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+ // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
+ %r6 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+
+ // ---- Wave64 -----
+
+ // f16 -> f32
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}})
+ %r7 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg6 : (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
+
+ // bf16 -> f32
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <4 x float> %{{.*}})
+ %r8 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg6 : (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
+
+ // f16 -> f16 (OPSEL = {0,1})
+ // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x half> %{{.*}}, i1 {{.*}})
+ %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7, %zero : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
+
+ // bf16 -> bf16 (OPSEL = {0,1})
+ // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x i16> %{{.*}}, i1 {{.*}})
+ %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8, %zero : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+
+ // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+ // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}})
+ %r12 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg5, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
+
+ // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+ // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}})
+ %r13 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg5, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<4xi32>, i1) -> vector<4xi32>
+
+ llvm.return %r0 : vector<8xf32>
+}
+
+
llvm.func @rocdl.mubuf(%rsrc : vector<4xi32>, %vindex : i32,
%offset : i32, %vdata1 : vector<1xf32>,
%vdata2 : vector<2xf32>, %vdata4 : vector<4xf32>) {