[mlir][NFC] Update several SPIRV operations to use declarative parsers.
authorRiver Riddle <riddleriver@gmail.com>
Thu, 30 Jan 2020 19:37:35 +0000 (11:37 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 30 Jan 2020 19:43:41 +0000 (11:43 -0800)
Differential Revision: https://reviews.llvm.org/D73504

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/control-flow-ops.mlir

index 79e5158..2dbee40 100644 (file)
@@ -2954,7 +2954,7 @@ def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
 // for the definition of the following types and type categories.
 
 def SPV_Void : TypeAlias<NoneType, "void type">;
-def SPV_Bool : IntOfWidths<[1]>;
+def SPV_Bool : I<1>;
 def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>;
 def SPV_Float : FloatOfWidths<[16, 32, 64]>;
 def SPV_Float16or32 : FloatOfWidths<[16, 32]>;
index a589ff5..2aab10e 100644 (file)
@@ -245,6 +245,11 @@ def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [
   );
 
   let autogenSerialization = 0;
+
+  let assemblyFormat = [{
+    $callee `(` $arguments `)` attr-dict `:`
+    functional-type($arguments, results)
+  }];
 }
 
 // -----
@@ -412,6 +417,8 @@ def SPV_ReturnValueOp : SPV_Op<"ReturnValue", [InFunctionScope, Terminator]> {
   );
 
   let results = (outs);
+
+  let assemblyFormat = "$value attr-dict `:` type($value)";
 }
 
 def SPV_SelectionOp : SPV_Op<"selection", [InFunctionScope]> {
index a8008c2..9564b03 100644 (file)
@@ -65,6 +65,8 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
   );
 
   let verifier = [{ return success(); }];
+
+  let assemblyFormat = "$predicate attr-dict `:` type($result)";
 }
 
 // -----
index 3334995..b4309a8 100644 (file)
@@ -408,6 +408,8 @@ def SPV_UndefOp : SPV_Op<"undef", []> {
 
   let hasOpcode = 0;
   let autogenSerialization = 0;
+
+  let assemblyFormat = "attr-dict `:` type($result)";
 }
 
 // -----
index ed7fe12..995e5d4 100644 (file)
@@ -57,6 +57,8 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
 
   let builders = [OpBuilder<[{Builder *builder, OperationState &state,
                               spirv::GlobalVariableOp var}]>];
+
+  let assemblyFormat = "$variable attr-dict `:` type($pointer)";
 }
 
 def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
@@ -409,6 +411,8 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
   let hasOpcode = 0;
 
   let autogenSerialization = 0;
+
+  let assemblyFormat = "$spec_const attr-dict `:` type($reference)";
 }
 
 def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> {
index 7062dab..fb3e815 100644 (file)
@@ -941,36 +941,6 @@ void spirv::AddressOfOp::build(Builder *builder, OperationState &state,
   build(builder, state, var.type(), builder->getSymbolRefAttr(var));
 }
 
-static ParseResult parseAddressOfOp(OpAsmParser &parser,
-                                    OperationState &state) {
-  FlatSymbolRefAttr varRefAttr;
-  Type type;
-  if (parser.parseAttribute(varRefAttr, Type(), kVariableAttrName,
-                            state.attributes) ||
-      parser.parseColonType(type)) {
-    return failure();
-  }
-  auto ptrType = type.dyn_cast<spirv::PointerType>();
-  if (!ptrType) {
-    return parser.emitError(parser.getCurrentLocation(),
-                            "expected spv.ptr type");
-  }
-  state.addTypes(ptrType);
-  return success();
-}
-
-static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter &printer) {
-  SmallVector<StringRef, 4> elidedAttrs;
-  printer << spirv::AddressOfOp::getOperationName();
-
-  // Print symbol name.
-  printer << ' ';
-  printer.printSymbolName(addressOfOp.variable());
-
-  // Print the type.
-  printer << " : " << addressOfOp.pointer().getType();
-}
-
 static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
   auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
       SymbolTable::lookupNearestSymbolFrom(addressOfOp.getParentOp(),
@@ -1736,45 +1706,6 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
 // spv.FunctionCall
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseFunctionCallOp(OpAsmParser &parser,
-                                       OperationState &state) {
-  FlatSymbolRefAttr calleeAttr;
-  FunctionType type;
-  SmallVector<OpAsmParser::OperandType, 4> operands;
-  auto loc = parser.getNameLoc();
-  if (parser.parseAttribute(calleeAttr, kCallee, state.attributes) ||
-      parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
-      parser.parseColonType(type)) {
-    return failure();
-  }
-
-  auto funcType = type.dyn_cast<FunctionType>();
-  if (!funcType) {
-    return parser.emitError(loc, "expected function type, but provided ")
-           << type;
-  }
-
-  if (funcType.getNumResults() > 1) {
-    return parser.emitError(loc, "expected callee function to have 0 or 1 "
-                                 "result, but provided ")
-           << funcType.getNumResults();
-  }
-
-  return failure(parser.addTypesToList(funcType.getResults(), state.types) ||
-                 parser.resolveOperands(operands, funcType.getInputs(), loc,
-                                        state.operands));
-}
-
-static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
-  SmallVector<Type, 4> argTypes(functionCallOp.getOperandTypes());
-  Type functionType = FunctionType::get(
-      argTypes, functionCallOp.getResultTypes(), functionCallOp.getContext());
-
-  printer << spirv::FunctionCallOp::getOperationName() << ' '
-          << functionCallOp.getAttr(kCallee) << '('
-          << functionCallOp.arguments() << ") : " << functionType;
-}
-
 static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
   auto fnName = functionCallOp.callee();
 
