NFC: Cleanup the various Op::print methods.
authorRiver Riddle <riverriddle@google.com>
Thu, 12 Dec 2019 23:31:39 +0000 (15:31 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 12 Dec 2019 23:32:21 +0000 (15:32 -0800)
This cleans up the implementation of the various operation print methods. This is done via a combination of code cleanup, adding new streaming methods to the printer(e.g. operand ranges), etc.

PiperOrigin-RevId: 285285181

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/AffineOps/AffineOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/lib/Dialect/VectorOps/VectorOps.cpp

index c6e89af..bc6887d 100644 (file)
@@ -140,7 +140,7 @@ def NVVM_MmaOp :
   }];
   let parser = [{ return parseNVVMMmaOp(parser, result); }];
   let printer = [{ printNVVMMmaOp(p, *this); }];
-  let verifier = [{ return mlir::NVVM::verify(*this); }];
+  let verifier = [{ return ::verify(*this); }];
 }
 
 #endif // NVVMIR_OPS
index 3052f79..05beaea 100644 (file)
@@ -154,6 +154,18 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value &value) {
   p.printOperand(&value);
   return p;
 }
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value *value) {
+  return p << *value;
+}
+
+template <typename T,
+          typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
+                                      !std::is_convertible<T &, Value *>::value,
+                                  T>::type * = nullptr>
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
+  p.printOperands(values);
+  return p;
+}
 
 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
   p.printType(type);
@@ -170,14 +182,29 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
 // FunctionType with the Type version above, not have it match this.
 template <typename T, typename std::enable_if<
                           !std::is_convertible<T &, Value &>::value &&
+                              !std::is_convertible<T &, Value *>::value &&
                               !std::is_convertible<T &, Type &>::value &&
-                              !std::is_convertible<T &, Attribute &>::value,
+                              !std::is_convertible<T &, Attribute &>::value &&
+                              !std::is_convertible<T &, ValueRange>::value &&
+                              !llvm::is_one_of<T, bool>::value,
                           T>::type * = nullptr>
 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
   p.getStream() << other;
   return p;
 }
 
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) {
+  return p << (value ? StringRef("true") : "false");
+}
+
+template <typename IteratorT>
+inline OpAsmPrinter &
+operator<<(OpAsmPrinter &p,
+           const iterator_range<ValueTypeIterator<IteratorT>> &types) {
+  interleaveComma(types, p);
+  return p;
+}
+
 //===----------------------------------------------------------------------===//
 // OpAsmParser
 //===----------------------------------------------------------------------===//
index 59e5afe..96a1a68 100644 (file)
@@ -1985,18 +1985,12 @@ static ParseResult parseAffineMinOp(OpAsmParser &parser,
 static void print(OpAsmPrinter &p, AffineMinOp op) {
   p << op.getOperationName() << ' '
     << op.getAttr(AffineMinOp::getMapAttrName());
-  auto begin = op.operand_begin();
-  auto end = op.operand_end();
+  auto operands = op.getOperands();
   unsigned numDims = op.map().getNumDims();
-  p << '(';
-  p.printOperands(begin, begin + numDims);
-  p << ')';
-
-  if (begin + numDims != end) {
-    p << '[';
-    p.printOperands(begin + numDims, end);
-    p << ']';
-  }
+  p << '(' << operands.take_front(numDims) << ')';
+
+  if (operands.size() != numDims)
+    p << '[' << operands.drop_front(numDims) << ']';
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
 }
 
index b897065..1f48d6d 100644 (file)
@@ -289,8 +289,7 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
 
   // Print the launch configuration.
   p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword();
-  printSizeAssignment(p, op.getGridSize(),
-                      operands.drop_back(operands.size() - 3),
+  printSizeAssignment(p, op.getGridSize(), operands.take_front(3),
                       op.getBlockIds());
   p << ' ' << op.getThreadsKeyword();
   printSizeAssignment(p, op.getBlockSize(), operands.slice(3, 3),
@@ -303,25 +302,17 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
   // Print the data argument remapping.
   if (!op.body().empty() && !operands.empty()) {
     p << ' ' << op.getArgsKeyword() << '(';
-    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
-      if (i != 0)
-        p << ", ";
-      p << *op.body().front().getArgument(LaunchOp::kNumConfigRegionAttributes +
-                                          i)
+    Block *entryBlock = &op.body().front();
+    interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
+      p << *entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i)
         << " = " << *operands[i];
-    }
+    });
     p << ") ";
   }
 
   // Print the types of data arguments.
