return compositeConstructOp.emitError(
"has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size();
- } else {
- if (constituents.size() != cType.getNumElements())
- return compositeConstructOp.emitError(
- "has incorrect number of operands: expected ")
- << cType.getNumElements() << ", but provided "
- << constituents.size();
+ } else if (constituents.size() != cType.getNumElements()) {
+ return compositeConstructOp.emitError(
+ "has incorrect number of operands: expected ")
+ << cType.getNumElements() << ", but provided "
+ << constituents.size();
}
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
}
//===----------------------------------------------------------------------===//
-// spv.CooperativeMatrixLengthNV
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCooperativeMatrixLengthNVOp(OpAsmParser &parser,
- OperationState &state) {
- OpAsmParser::OperandType operandInfo;
- Type dstType = parser.getBuilder().getIntegerType(32);
- Type type;
- if (parser.parseColonType(type)) {
- return failure();
- }
- state.addAttribute(kTypeAttrName, TypeAttr::get(type));
- state.addTypes(dstType);
- return success();
-}
-
-static void print(spirv::CooperativeMatrixLengthNVOp coopMatrix,
- OpAsmPrinter &printer) {
- printer << coopMatrix.getOperationName() << " : " << coopMatrix.type();
-}
-
-//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixMulAddNV
//===----------------------------------------------------------------------===//
-static ParseResult parseCooperativeMatrixMulAddNVOp(OpAsmParser &parser,
- OperationState &state) {
- SmallVector<OpAsmParser::OperandType, 3> ops;
- SmallVector<Type, 3> types(3);
- if (parser.parseOperandList(ops, 3) || parser.parseColon() ||
- parser.parseType(types[0]) || parser.parseComma() ||
- parser.parseType(types[1]) || parser.parseArrow() ||
- parser.parseType(types[2]) ||
- parser.resolveOperands(ops, types, parser.getNameLoc(), state.operands)) {
- return failure();
- }
- state.addTypes(types[2]);
- return success();
-}
-
-static void print(spirv::CooperativeMatrixMulAddNVOp coopMatrix,
- OpAsmPrinter &printer) {
- printer << coopMatrix.getOperationName() << ' ' << coopMatrix.getOperand(0)
- << ", " << coopMatrix.getOperand(1) << ", "
- << coopMatrix.getOperand(2) << ", "
- << " : " << coopMatrix.getOperand(0).getType() << ", "
- << coopMatrix.getOperand(1).getType() << " -> "
- << coopMatrix.getOperand(2).getType();
-}
-
static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
if (op.c().getType() != op.result().getType())
// CHECK-LABEL: @cooperative_matrix_muladd
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_muladd
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}