@@ -2533,24 +2464,6 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
 // spv._reference_of
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseReferenceOfOp(OpAsmParser &parser,
-                                      OperationState &state) {
-  FlatSymbolRefAttr constRefAttr;
-  Type type;
-  if (parser.parseAttribute(constRefAttr, Type(), kSpecConstAttrName,
-                            state.attributes) ||
-      parser.parseColonType(type)) {
-    return failure();
-  }
-  return parser.addTypeToList(type, state.types);
-}
-
-static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter &printer) {
-  printer << spirv::ReferenceOfOp::getOperationName() << ' ';
-  printer.printSymbolName(referenceOfOp.spec_const());
-  printer << " : " << referenceOfOp.reference().getType();
-}
-
 static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
   auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(
       SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(),
@@ -2584,20 +2497,6 @@ static LogicalResult verify(spirv::ReturnOp returnOp) {
 // spv.ReturnValue
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseReturnValueOp(OpAsmParser &parser,
-                                      OperationState &state) {
-  OpAsmParser::OperandType retValInfo;
-  Type retValType;
-  return failure(parser.parseOperand(retValInfo) ||
-                 parser.parseColonType(retValType) ||
-                 parser.resolveOperand(retValInfo, retValType, state.operands));
-}
-
-static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
-  printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value()
-          << " : " << retValOp.value().getType();
-}
-
 static LogicalResult verify(spirv::ReturnValueOp retValOp) {
   auto funcOp = retValOp.getParentOfType<FuncOp>();
   auto numFnResults = funcOp.getType().getNumResults();
@@ -3033,44 +2932,6 @@ static LogicalResult verify(spirv::StoreOp storeOp) {
 }
 
 //===----------------------------------------------------------------------===//
-// spv.SubgroupBallotKHROp
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser,
-                                            OperationState &state) {
-  OpAsmParser::OperandType operandInfo;
-  Type resultType;
-  IntegerType i1Type = parser.getBuilder().getI1Type();
-  if (parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
-      parser.resolveOperand(operandInfo, i1Type, state.operands))
-    return failure();
-
-  return parser.addTypeToList(resultType, state.types);
-}
-
-static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) {
-  printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '
-          << ballotOp.predicate() << " : " << ballotOp.getType();
-}
-
-//===----------------------------------------------------------------------===//
-// spv.Undef
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &state) {
-  Type type;
-  if (parser.parseColonType(type)) {
-    return failure();
-  }
-  state.addTypes(type);
-  return success();
-}
-
-static void print(spirv::UndefOp undefOp, OpAsmPrinter &printer) {
-  printer << spirv::UndefOp::getOperationName() << " : " << undefOp.getType();
-}
-
-//===----------------------------------------------------------------------===//
 // spv.Unreachable
 //===----------------------------------------------------------------------===//
 
index 1abeafe..2dd838c 100644 (file)
@@ -202,7 +202,7 @@ func @caller() {
 spv.module "Logical" "GLSL450" {
   func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () {
     // expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}}
-    %0 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32)
+    %0:2 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32)
     spv.Return
   }
 }