-  if (!operands.empty()) {
-    p << ": ";
-    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
-      if (i != 0)
-        p << ", ";
-      p << operands[i]->getType();
-    }
-  }
+  if (!operands.empty())
+    p << ": " << operands.getTypes();
 
   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
   p.printOptionalAttrDict(op.getAttrs());
@@ -701,7 +692,7 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
     return;
 
   p << ' ' << keyword << '(';
-  interleaveComma(values, p.getStream(),
+  interleaveComma(values, p,
                   [&p](BlockArgument *v) { p << *v << " : " << v->getType(); });
   p << ')';
 }
index 78da999..d037d2e 100644 (file)
@@ -177,9 +177,8 @@ static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
   SmallVector<Type, 8> types(op.getOperandTypes());
   auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
 
-  p << op.getOperationName() << ' ' << *op.base() << '[';
-  p.printOperands(std::next(op.operand_begin()), op.operand_end());
-  p << ']';
+  p << op.getOperationName() << ' ' << *op.base() << '['
+    << op.getOperands().drop_front() << ']';
   p.printOptionalAttrDict(op.getAttrs());
   p << " : " << funcTy;
 }
@@ -312,10 +311,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
   else
     p << *op.getOperand(0);
 
-  p << '(';
-  p.printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
-  p << ')';
-
+  p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
   p.printOptionalAttrDict(op.getAttrs(), {"callee"});
 
   // Reconstruct the function MLIR function type from operand and result types.
@@ -938,8 +934,7 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
   // Print the trailing type unless it's a string global.
   if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
     return;
-  p << " : ";
-  p.printType(op.type());
+  p << " : " << op.type();
 
   Region &initializer = op.getInitializerRegion();
   if (!initializer.empty())
@@ -1346,8 +1341,7 @@ static LogicalResult verify(LLVMFuncOp op) {
 static void printNullOp(OpAsmPrinter &p, LLVM::NullOp op) {
   p << NullOp::getOperationName();
   p.printOptionalAttrDict(op.getAttrs());
-  p << " : ";
-  p.printType(op.getType());
+  p << " : " << op.getType();
 }
 
 // <operation> = `llvm.mlir.null` : type
index 0b10391..e4708fb 100644 (file)
 #include "llvm/IR/Type.h"
 #include "llvm/Support/SourceMgr.h"
 
-namespace mlir {
-namespace NVVM {
+using namespace mlir;
+using namespace NVVM;
 
 //===----------------------------------------------------------------------===//
 // Printing/parsing for NVVM ops
 //===----------------------------------------------------------------------===//
 
 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
-  p << op->getName() << " ";
-  p.printOperands(op->getOperands());
+  p << op->getName() << " " << op->getOperands();
   if (op->getNumResults() > 0)
-    interleaveComma(op->getResultTypes(), p << " : ");
+    p << " : " << op->getResultTypes();
 }
 
 // <operation> ::= `llvm.nvvm.XYZ` : type
@@ -141,8 +140,7 @@ static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) {
 }
 
 static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) {
-  p << op.getOperationName() << " ";
-  p.printOperands(op.getOperands());
+  p << op.getOperationName() << " " << op.getOperands();
   p.printOptionalAttrDict(op.getAttrs());
   p << " : "
     << FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()),
@@ -210,10 +208,11 @@ NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
   allowUnknownOperations();
 }
 
+namespace mlir {
+namespace NVVM {
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
-
-static DialectRegistration<NVVMDialect> nvvmDialect;
-
 } // namespace NVVM
 } // namespace mlir
+
+static DialectRegistration<NVVMDialect> nvvmDialect;
index 487382b..30c55b5 100644 (file)
 #include "llvm/IR/Type.h"
 #include "llvm/Support/SourceMgr.h"
 
