From e7aa47ff111c53127587d8aea71b088db3a671aa Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 12 Dec 2019 15:31:39 -0800 Subject: [PATCH] NFC: Cleanup the various Op::print methods. This cleans up the implementation of the various operation print methods. This is done via a combination of code cleanup, adding new streaming methods to the printer(e.g. operand ranges), etc. PiperOrigin-RevId: 285285181 --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 2 +- mlir/include/mlir/IR/OpImplementation.h | 29 ++++++- mlir/lib/Dialect/AffineOps/AffineOps.cpp | 16 ++-- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 25 ++---- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 16 ++-- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 19 +++-- mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 16 ++-- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 41 ++++------ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 116 +++++++++++----------------- mlir/lib/Dialect/StandardOps/Ops.cpp | 112 +++++++-------------------- mlir/lib/Dialect/VectorOps/VectorOps.cpp | 64 ++++++--------- 11 files changed, 175 insertions(+), 281 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index c6e89af..bc6887d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -140,7 +140,7 @@ def NVVM_MmaOp : }]; let parser = [{ return parseNVVMMmaOp(parser, result); }]; let printer = [{ printNVVMMmaOp(p, *this); }]; - let verifier = [{ return mlir::NVVM::verify(*this); }]; + let verifier = [{ return ::verify(*this); }]; } #endif // NVVMIR_OPS diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 3052f79..05beaea 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -154,6 +154,18 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value &value) { p.printOperand(&value); return p; } +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value *value) { + return p << *value; +} + +template ::value && + !std::is_convertible::value, + T>::type * = nullptr> +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { + p.printOperands(values); + return p; +} inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) { p.printType(type); @@ -170,14 +182,29 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) { // FunctionType with the Type version above, not have it match this. template ::value && + !std::is_convertible::value && !std::is_convertible::value && - !std::is_convertible::value, + !std::is_convertible::value && + !std::is_convertible::value && + !llvm::is_one_of::value, T>::type * = nullptr> inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) { p.getStream() << other; return p; } +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) { + return p << (value ? StringRef("true") : "false"); +} + +template +inline OpAsmPrinter & +operator<<(OpAsmPrinter &p, + const iterator_range> &types) { + interleaveComma(types, p); + return p; +} + //===----------------------------------------------------------------------===// // OpAsmParser //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 59e5afec..96a1a68 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -1985,18 +1985,12 @@ static ParseResult parseAffineMinOp(OpAsmParser &parser, static void print(OpAsmPrinter &p, AffineMinOp op) { p << op.getOperationName() << ' ' << op.getAttr(AffineMinOp::getMapAttrName()); - auto begin = op.operand_begin(); - auto end = op.operand_end(); + auto operands = op.getOperands(); unsigned numDims = op.map().getNumDims(); - p << '('; - p.printOperands(begin, begin + numDims); - p << ')'; - - if (begin + numDims != end) { - p << '['; - p.printOperands(begin + numDims, end); - p << ']'; - } + p << '(' << operands.take_front(numDims) << ')'; + + if (operands.size() != numDims) + p << '[' << operands.drop_front(numDims) << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index b8970650..1f48d6d 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -289,8 +289,7 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { // Print the launch configuration. p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword(); - printSizeAssignment(p, op.getGridSize(), - operands.drop_back(operands.size() - 3), + printSizeAssignment(p, op.getGridSize(), operands.take_front(3), op.getBlockIds()); p << ' ' << op.getThreadsKeyword(); printSizeAssignment(p, op.getBlockSize(), operands.slice(3, 3), @@ -303,25 +302,17 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { // Print the data argument remapping. if (!op.body().empty() && !operands.empty()) { p << ' ' << op.getArgsKeyword() << '('; - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (i != 0) - p << ", "; - p << *op.body().front().getArgument(LaunchOp::kNumConfigRegionAttributes + - i) + Block *entryBlock = &op.body().front(); + interleaveComma(llvm::seq(0, operands.size()), p, [&](int i) { + p << *entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i) << " = " << *operands[i]; - } + }); p << ") "; } // Print the types of data arguments. - if (!operands.empty()) { - p << ": "; - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (i != 0) - p << ", "; - p << operands[i]->getType(); - } - } + if (!operands.empty()) + p << ": " << operands.getTypes(); p.printRegion(op.body(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict(op.getAttrs()); @@ -701,7 +692,7 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword, return; p << ' ' << keyword << '('; - interleaveComma(values, p.getStream(), + interleaveComma(values, p, [&p](BlockArgument *v) { p << *v << " : " << v->getType(); }); p << ')'; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 78da999..d037d2e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -177,9 +177,8 @@ static void printGEPOp(OpAsmPrinter &p, GEPOp &op) { SmallVector types(op.getOperandTypes()); auto funcTy = FunctionType::get(types, op.getType(), op.getContext()); - p << op.getOperationName() << ' ' << *op.base() << '['; - p.printOperands(std::next(op.operand_begin()), op.operand_end()); - p << ']'; + p << op.getOperationName() << ' ' << *op.base() << '[' + << op.getOperands().drop_front() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << funcTy; } @@ -312,10 +311,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) { else p << *op.getOperand(0); - p << '('; - p.printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1)); - p << ')'; - + p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; p.printOptionalAttrDict(op.getAttrs(), {"callee"}); // Reconstruct the function MLIR function type from operand and result types. @@ -938,8 +934,7 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { // Print the trailing type unless it's a string global. if (op.getValueOrNull().dyn_cast_or_null()) return; - p << " : "; - p.printType(op.type()); + p << " : " << op.type(); Region &initializer = op.getInitializerRegion(); if (!initializer.empty()) @@ -1346,8 +1341,7 @@ static LogicalResult verify(LLVMFuncOp op) { static void printNullOp(OpAsmPrinter &p, LLVM::NullOp op) { p << NullOp::getOperationName(); p.printOptionalAttrDict(op.getAttrs()); - p << " : "; - p.printType(op.getType()); + p << " : " << op.getType(); } // = `llvm.mlir.null` : type diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 0b10391..e4708fb 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -37,18 +37,17 @@ #include "llvm/IR/Type.h" #include "llvm/Support/SourceMgr.h" -namespace mlir { -namespace NVVM { +using namespace mlir; +using namespace NVVM; //===----------------------------------------------------------------------===// // Printing/parsing for NVVM ops //===----------------------------------------------------------------------===// static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { - p << op->getName() << " "; - p.printOperands(op->getOperands()); + p << op->getName() << " " << op->getOperands(); if (op->getNumResults() > 0) - interleaveComma(op->getResultTypes(), p << " : "); + p << " : " << op->getResultTypes(); } // ::= `llvm.nvvm.XYZ` : type @@ -141,8 +140,7 @@ static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) { } static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) { - p << op.getOperationName() << " "; - p.printOperands(op.getOperands()); + p << op.getOperationName() << " " << op.getOperands(); p.printOptionalAttrDict(op.getAttrs()); p << " : " << FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()), @@ -210,10 +208,11 @@ NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) { allowUnknownOperations(); } +namespace mlir { +namespace NVVM { #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" - -static DialectRegistration nvvmDialect; - } // namespace NVVM } // namespace mlir + +static DialectRegistration nvvmDialect; diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 487382b..30c55b5 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -36,18 +36,17 @@ #include "llvm/IR/Type.h" #include "llvm/Support/SourceMgr.h" -namespace mlir { -namespace ROCDL { +using namespace mlir; +using namespace ROCDL; //===----------------------------------------------------------------------===// // Printing/parsing for ROCDL ops //===----------------------------------------------------------------------===// static void printROCDLOp(OpAsmPrinter &p, Operation *op) { - p << op->getName() << " "; - p.printOperands(op->getOperands()); + p << op->getName() << " " << op->getOperands(); if (op->getNumResults() > 0) - interleaveComma(op->getResultTypes(), p << " : "); + p << " : " << op->getResultTypes(); } // ::= `rocdl.XYZ` : type @@ -73,10 +72,11 @@ ROCDLDialect::ROCDLDialect(MLIRContext *context) : Dialect("rocdl", context) { allowUnknownOperations(); } +namespace mlir { +namespace ROCDL { #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc" - -static DialectRegistration rocdlDialect; - } // namespace ROCDL } // namespace mlir + +static DialectRegistration rocdlDialect; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 2efd26a..6adfeb5 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -60,18 +60,16 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { llvm::StringSet<> linalgTraitAttrsSet; linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector attrs; - for (auto attr : op.getAttrs()) { + for (auto attr : op.getAttrs()) if (linalgTraitAttrsSet.count(attr.first.strref()) > 0) attrs.push_back(attr); - } + auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); - p << op.getOperationName() << " " << dictAttr << " "; - p.printOperands(op.getOperands()); + p << op.getOperationName() << " " << dictAttr << " " << op.getOperands(); if (!op.region().empty()) p.printRegion(op.region()); p.printOptionalAttrDict(op.getAttrs(), attrNames); - p << ": "; - interleaveComma(op.getOperandTypes(), p); + p << ": " << op.getOperandTypes(); } static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); } @@ -342,14 +340,13 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result, } static void print(OpAsmPrinter &p, SliceOp op) { - p << SliceOp::getOperationName() << " " << *op.view() << "["; - p.printOperands(op.indexings()); - p << "] "; + auto indexings = op.indexings(); + p << SliceOp::getOperationName() << " " << *op.view() << "[" << indexings + << "] "; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getBaseViewType(); - for (auto indexing : op.indexings()) { - p << ", " << indexing->getType(); - } + if (!indexings.empty()) + p << ", " << op.indexings().getTypes(); p << ", " << op.getType(); } @@ -455,16 +452,11 @@ static ParseResult parseTransposeOp(OpAsmParser &parser, static void print(OpAsmPrinter &p, YieldOp op) { p << op.getOperationName(); - if (op.getNumOperands() > 0) { - p << ' '; - p.printOperands(op.operand_begin(), op.operand_end()); - } + if (op.getNumOperands() > 0) + p << ' ' << op.getOperands(); p.printOptionalAttrDict(op.getAttrs()); - if (op.getNumOperands() > 0) { - p << " : "; - interleaveComma(op.getOperands(), p, - [&](Value *e) { p.printType(e->getType()); }); - } + if (op.getNumOperands() > 0) + p << " : " << op.getOperandTypes(); } static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { @@ -536,12 +528,9 @@ static LogicalResult verify(YieldOp op) { // Where %0, %1 and %2 are ssa-values of type MemRefType with strides. static void printLinalgLibraryOp(OpAsmPrinter &p, Operation *op) { assert(op->getAbstractOperation() && "unregistered operation"); - p << op->getName().getStringRef() << "("; - interleaveComma(op->getOperands(), p, [&](Value *v) { p << *v; }); - p << ")"; + p << op->getName().getStringRef() << "(" << op->getOperands() << ")"; p.printOptionalAttrDict(op->getAttrs()); - p << " : "; - interleaveComma(op->getOperands(), p, [&](Value *v) { p << v->getType(); }); + p << " : " << op->getOperandTypes(); } static ParseResult parseLinalgLibraryOp(OpAsmParser &parser, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 99ab1cd..839f134 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -500,11 +500,8 @@ static ParseResult parseBitFieldExtractOp(OpAsmParser &parser, } static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) { - printer << op->getName() << ' '; - printer.printOperands(op->getOperands()); - printer << " : " << op->getOperand(0)->getType() << ", " - << op->getOperand(1)->getType() << ", " - << op->getOperand(2)->getType(); + printer << op->getName() << ' ' << op->getOperands() << " : " + << op->getOperandTypes(); } static LogicalResult verifyBitFieldExtractOp(Operation *op) { @@ -580,9 +577,8 @@ static ParseResult parseLogicalBinaryOp(OpAsmParser &parser, } static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) { - printer << logicalOp->getName() << ' '; - printer.printOperands(logicalOp->getOperands()); - printer << " : " << logicalOp->getOperand(0)->getType(); + printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : " + << logicalOp->getOperand(0)->getType(); } static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) { @@ -717,9 +713,7 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser, static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr() - << '['; - printer.printOperands(op.indices()); - printer << "] : " << op.base_ptr()->getType(); + << '[' << op.indices() << "] : " << op.base_ptr()->getType(); } static LogicalResult verify(spirv::AccessChainOp accessChainOp) { @@ -875,9 +869,8 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp, printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \"" << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \"" - << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "; - printer.printOperands(atomOp.getOperands()); - printer << " : " << atomOp.pointer()->getType(); + << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" " + << atomOp.getOperands() << " : " << atomOp.pointer()->getType(); } static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) { @@ -975,9 +968,9 @@ static ParseResult parseBitFieldInsertOp(OpAsmParser &parser, static void print(spirv::BitFieldInsertOp bitFieldInsertOp, OpAsmPrinter &printer) { - printer << spirv::BitFieldInsertOp::getOperationName() << ' '; - printer.printOperands(bitFieldInsertOp.getOperands()); - printer << " : " << bitFieldInsertOp.base()->getType() << ", " + printer << spirv::BitFieldInsertOp::getOperationName() << ' ' + << bitFieldInsertOp.getOperands() << " : " + << bitFieldInsertOp.base()->getType() << ", " << bitFieldInsertOp.offset()->getType() << ", " << bitFieldInsertOp.count()->getType(); } @@ -1072,8 +1065,8 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser, } static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) { - printer << spirv::BranchConditionalOp::getOperationName() << ' '; - printer.printOperand(branchOp.condition()); + printer << spirv::BranchConditionalOp::getOperationName() << ' ' + << branchOp.condition(); if (auto weights = branchOp.branch_weights()) { printer << " ["; @@ -1148,9 +1141,9 @@ static ParseResult parseCompositeConstructOp(OpAsmParser &parser, static void print(spirv::CompositeConstructOp compositeConstructOp, OpAsmPrinter &printer) { - printer << spirv::CompositeConstructOp::getOperationName() << " "; - printer.printOperands(compositeConstructOp.constituents()); - printer << " : " << compositeConstructOp.getResult()->getType(); + printer << spirv::CompositeConstructOp::getOperationName() << " " + << compositeConstructOp.constituents() << " : " + << compositeConstructOp.getResult()->getType(); } static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { @@ -1322,9 +1315,8 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) { static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) { printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value(); - if (constOp.getType().isa()) { + if (constOp.getType().isa()) printer << " : " << constOp.getType(); - } } static LogicalResult verify(spirv::ConstantOp constOp) { @@ -1577,9 +1569,8 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) { << execModeOp.fn() << " \"" << stringifyExecutionMode(execModeOp.execution_mode()) << "\""; auto values = execModeOp.values(); - if (!values.size()) { + if (!values.size()) return; - } printer << ", "; interleaveComma(values, printer, [&](Attribute a) { printer << a.cast().getInt(); @@ -1626,9 +1617,8 @@ static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) { FunctionType::get(argTypes, resultTypes, functionCallOp.getContext()); printer << spirv::FunctionCallOp::getOperationName() << ' ' - << functionCallOp.getAttr(kCallee) << '('; - printer.printOperands(functionCallOp.arguments()); - printer << ") : " << functionType; + << functionCallOp.getAttr(kCallee) << '(' + << functionCallOp.arguments() << ") : " << functionType; } static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { @@ -1829,9 +1819,8 @@ static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser, static void print(spirv::GroupNonUniformBallotOp ballotOp, OpAsmPrinter &printer) { printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \"" - << stringifyScope(ballotOp.execution_scope()) << "\" "; - printer.printOperand(ballotOp.predicate()); - printer << " : " << ballotOp.getType(); + << stringifyScope(ballotOp.execution_scope()) << "\" " + << ballotOp.predicate() << " : " << ballotOp.getType(); } static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) { @@ -1943,9 +1932,8 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( loadOp.ptr()->getType().cast().getStorageClass()); - printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "; - // Print the pointer operand. - printer.printOperand(loadOp.ptr()); + printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" " + << loadOp.ptr(); printMemoryAccessAttribute(loadOp, printer, elidedAttrs); @@ -2238,26 +2226,26 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { } static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) { - auto *op = moduleOp.getOperation(); + printer << spirv::ModuleOp::getOperationName(); // Only print out addressing model and memory model in a nicer way if both - // presents. Otherwise, print them in the general form. This helps debugging - // ill-formed ModuleOp. + // presents. Otherwise, print them in the general form. This helps + // debugging ill-formed ModuleOp. SmallVector elidedAttrs; auto addressingModelAttrName = spirv::attributeName(); auto memoryModelAttrName = spirv::attributeName(); - if (op->getAttr(addressingModelAttrName) && - op->getAttr(memoryModelAttrName)) { - printer << spirv::ModuleOp::getOperationName() << " \"" + if (moduleOp.getAttr(addressingModelAttrName) && + moduleOp.getAttr(memoryModelAttrName)) { + printer << " \"" << spirv::stringifyAddressingModel(moduleOp.addressing_model()) << "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model()) << '"'; elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName}); } - printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, + printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); - printer.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs); + printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs); } static LogicalResult verify(spirv::ModuleOp moduleOp) { @@ -2417,9 +2405,8 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser, } static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) { - printer << spirv::ReturnValueOp::getOperationName() << ' '; - printer.printOperand(retValOp.value()); - printer << " : " << retValOp.value()->getType(); + printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value() + << " : " << retValOp.value()->getType(); } static LogicalResult verify(spirv::ReturnValueOp retValOp) { @@ -2471,13 +2458,8 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) { } static void print(spirv::SelectOp op, OpAsmPrinter &printer) { - printer << spirv::SelectOp::getOperationName() << " "; - - // Print the operands. - printer.printOperands(op.getOperands()); - - // Print colon and types. - printer << " : " << op.condition()->getType() << ", " + printer << spirv::SelectOp::getOperationName() << " " << op.getOperands() + << " : " << op.condition()->getType() << ", " << op.result()->getType(); } @@ -2788,8 +2770,7 @@ static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) { printer.printSymbolName(constOp.sym_name()); if (auto specID = constOp.getAttrOfType(kSpecIdAttrName)) printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; - printer << " = "; - printer.printAttribute(constOp.default_value()); + printer << " = " << constOp.default_value(); } static LogicalResult verify(spirv::SpecConstantOp constOp) { @@ -2844,17 +2825,12 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( storeOp.ptr()->getType().cast().getStorageClass()); - printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "; - // Print the pointer operand - printer.printOperand(storeOp.ptr()); - printer << ", "; - // Print the value operand - printer.printOperand(storeOp.value()); + printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" " + << storeOp.ptr() << ", " << storeOp.value(); printMemoryAccessAttribute(storeOp, printer, elidedAttrs); printer << " : " << storeOp.value()->getType(); - printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } @@ -2885,9 +2861,8 @@ static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser, } static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) { - printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '; - printer.printOperand(ballotOp.predicate()); - printer << " : " << ballotOp.getType(); + printer << spirv::SubgroupBallotKHROp::getOperationName() << ' ' + << ballotOp.predicate() << " : " << ballotOp.getType(); } //===----------------------------------------------------------------------===// @@ -2973,20 +2948,15 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) { } static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) { - auto *op = varOp.getOperation(); SmallVector elidedAttrs{ spirv::attributeName()}; printer << spirv::VariableOp::getOperationName(); // Print optional initializer - if (op->getNumOperands() > 0) { - printer << " init("; - printer.printOperands(varOp.initializer()); - printer << ")"; - } - - printVariableDecorations(op, printer, elidedAttrs); + if (varOp.getNumOperands() != 0) + printer << " init(" << varOp.initializer() << ")"; + printVariableDecorations(varOp, printer, elidedAttrs); printer << " : " << varOp.getType(); } diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 7726c04..531be29 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -166,15 +166,10 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context) void mlir::printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &p) { - p << '('; - p.printOperands(begin, begin + numDims); - p << ')'; - - if (begin + numDims != end) { - p << '['; - p.printOperands(begin + numDims, end); - p << ']'; - } + Operation::operand_range operands(begin, end); + p << '(' << operands.take_front(numDims) << ')'; + if (operands.size() != numDims) + p << '[' << operands.drop_front(numDims) << ']'; } // Parses dimension and symbol list, and sets 'numDims' to the number of @@ -485,12 +480,9 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, CallOp op) { - p << "call " << op.getAttr("callee") << '('; - p.printOperands(op.getOperands()); - p << ')'; + p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : "; - p.printType(op.getCalleeType()); + p << " : " << op.getCalleeType(); } static LogicalResult verify(CallOp op) { @@ -572,11 +564,7 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser, } static void print(OpAsmPrinter &p, CallIndirectOp op) { - p << "call_indirect "; - p.printOperand(op.getCallee()); - p << '('; - p.printOperands(op.getArgOperands()); - p << ')'; + p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); p << " : " << op.getCallee()->getType(); } @@ -690,12 +678,7 @@ static void print(OpAsmPrinter &p, CmpIOp op) { auto predicateValue = op.getAttrOfType(CmpIOp::getPredicateAttrName()).getInt(); p << '"' << stringifyCmpIPredicate(static_cast(predicateValue)) - << '"'; - - p << ", "; - p.printOperand(op.lhs()); - p << ", "; - p.printOperand(op.rhs()); + << '"' << ", " << op.lhs() << ", " << op.rhs(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); p << " : " << op.lhs()->getType(); @@ -851,15 +834,8 @@ static void print(OpAsmPrinter &p, CmpFOp op) { assert(predicateValue >= static_cast(CmpFPredicate::FirstValidValue) && predicateValue < static_cast(CmpFPredicate::NumPredicates) && "unknown predicate index"); - Builder b(op.getContext()); - auto predicateStringAttr = - b.getStringAttr(getCmpFPredicateNames()[predicateValue]); - p.printAttribute(predicateStringAttr); - - p << ", "; - p.printOperand(op.lhs()); - p << ", "; - p.printOperand(op.rhs()); + p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs() + << ", " << op.rhs(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); p << " : " << op.lhs()->getType(); @@ -1002,9 +978,7 @@ static ParseResult parseCondBranchOp(OpAsmParser &parser, } static void print(OpAsmPrinter &p, CondBranchOp op) { - p << "cond_br "; - p.printOperand(op.getCondition()); - p << ", "; + p << "cond_br " << op.getCondition() << ", "; p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); p << ", "; p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); @@ -1025,7 +999,7 @@ static void print(OpAsmPrinter &p, ConstantOp &op) { if (op.getAttrs().size() > 1) p << ' '; - p.printAttribute(op.getValue()); + p << op.getValue(); // If the value is a symbol reference, print a trailing type. if (op.getValue().isa()) @@ -1407,18 +1381,12 @@ void DmaStartOp::build(Builder *builder, OperationState &result, } void DmaStartOp::print(OpAsmPrinter &p) { - p << "dma_start " << *getSrcMemRef() << '['; - p.printOperands(getSrcIndices()); - p << "], " << *getDstMemRef() << '['; - p.printOperands(getDstIndices()); - p << "], " << *getNumElements(); - p << ", " << *getTagMemRef() << '['; - p.printOperands(getTagIndices()); - p << ']'; - if (isStrided()) { - p << ", " << *getStride(); - p << ", " << *getNumElementsPerStride(); - } + p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], " + << *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements() + << ", " << *getTagMemRef() << '[' << getTagIndices() << ']'; + if (isStrided()) + p << ", " << *getStride() << ", " << *getNumElementsPerStride(); + p.printOptionalAttrDict(getAttrs()); p << " : " << getSrcMemRef()->getType(); p << ", " << getDstMemRef()->getType(); @@ -1550,12 +1518,8 @@ void DmaWaitOp::build(Builder *builder, OperationState &result, } void DmaWaitOp::print(OpAsmPrinter &p) { - p << "dma_wait "; - p.printOperand(getTagMemRef()); - p << '['; - p.printOperands(getTagIndices()); - p << "], "; - p.printOperand(getNumElements()); + p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], " + << getNumElements(); p.printOptionalAttrDict(getAttrs()); p << " : " << getTagMemRef()->getType(); } @@ -1604,8 +1568,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ExtractElementOp op) { - p << "extract_element " << *op.getAggregate() << '['; - p.printOperands(op.getIndices()); + p << "extract_element " << *op.getAggregate() << '[' << op.getIndices(); p << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getAggregate()->getType(); @@ -1686,9 +1649,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, LoadOp op) { - p << "load " << *op.getMemRef() << '['; - p.printOperands(op.getIndices()); - p << ']'; + p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getMemRefType(); } @@ -1922,12 +1883,8 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { static void print(OpAsmPrinter &p, ReturnOp op) { p << "return"; - if (op.getNumOperands() != 0) { - p << ' '; - p.printOperands(op.getOperands()); - p << " : "; - interleaveComma(op.getOperandTypes(), p); - } + if (op.getNumOperands() != 0) + p << ' ' << op.getOperands() << " : " << op.getOperandTypes(); } static LogicalResult verify(ReturnOp op) { @@ -1984,9 +1941,7 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, SelectOp op) { - p << "select "; - p.printOperands(op.getOperands()); - p << " : " << op.getTrueValue()->getType(); + p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType(); p.printOptionalAttrDict(op.getAttrs()); } @@ -2093,9 +2048,7 @@ OpFoldResult SplatOp::fold(ArrayRef operands) { static void print(OpAsmPrinter &p, StoreOp op) { p << "store " << *op.getValueToStore(); - p << ", " << *op.getMemRef() << '['; - p.printOperands(op.getIndices()); - p << ']'; + p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getMemRefType(); } @@ -2339,9 +2292,7 @@ static void print(OpAsmPrinter &p, ViewOp op) { auto *dynamicOffset = op.getDynamicOffset(); if (dynamicOffset != nullptr) p.printOperand(dynamicOffset); - p << "]["; - p.printOperands(op.getDynamicSizes()); - p << ']'; + p << "][" << op.getDynamicSizes() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); } @@ -2609,13 +2560,8 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, SubViewOp op) { - p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; - p.printOperands(op.offsets()); - p << "]["; - p.printOperands(op.sizes()); - p << "]["; - p.printOperands(op.strides()); - p << ']'; + p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets() + << "][" << op.sizes() << "][" << op.strides() << ']'; SmallVector elidedAttrs = { SubViewOp::getOperandSegmentSizeAttr()}; diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 28a0322..a2345fe 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -110,17 +110,16 @@ static void print(OpAsmPrinter &p, ContractionOp op) { llvm::StringSet<> traitAttrsSet; traitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector attrs; - for (auto attr : op.getAttrs()) { + for (auto attr : op.getAttrs()) if (traitAttrsSet.count(attr.first.strref()) > 0) attrs.push_back(attr); - } + auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", "; p << *op.rhs() << ", " << *op.acc(); - if (llvm::size(op.masks()) == 2) { - p << ", " << **op.masks().begin(); - p << ", " << **(op.masks().begin() + 1); - } + if (op.masks().size() == 2) + p << ", " << op.masks(); + p.printOptionalAttrDict(op.getAttrs(), attrNames); p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into " << op.getResultType(); @@ -417,9 +416,8 @@ static LogicalResult verify(vector::ExtractOp op) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, BroadcastOp op) { - p << op.getOperationName() << " " << *op.source(); - p << " : " << op.getSourceType(); - p << " to " << op.getVectorType(); + p << op.getOperationName() << " " << *op.source() << " : " + << op.getSourceType() << " to " << op.getVectorType(); } static LogicalResult verify(BroadcastOp op) { @@ -560,8 +558,7 @@ static void print(OpAsmPrinter &p, InsertOp op) { p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()}); - p << " : " << op.getSourceType(); - p << " into " << op.getDestVectorType(); + p << " : " << op.getSourceType() << " into " << op.getDestVectorType(); } static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) { @@ -789,8 +786,8 @@ static LogicalResult verify(InsertStridedSliceOp op) { static void print(OpAsmPrinter &p, OuterProductOp op) { p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs(); - if (llvm::size(op.acc()) > 0) - p << ", " << **op.acc().begin(); + if (!op.acc().empty()) + p << ", " << op.acc(); p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType(); } @@ -1034,16 +1031,10 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap, } static void print(OpAsmPrinter &p, TransferReadOp op) { - p << op.getOperationName() << " "; - p.printOperand(op.memref()); - p << "["; - p.printOperands(op.indices()); - p << "], "; - p.printOperand(op.padding()); - p << " "; + p << op.getOperationName() << " " << op.memref() << "[" << op.indices() + << "], " << op.padding() << " "; p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getMemRefType(); - p << ", " << op.getVectorType(); + p << " : " << op.getMemRefType() << ", " << op.getVectorType(); } ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) { @@ -1106,15 +1097,10 @@ static LogicalResult verify(TransferReadOp op) { // TransferWriteOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, TransferWriteOp op) { - p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref(); - p << "["; - p.printOperands(op.indices()); - p << "]"; + p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref() + << "[" << op.indices() << "]"; p.printOptionalAttrDict(op.getAttrs()); - p << " : "; - p.printType(op.getVectorType()); - p << ", "; - p.printType(op.getMemRefType()); + p << " : " << op.getVectorType() << ", " << op.getMemRefType(); } ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) { @@ -1180,13 +1166,13 @@ void TypeCastOp::build(Builder *builder, OperationState &result, inferVectorTypeCastResultType(source->getType().cast())); } -static void print(OpAsmPrinter &p, TypeCastOp &op) { +static void print(OpAsmPrinter &p, TypeCastOp op) { auto type = op.getOperand()->getType().cast(); p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to " << inferVectorTypeCastResultType(type); } -static LogicalResult verify(TypeCastOp &op) { +static LogicalResult verify(TypeCastOp op) { auto resultType = inferVectorTypeCastResultType(op.getMemRefType()); if (op.getResultMemRefType() != resultType) return op.emitOpError("expects result type to be: ") << resultType; @@ -1208,9 +1194,9 @@ ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) { parser.addTypeToList(resultType, result.types)); } -static void print(OpAsmPrinter &p, ConstantMaskOp &op) { - p << op.getOperationName() << ' ' << op.mask_dim_sizes(); - p << " : " << op.getResult()->getType(); +static void print(OpAsmPrinter &p, ConstantMaskOp op) { + p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : " + << op.getResult()->getType(); } static LogicalResult verify(ConstantMaskOp &op) { @@ -1256,13 +1242,11 @@ ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) { parser.addTypeToList(resultType, result.types)); } -static void print(OpAsmPrinter &p, CreateMaskOp &op) { - p << op.getOperationName() << ' '; - p.printOperands(op.operands()); - p << " : " << op.getResult()->getType(); +static void print(OpAsmPrinter &p, CreateMaskOp op) { + p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType(); } -static LogicalResult verify(CreateMaskOp &op) { +static LogicalResult verify(CreateMaskOp op) { // Verify that an operand was specified for each result vector each dimension. if (op.getNumOperands() != op.getResult()->getType().cast().getRank()) -- 2.7.4