[mlir][spirv] Clean up coop matrix assembly declaration.
authorThomas Raoux <thomasraoux@google.com>
Fri, 29 May 2020 23:34:56 +0000 (16:34 -0700)
committerThomas Raoux <thomasraoux@google.com>
Fri, 29 May 2020 23:37:35 +0000 (16:37 -0700)
Address code review feedback and use declarative assembly format.

Differential Revision: https://reviews.llvm.org/D80687

mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
mlir/test/Dialect/SPIRV/cooperative-matrix.mlir

index 4645765..9c3462a 100644 (file)
@@ -39,6 +39,8 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
     ```
   }];
 
+  let assemblyFormat = "attr-dict `:` $type";
+
   let availability = [
     MinVersion<SPV_V_1_0>,
     MaxVersion<SPV_V_1_5>,
@@ -139,7 +141,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
 // -----
 
 def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
-  [NoSideEffect]> {
+  [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
   let summary = "See extension SPV_NV_cooperative_matrix";
 
   let description = [{
@@ -188,6 +190,10 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
     ```
   }];
 
+  let assemblyFormat = [{
+    operands attr-dict`:` type($a) `,` type($b) `->` type($c)
+  }];
+
   let availability = [
     MinVersion<SPV_V_1_0>,
     MaxVersion<SPV_V_1_5>,
index 4f48ef9..ac8fee8 100644 (file)
@@ -1134,12 +1134,11 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
       return compositeConstructOp.emitError(
                  "has incorrect number of operands: expected ")
              << "1, but provided " << constituents.size();
-  } else {
-    if (constituents.size() != cType.getNumElements())
-      return compositeConstructOp.emitError(
-                 "has incorrect number of operands: expected ")
-             << cType.getNumElements() << ", but provided "
-             << constituents.size();
+  } else if (constituents.size() != cType.getNumElements()) {
+    return compositeConstructOp.emitError(
+               "has incorrect number of operands: expected ")
+           << cType.getNumElements() << ", but provided "
+           << constituents.size();
   }
 
   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
@@ -2736,56 +2735,9 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
 }
 
 //===----------------------------------------------------------------------===//
-// spv.CooperativeMatrixLengthNV
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCooperativeMatrixLengthNVOp(OpAsmParser &parser,
-                                                    OperationState &state) {
-  OpAsmParser::OperandType operandInfo;
-  Type dstType = parser.getBuilder().getIntegerType(32);
-  Type type;
-  if (parser.parseColonType(type)) {
-    return failure();
-  }
-  state.addAttribute(kTypeAttrName, TypeAttr::get(type));
-  state.addTypes(dstType);
-  return success();
-}
-
-static void print(spirv::CooperativeMatrixLengthNVOp coopMatrix,
-                  OpAsmPrinter &printer) {
-  printer << coopMatrix.getOperationName() << " : " << coopMatrix.type();
-}
-
-//===----------------------------------------------------------------------===//
 // spv.CooperativeMatrixMulAddNV
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseCooperativeMatrixMulAddNVOp(OpAsmParser &parser,
-                                                    OperationState &state) {
-  SmallVector<OpAsmParser::OperandType, 3> ops;
-  SmallVector<Type, 3> types(3);
-  if (parser.parseOperandList(ops, 3) || parser.parseColon() ||
-      parser.parseType(types[0]) || parser.parseComma() ||
-      parser.parseType(types[1]) || parser.parseArrow() ||
-      parser.parseType(types[2]) ||
-      parser.resolveOperands(ops, types, parser.getNameLoc(), state.operands)) {
-    return failure();
-  }
-  state.addTypes(types[2]);
-  return success();
-}
-
-static void print(spirv::CooperativeMatrixMulAddNVOp coopMatrix,
-                  OpAsmPrinter &printer) {
-  printer << coopMatrix.getOperationName() << ' ' << coopMatrix.getOperand(0)
-          << ", " << coopMatrix.getOperand(1) << ", "
-          << coopMatrix.getOperand(2) << ", "
-          << " : " << coopMatrix.getOperand(0).getType() << ", "
-          << coopMatrix.getOperand(1).getType() << " -> "
-          << coopMatrix.getOperand(2).getType();
-}
-
 static LogicalResult
 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
   if (op.c().getType() != op.result().getType())
index 12f710e..0d58fea 100644 (file)
@@ -38,7 +38,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_N
 
   // CHECK-LABEL: @cooperative_matrix_muladd
   spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}},  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+    // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
     %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
     spv.Return
   }
index e303526..51c7090 100644 (file)
@@ -38,7 +38,7 @@ spv.func @cooperative_matrix_length() -> i32 "None" {
 
 // CHECK-LABEL: @cooperative_matrix_muladd
 spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}},  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+  // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
   %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
   spv.Return
 }