-namespace mlir {
-namespace ROCDL {
+using namespace mlir;
+using namespace ROCDL;
 
 //===----------------------------------------------------------------------===//
 // Printing/parsing for ROCDL ops
 //===----------------------------------------------------------------------===//
 
 static void printROCDLOp(OpAsmPrinter &p, Operation *op) {
-  p << op->getName() << " ";
-  p.printOperands(op->getOperands());
+  p << op->getName() << " " << op->getOperands();
   if (op->getNumResults() > 0)
-    interleaveComma(op->getResultTypes(), p << " : ");
+    p << " : " << op->getResultTypes();
 }
 
 // <operation> ::= `rocdl.XYZ` : type
@@ -73,10 +72,11 @@ ROCDLDialect::ROCDLDialect(MLIRContext *context) : Dialect("rocdl", context) {
   allowUnknownOperations();
 }
 
+namespace mlir {
+namespace ROCDL {
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
-
-static DialectRegistration<ROCDLDialect> rocdlDialect;
-
 } // namespace ROCDL
 } // namespace mlir
+
+static DialectRegistration<ROCDLDialect> rocdlDialect;
index 2efd26a..6adfeb5 100644 (file)
@@ -60,18 +60,16 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
   llvm::StringSet<> linalgTraitAttrsSet;
   linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end());
   SmallVector<NamedAttribute, 8> attrs;
-  for (auto attr : op.getAttrs()) {
+  for (auto attr : op.getAttrs())
     if (linalgTraitAttrsSet.count(attr.first.strref()) > 0)
       attrs.push_back(attr);
-  }
+
   auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
-  p << op.getOperationName() << " " << dictAttr << " ";
-  p.printOperands(op.getOperands());
+  p << op.getOperationName() << " " << dictAttr << " " << op.getOperands();
   if (!op.region().empty())
     p.printRegion(op.region());
   p.printOptionalAttrDict(op.getAttrs(), attrNames);
-  p << ": ";
-  interleaveComma(op.getOperandTypes(), p);
+  p << ": " << op.getOperandTypes();
 }
 
 static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
@@ -342,14 +340,13 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
 }
 
 static void print(OpAsmPrinter &p, SliceOp op) {
-  p << SliceOp::getOperationName() << " " << *op.view() << "[";
-  p.printOperands(op.indexings());
-  p << "] ";
+  auto indexings = op.indexings();
+  p << SliceOp::getOperationName() << " " << *op.view() << "[" << indexings
+    << "] ";
   p.printOptionalAttrDict(op.getAttrs());
   p << " : " << op.getBaseViewType();
-  for (auto indexing : op.indexings()) {
-    p << ", " << indexing->getType();
-  }
+  if (!indexings.empty())
+    p << ", " << op.indexings().getTypes();
   p << ", " << op.getType();
 }
 
@@ -455,16 +452,11 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
 
 static void print(OpAsmPrinter &p, YieldOp op) {
   p << op.getOperationName();
-  if (op.getNumOperands() > 0) {
-    p << ' ';
-    p.printOperands(op.operand_begin(), op.operand_end());
-  }
+  if (op.getNumOperands() > 0)
+    p << ' ' << op.getOperands();
   p.printOptionalAttrDict(op.getAttrs());
-  if (op.getNumOperands() > 0) {
-    p << " : ";
-    interleaveComma(op.getOperands(), p,
-                    [&](Value *e) { p.printType(e->getType()); });
-  }
+  if (op.getNumOperands() > 0)
+    p << " : " << op.getOperandTypes();
 }
 
 static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
@@ -536,12 +528,9 @@ static LogicalResult verify(YieldOp op) {
 // Where %0, %1 and %2 are ssa-values of type MemRefType with strides.
 static void printLinalgLibraryOp(OpAsmPrinter &p, Operation *op) {
   assert(op->getAbstractOperation() && "unregistered operation");
-  p << op->getName().getStringRef() << "(";
-  interleaveComma(op->getOperands(), p, [&](Value *v) { p << *v; });
-  p << ")";
+  p << op->getName().getStringRef() << "(" << op->getOperands() << ")";
   p.printOptionalAttrDict(op->getAttrs());
-  p << " : ";
-  interleaveComma(op->getOperands(), p, [&](Value *v) { p << v->getType(); });
+  p << " : " << op->getOperandTypes();
 }
 
 static ParseResult parseLinalgLibraryOp(OpAsmParser &parser,
index 99ab1cd..839f134 100644 (file)
@@ -500,11 +500,8 @@ static ParseResult parseBitFieldExtractOp(OpAsmParser &parser,
 }
 
 static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) {
-  printer << op->getName() << ' ';
-  printer.printOperands(op->getOperands());
-  printer << " : " << op->getOperand(0)->getType() << ", "
-          << op->getOperand(1)->getType() << ", "
-          << op->getOperand(2)->getType();
+  printer << op->getName() << ' ' << op->getOperands() << " : "
+          << op->getOperandTypes();
 }
 
 static LogicalResult verifyBitFieldExtractOp(Operation *op) {
@@ -580,9 +577,8 @@ static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
 }
 
 static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
-  printer << logicalOp->getName() << ' ';
-  printer.printOperands(logicalOp->getOperands());
-  printer << " : " << logicalOp->getOperand(0)->getType();
+  printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
+          << logicalOp->getOperand(0)->getType();
 }
 
 static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
@@ -717,9 +713,7 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
 
 static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
   printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
-          << '[';
-  printer.printOperands(op.indices());
-  printer << "] : " << op.base_ptr()->getType();
+          << '[' << op.indices() << "] : " << op.base_ptr()->getType();
 }
 
 static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
@@ -875,9 +869,8 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
   printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
           << stringifyScope(atomOp.memory_scope()) << "\" \""
           << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
-          << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" ";
-  printer.printOperands(atomOp.getOperands());
-  printer << " : " << atomOp.pointer()->getType();
+          << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
+          << atomOp.getOperands() << " : " << atomOp.pointer()->getType();
 }
 
 static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
@@ -975,9 +968,9 @@ static ParseResult parseBitFieldInsertOp(OpAsmParser &parser,
 
 static void print(spirv::BitFieldInsertOp bitFieldInsertOp,
                   OpAsmPrinter &printer) {
-  printer << spirv::BitFieldInsertOp::getOperationName() << ' ';
-  printer.printOperands(bitFieldInsertOp.getOperands());
-  printer << " : " << bitFieldInsertOp.base()->getType() << ", "
+  printer << spirv::BitFieldInsertOp::getOperationName() << ' '
+          << bitFieldInsertOp.getOperands() << " : "
+          << bitFieldInsertOp.base()->getType() << ", "
           << bitFieldInsertOp.offset()->getType() << ", "
           << bitFieldInsertOp.count()->getType();
 }
@@ -1072,8 +1065,8 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
 }
 
 static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
-  printer << spirv::BranchConditionalOp::getOperationName() << ' ';
-  printer.printOperand(branchOp.condition());
+  printer << spirv::BranchConditionalOp::getOperationName() << ' '
+          << branchOp.condition();
 
   if (auto weights = branchOp.branch_weights()) {
     printer << " [";
@@ -1148,9 +1141,9 @@ static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
 
 static void print(spirv::CompositeConstructOp compositeConstructOp,
                   OpAsmPrinter &printer) {
-  printer << spirv::CompositeConstructOp::getOperationName() << " ";
-  printer.printOperands(compositeConstructOp.constituents());
-  printer << " : " << compositeConstructOp.getResult()->getType();
+  printer << spirv::CompositeConstructOp::getOperationName() << " "
+          << compositeConstructOp.constituents() << " : "
+          << compositeConstructOp.getResult()->getType();
 }
 
 static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
@@ -1322,9 +1315,8 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
 
 static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
   printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
-  if (constOp.getType().isa<spirv::ArrayType>()) {
+  if (constOp.getType().isa<spirv::ArrayType>())
     printer << " : " << constOp.getType();
-  }
 }
 
 static LogicalResult verify(spirv::ConstantOp constOp) {
@@ -1577,9 +1569,8 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
           << execModeOp.fn() << " \""
           << stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
   auto values = execModeOp.values();
-  if (!values.size()) {
+  if (!values.size())
     return;
-  }
   printer << ", ";
   interleaveComma(values, printer, [&](Attribute a) {
     printer << a.cast<IntegerAttr>().getInt();
@@ -1626,9 +1617,8 @@ static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
       FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
 
   printer << spirv::FunctionCallOp::getOperationName() << ' '
-          << functionCallOp.getAttr(kCallee) << '(';
-  printer.printOperands(functionCallOp.arguments());
-  printer << ") : " << functionType;
+          << functionCallOp.getAttr(kCallee) << '('
+          << functionCallOp.arguments() << ") : " << functionType;
 }
 
 static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
@@ -1829,9 +1819,8 @@ static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
 static void print(spirv::GroupNonUniformBallotOp ballotOp,
                   OpAsmPrinter &printer) {
   printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
-          << stringifyScope(ballotOp.execution_scope()) << "\" ";
-  printer.printOperand(ballotOp.predicate());
-  printer << " : " << ballotOp.getType();
+          << stringifyScope(ballotOp.execution_scope()) << "\" "
+          << ballotOp.predicate() << " : " << ballotOp.getType();
 }
 
 static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
@@ -1943,9 +1932,8 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
   SmallVector<StringRef, 4> elidedAttrs;
   StringRef sc = stringifyStorageClass(
       loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
-  printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" ";
-  // Print the pointer operand.
-  printer.printOperand(loadOp.ptr());
+  printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
+          << loadOp.ptr();
 
   printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
 
@@ -2238,26 +2226,26 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
 }
 
 static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
-  auto *op = moduleOp.getOperation();
+  printer << spirv::ModuleOp::getOperationName();
 
   // Only print out addressing model and memory model in a nicer way if both
-  // presents. Otherwise, print them in the general form. This helps debugging
-  // ill-formed ModuleOp.
+  // presents. Otherwise, print them in the general form. This helps
+  // debugging ill-formed ModuleOp.
   SmallVector<StringRef, 2> elidedAttrs;
   auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
   auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
-  if (op->getAttr(addressingModelAttrName) &&
-      op->getAttr(memoryModelAttrName)) {
-    printer << spirv::ModuleOp::getOperationName() << " \""
+  if (moduleOp.getAttr(addressingModelAttrName) &&
+      moduleOp.getAttr(memoryModelAttrName)) {
+    printer << " \""
             << spirv::stringifyAddressingModel(moduleOp.addressing_model())
             << "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
             << '"';
     elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
   }
 
-  printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
+  printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
                       /*printBlockTerminators=*/false);
-  printer.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
+  printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs);
 }
 
 static LogicalResult verify(spirv::ModuleOp moduleOp) {
@@ -2417,9 +2405,8 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser,
 }
 
 static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
-  printer << spirv::ReturnValueOp::getOperationName() << ' ';
-  printer.printOperand(retValOp.value());
-  printer << " : " << retValOp.value()->getType();
+  printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value()
+          << " : " << retValOp.value()->getType();
 }
 
 static LogicalResult verify(spirv::ReturnValueOp retValOp) {
@@ -2471,13 +2458,8 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
 }
 
 static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
-  printer << spirv::SelectOp::getOperationName() << " ";
-
-  // Print the operands.
-  printer.printOperands(op.getOperands());
-
-  // Print colon and types.
-  printer << " : " << op.condition()->getType() << ", "
+  printer << spirv::SelectOp::getOperationName() << " " << op.getOperands()
+          << " : " << op.condition()->getType() << ", "
           << op.result()->getType();
 }
 
@@ -2788,8 +2770,7 @@ static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
   printer.printSymbolName(constOp.sym_name());
   if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName))
     printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
-  printer << " = ";
-  printer.printAttribute(constOp.default_value());
+  printer << " = " << constOp.default_value();
 }
 
 static LogicalResult verify(spirv::SpecConstantOp constOp) {
@@ -2844,17 +2825,12 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
   SmallVector<StringRef, 4> elidedAttrs;
   StringRef sc = stringifyStorageClass(
       storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
-  printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" ";
-  // Print the pointer operand
-  printer.printOperand(storeOp.ptr());
-  printer << ", ";
-  // Print the value operand
-  printer.printOperand(storeOp.value());
+  printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
+          << storeOp.ptr() << ", " << storeOp.value();
 
   printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
 
   printer << " : " << storeOp.value()->getType();
-
   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 }
 
@@ -2885,9 +2861,8 @@ static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser,
 }
 
 static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) {
-  printer << spirv::SubgroupBallotKHROp::getOperationName() << ' ';
-  printer.printOperand(ballotOp.predicate());
-  printer << " : " << ballotOp.getType();
+  printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '
+          << ballotOp.predicate() << " : " << ballotOp.getType();
 }
 
 //===----------------------------------------------------------------------===//
@@ -2973,20 +2948,15 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
 }
 
 static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
-  auto *op = varOp.getOperation();
   SmallVector<StringRef, 4> elidedAttrs{
       spirv::attributeName<spirv::StorageClass>()};
   printer << spirv::VariableOp::getOperationName();
 
   // Print optional initializer
-  if (op->getNumOperands() > 0) {
-    printer << " init(";
-    printer.printOperands(varOp.initializer());
-    printer << ")";
-  }
-
-  printVariableDecorations(op, printer, elidedAttrs);
+  if (varOp.getNumOperands() != 0)
+    printer << " init(" << varOp.initializer() << ")";
 
+  printVariableDecorations(varOp, printer, elidedAttrs);
   printer << " : " << varOp.getType();
 }
 
index 7726c04..531be29 100644 (file)
@@ -166,15 +166,10 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
 void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
                                  Operation::operand_iterator end,
                                  unsigned numDims, OpAsmPrinter &p) {
-  p << '(';
-  p.printOperands(begin, begin + numDims);
-  p << ')';
-
-  if (begin + numDims != end) {
-    p << '[';
-    p.printOperands(begin + numDims, end);
-    p << ']';
-  }
+  Operation::operand_range operands(begin, end);
+  p << '(' << operands.take_front(numDims) << ')';
+  if (operands.size() != numDims)
+    p << '[' << operands.drop_front(numDims) << ']';
 }
 
 // Parses dimension and symbol list, and sets 'numDims' to the number of
@@ -485,12 +480,9 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
 }
 
 static void print(OpAsmPrinter &p, CallOp op) {
-  p << "call " << op.getAttr("callee") << '(';
-  p.printOperands(op.getOperands());
-  p << ')';
+  p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')';
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
-  p << " : ";
-  p.printType(op.getCalleeType());
+  p << " : " << op.getCalleeType();
 }
 
 static LogicalResult verify(CallOp op) {
@@ -572,11 +564,7 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser,
 }
 
 static void print(OpAsmPrinter &p, CallIndirectOp op) {
-  p << "call_indirect ";
-  p.printOperand(op.getCallee());
-  p << '(';
-  p.printOperands(op.getArgOperands());
-  p << ')';
+  p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
   p << " : " << op.getCallee()->getType();
 }
@@ -690,12 +678,7 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
   auto predicateValue =
       op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
   p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
-    << '"';
-
-  p << ", ";
-  p.printOperand(op.lhs());
-  p << ", ";
-  p.printOperand(op.rhs());
+    << '"' << ", " << op.lhs() << ", " << op.rhs();
   p.printOptionalAttrDict(op.getAttrs(),
                           /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
   p << " : " << op.lhs()->getType();
@@ -851,15 +834,8 @@ static void print(OpAsmPrinter &p, CmpFOp op) {
   assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
          predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
          "unknown predicate index");
-  Builder b(op.getContext());
-  auto predicateStringAttr =
-      b.getStringAttr(getCmpFPredicateNames()[predicateValue]);
-  p.printAttribute(predicateStringAttr);
-
-  p << ", ";
-  p.printOperand(op.lhs());
-  p << ", ";
-  p.printOperand(op.rhs());
+  p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs()
+    << ", " << op.rhs();
   p.printOptionalAttrDict(op.getAttrs(),
                           /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
   p << " : " << op.lhs()->getType();
@@ -1002,9 +978,7 @@ static ParseResult parseCondBranchOp(OpAsmParser &parser,
 }
 
 static void print(OpAsmPrinter &p, CondBranchOp op) {
-  p << "cond_br ";
-  p.printOperand(op.getCondition());
-  p << ", ";
+  p << "cond_br " << op.getCondition() << ", ";
   p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
   p << ", ";
   p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
@@ -1025,7 +999,7 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
 
   if (op.getAttrs().size() > 1)
     p << ' ';
-  p.printAttribute(op.getValue());
+  p << op.getValue();
 
   // If the value is a symbol reference, print a trailing type.
   if (op.getValue().isa<SymbolRefAttr>())
@@ -1407,18 +1381,12 @@ void DmaStartOp::build(Builder *builder, OperationState &result,
 }
 
 void DmaStartOp::print(OpAsmPrinter &p) {
-  p << "dma_start " << *getSrcMemRef() << '[';
-  p.printOperands(getSrcIndices());
-  p << "], " << *getDstMemRef() << '[';
-  p.printOperands(getDstIndices());
-  p << "], " << *getNumElements();
-  p << ", " << *getTagMemRef() << '[';
-  p.printOperands(getTagIndices());
-  p << ']';
-  if (isStrided()) {
-    p << ", " << *getStride();
-    p << ", " << *getNumElementsPerStride();
-  }
+  p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], "
+    << *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements()
+    << ", " << *getTagMemRef() << '[' << getTagIndices() << ']';
+  if (isStrided())
+    p << ", " << *getStride() << ", " << *getNumElementsPerStride();
+
   p.printOptionalAttrDict(getAttrs());
   p << " : " << getSrcMemRef()->getType();
   p << ", " << getDstMemRef()->getType();
@@ -1550,12 +1518,8 @@ void DmaWaitOp::build(Builder *builder, OperationState &result,
 }
 
 void DmaWaitOp::print(OpAsmPrinter &p) {
-  p << "dma_wait ";
-  p.printOperand(getTagMemRef());
-  p << '[';
-  p.printOperands(getTagIndices());
-  p << "], ";
-  p.printOperand(getNumElements());
+  p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
+    << getNumElements();
   p.printOptionalAttrDict(getAttrs());
   p << " : " << getTagMemRef()->getType();
 }
@@ -1604,8 +1568,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 //===----------------------------------------------------------------------===//
 
 static void print(OpAsmPrinter &p, ExtractElementOp op) {
-  p << "extract_element " << *op.getAggregate() << '[';
-  p.printOperands(op.getIndices());
+  p << "extract_element " << *op.getAggregate() << '[' << op.getIndices();
   p << ']';
   p.printOptionalAttrDict(op.getAttrs());
   p << " : " << op.getAggregate()->getType();
@@ -1686,9 +1649,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
 //===----------------------------------------------------------------------===//
 
 static void print(OpAsmPrinter &p, LoadOp op) {
-  p << "load " << *op.getMemRef() << '[';
-  p.printOperands(op.getIndices());
-  p << ']';
+  p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']';
   p.printOptionalAttrDict(op.getAttrs());
   p << " : " << op.getMemRefType();
 }
@@ -1922,12 +1883,8 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
 
 static void print(OpAsmPrinter &p, ReturnOp op) {
   p << "return";
-  if (op.getNumOperands() != 0) {
-    p << ' ';
-    p.printOperands(op.getOperands());
-    p << " : ";
-    interleaveComma(op.getOperandTypes(), p);
-  }
+  if (op.getNumOperands() != 0)
+    p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
 }
 
 static LogicalResult verify(ReturnOp op) {
@@ -1984,9 +1941,7 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
 }
 
 static void print(OpAsmPrinter &p, SelectOp op) {
-  p << "select ";
-  p.printOperands(op.getOperands());
-  p << " : " << op.getTrueValue()->getType();
+  p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType();
   p.printOptionalAttrDict(op.getAttrs());
 }
 
@@ -2093,9 +2048,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
 
 static void print(OpAsmPrinter &p, StoreOp op) {
   p << "store " << *op.getValueToStore();
-  p << ", " << *op.getMemRef() << '[';
-  p.printOperands(op.getIndices());
-  p << ']';
+  p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']';
   p.printOptionalAttrDict(op.getAttrs());
   p << " : " << op.getMemRefType();
 }
@@ -2339,9 +2292,7 @@ static void print(OpAsmPrinter &p, ViewOp op) {
   auto *dynamicOffset = op.getDynamicOffset();
   if (dynamicOffset != nullptr)
     p.printOperand(dynamicOffset);
-  p << "][";
-  p.printOperands(op.getDynamicSizes());
-  p << ']';
+  p << "][" << op.getDynamicSizes() << ']';
   p.printOptionalAttrDict(op.getAttrs());
   p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
 }
@@ -2609,13 +2560,8 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
 }
 
 static void print(OpAsmPrinter &p, SubViewOp op) {
-  p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
-  p.printOperands(op.offsets());
-  p << "][";
-  p.printOperands(op.sizes());
-  p << "][";
-  p.printOperands(op.strides());
-  p << ']';
+  p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets()
+    << "][" << op.sizes() << "][" << op.strides() << ']';
 
   SmallVector<StringRef, 1> elidedAttrs = {
       SubViewOp::getOperandSegmentSizeAttr()};
index 28a0322..a2345fe 100644 (file)
@@ -110,17 +110,16 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
   llvm::StringSet<> traitAttrsSet;
   traitAttrsSet.insert(attrNames.begin(), attrNames.end());
   SmallVector<NamedAttribute, 8> attrs;
-  for (auto attr : op.getAttrs()) {
+  for (auto attr : op.getAttrs())
     if (traitAttrsSet.count(attr.first.strref()) > 0)
       attrs.push_back(attr);
-  }
+
   auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
   p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
   p << *op.rhs() << ", " << *op.acc();
-  if (llvm::size(op.masks()) == 2) {
-    p << ", " << **op.masks().begin();
-    p << ", " << **(op.masks().begin() + 1);
-  }
+  if (op.masks().size() == 2)
+    p << ", " << op.masks();
+
   p.printOptionalAttrDict(op.getAttrs(), attrNames);
   p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
     << op.getResultType();
@@ -417,9 +416,8 @@ static LogicalResult verify(vector::ExtractOp op) {
 //===----------------------------------------------------------------------===//
 
 static void print(OpAsmPrinter &p, BroadcastOp op) {
-  p << op.getOperationName() << " " << *op.source();
-  p << " : " << op.getSourceType();
-  p << " to " << op.getVectorType();
+  p << op.getOperationName() << " " << *op.source() << " : "
+    << op.getSourceType() << " to " << op.getVectorType();
 }
 
 static LogicalResult verify(BroadcastOp op) {
@@ -560,8 +558,7 @@ static void print(OpAsmPrinter &p, InsertOp op) {
   p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
     << op.position();
   p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
-  p << " : " << op.getSourceType();
-  p << " into " << op.getDestVectorType();
+  p << " : " << op.getSourceType() << " into " << op.getDestVectorType();
 }
 
 static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) {
@@ -789,8 +786,8 @@ static LogicalResult verify(InsertStridedSliceOp op) {
 
 static void print(OpAsmPrinter &p, OuterProductOp op) {
   p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
-  if (llvm::size(op.acc()) > 0)
-    p << ", " << **op.acc().begin();
+  if (!op.acc().empty())
+    p << ", " << op.acc();
   p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
 }
 
@@ -1034,16 +1031,10 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
 }
 
 static void print(OpAsmPrinter &p, TransferReadOp op) {
-  p << op.getOperationName() << " ";
-  p.printOperand(op.memref());
-  p << "[";
-  p.printOperands(op.indices());
-  p << "], ";
-  p.printOperand(op.padding());
-  p << " ";
+  p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
+    << "], " << op.padding() << " ";
   p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.getMemRefType();
-  p << ", " << op.getVectorType();
+  p << " : " << op.getMemRefType() << ", " << op.getVectorType();
 }
 
 ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) {
@@ -1106,15 +1097,10 @@ static LogicalResult verify(TransferReadOp op) {
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
 static void print(OpAsmPrinter &p, TransferWriteOp op) {
-  p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref();
-  p << "[";
-  p.printOperands(op.indices());
-  p << "]";
+  p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref()
+    << "[" << op.indices() << "]";
   p.printOptionalAttrDict(op.getAttrs());
-  p << " : ";
-  p.printType(op.getVectorType());
-  p << ", ";
-  p.printType(op.getMemRefType());
+  p << " : " << op.getVectorType() << ", " << op.getMemRefType();
 }
 
 ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) {
@@ -1180,13 +1166,13 @@ void TypeCastOp::build(Builder *builder, OperationState &result,
       inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
 }
 
-static void print(OpAsmPrinter &p, TypeCastOp &op) {
+static void print(OpAsmPrinter &p, TypeCastOp op) {
   auto type = op.getOperand()->getType().cast<MemRefType>();
   p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to "
     << inferVectorTypeCastResultType(type);
 }
 
-static LogicalResult verify(TypeCastOp &op) {
+static LogicalResult verify(TypeCastOp op) {
   auto resultType = inferVectorTypeCastResultType(op.getMemRefType());
   if (op.getResultMemRefType() != resultType)
     return op.emitOpError("expects result type to be: ") << resultType;
@@ -1208,9 +1194,9 @@ ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) {
       parser.addTypeToList(resultType, result.types));
 }
 
-static void print(OpAsmPrinter &p, ConstantMaskOp &op) {
-  p << op.getOperationName() << ' ' << op.mask_dim_sizes();
-  p << " : " << op.getResult()->getType();
+static void print(OpAsmPrinter &p, ConstantMaskOp op) {
+  p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
+    << op.getResult()->getType();
 }
 
 static LogicalResult verify(ConstantMaskOp &op) {
@@ -1256,13 +1242,11 @@ ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) {
       parser.addTypeToList(resultType, result.types));
 }
 
-static void print(OpAsmPrinter &p, CreateMaskOp &op) {
-  p << op.getOperationName() << ' ';
-  p.printOperands(op.operands());
-  p << " : " << op.getResult()->getType();
+static void print(OpAsmPrinter &p, CreateMaskOp op) {
+  p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType();
 }
 
-static LogicalResult verify(CreateMaskOp &op) {
+static LogicalResult verify(CreateMaskOp op) {
   // Verify that an operand was specified for each result vector each dimension.
   if (op.getNumOperands() !=
       op.getResult()->getType().cast<VectorType>().getRank())