[mlir][NFC] Update Operation::getResultTypes to use ArrayRef<Type> instead of iterato...
authorRiver Riddle <riddleriver@gmail.com>
Tue, 28 Jan 2020 03:57:14 +0000 (19:57 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 28 Jan 2020 03:57:48 +0000 (19:57 -0800)
Summary: The new internal representation of operation results now allows for accessing the result types to be more efficient. Changing the API to ArrayRef is more efficient and removes the need to explicitly materialize vectors in several places.

Differential Revision: https://reviews.llvm.org/D73429

13 files changed:
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/Analysis/InferTypeOpInterface.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/Transforms/CSE.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp

index 497706d1f8e93b04c6629b9b53ea6d8e750eb075..79899c8111fa26983a94cb6b810f170b6910a9c5 100644 (file)
@@ -191,6 +191,10 @@ operator<<(OpAsmPrinter &p,
   interleaveComma(types, p);
   return p;
 }
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
+  interleaveComma(types, p);
+  return p;
+}
 
 //===----------------------------------------------------------------------===//
 // OpAsmParser
index da15a7014e253783e495bc8fe66900e39b7b451d..28e726edd874480958202f2d2b5b6fe07ad84bbb 100644 (file)
@@ -260,10 +260,10 @@ public:
 
   /// Support result type iteration.
   using result_type_iterator = result_range::type_iterator;
-  using result_type_range = iterator_range<result_type_iterator>;
-  result_type_iterator result_type_begin() { return result_begin(); }
-  result_type_iterator result_type_end() { return result_end(); }
-  result_type_range getResultTypes() { return getResults().getTypes(); }
+  using result_type_range = ArrayRef<Type>;
+  result_type_iterator result_type_begin() { return getResultTypes().begin(); }
+  result_type_iterator result_type_end() { return getResultTypes().end(); }
+  result_type_range getResultTypes();
 
   //===--------------------------------------------------------------------===//
   // Attributes
index b62e72aa72dd81c8b4ff79b1ec6add615bf185cb..bf19a5af14ff6f9fe240bb2fa5d94a67b5fbc144 100644 (file)
@@ -595,8 +595,8 @@ public:
   ResultRange(Operation *op);
 
   /// Returns the types of the values within this range.
-  using type_iterator = ValueTypeIterator<iterator>;
-  iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
+  using type_iterator = ArrayRef<Type>::iterator;
+  ArrayRef<Type> getTypes() const;
 
 private:
   /// See `indexed_accessor_range` for details.
index 74d76984be3cd0cb1c068c1217318ab654c6c57b..4b4be52d66265afc628de457610ce689b1a3b800 100644 (file)
@@ -53,8 +53,8 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
                                         op->getOperands(), op->getAttrs(),
                                         op->getRegions(), inferedReturnTypes)))
     return failure();
-  SmallVector<Type, 4> resultTypes(op->getResultTypes());
-  if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
+  if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes,
+                                         op->getResultTypes()))
     return op->emitOpError(
         "inferred type incompatible with return type of operation");
   return success();
index ed28dd2853116b46811eff625a032287fdf34ae4..1a2d3b07f4162c68ad7d240c9babf5206b1a494b 100644 (file)
@@ -652,8 +652,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
 
     Type packedType;
     if (numResults != 0) {
-      packedType = this->lowering.packFunctionResults(
-          llvm::to_vector<4>(op->getResultTypes()));
+      packedType = this->lowering.packFunctionResults(op->getResultTypes());
       if (!packedType)
         return this->matchFailure();
     }
index d96a95e0b41e73ce528002902f5b05eac91e3070..a49df6304e28798ea61dbda7c20bf26e1f1158e7 100644 (file)
@@ -292,11 +292,11 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
   p.printOptionalAttrDict(op.getAttrs(), {"callee"});
 
   // Reconstruct the function MLIR function type from operand and result types.
-  SmallVector<Type, 1> resultTypes(op.getResultTypes());
   SmallVector<Type, 8> argTypes(
       llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
 
-  p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
+  p << " : "
+    << FunctionType::get(argTypes, op.getResultTypes(), op.getContext());
 }
 
 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
index 7854328ceeb25cb6a43be2868d36cacad8995027..96c41fca292ddf176e13bdbf65df6dd4f400fdfc 100644 (file)
@@ -1685,9 +1685,8 @@ static ParseResult parseFunctionCallOp(OpAsmParser &parser,
 
 static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
   SmallVector<Type, 4> argTypes(functionCallOp.getOperandTypes());
-  SmallVector<Type, 1> resultTypes(functionCallOp.getResultTypes());
-  Type functionType =
-      FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
+  Type functionType = FunctionType::get(
+      argTypes, functionCallOp.getResultTypes(), functionCallOp.getContext());
 
   printer << spirv::FunctionCallOp::getOperationName() << ' '
           << functionCallOp.getAttr(kCallee) << '('
index 12c0f1674c36766efc28e16a7ac98f6c0208cc65..fbbe6eb4fccc1aad6d4777a82dabc4f39d19f4c5 100644 (file)
@@ -1764,12 +1764,9 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
   auto funcName = op.callee();
   uint32_t resTypeID = 0;
 
-  SmallVector<Type, 1> resultTypes(op.getResultTypes());
-  if (failed(processType(op.getLoc(),
-                         (resultTypes.empty() ? getVoidType() : resultTypes[0]),
-                         resTypeID))) {
+  Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
+  if (failed(processType(op.getLoc(), resultTy, resTypeID)))
     return failure();
-  }
 
   auto funcID = getOrCreateFunctionID(funcName);
   auto funcCallID = getNextID();
@@ -1781,9 +1778,8 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
     operands.push_back(valueID);
   }
 
-  if (!resultTypes.empty()) {
+  if (!resultTy.isa<NoneType>())
     valueIDMap[op.getResult(0)] = funcCallID;
-  }
 
   return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall,
                                operands);
index 40d2baa7ca4e400ed9bfab71ff91b0013e3f997e..824a2ea87d96212d7047f76823d1aa2f161eb1c8 100644 (file)
@@ -500,9 +500,8 @@ static LogicalResult verify(CallOp op) {
 }
 
 FunctionType CallOp::getCalleeType() {
-  SmallVector<Type, 4> resultTypes(getResultTypes());
   SmallVector<Type, 8> argTypes(getOperandTypes());
-  return FunctionType::get(argTypes, resultTypes, getContext());
+  return FunctionType::get(argTypes, getResultTypes(), getContext());
 }
 
 //===----------------------------------------------------------------------===//
@@ -522,8 +521,8 @@ struct SimplifyIndirectCallWithKnownCallee
       return matchFailure();
 
     // Replace with a direct call.
-    SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
-    rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, callResults,
+    rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
+                                        indirectCall.getResultTypes(),
                                         indirectCall.getArgOperands());
     return matchSuccess();
   }
index d00e1e9468742b032c3bd2a98832a2cf0407c93d..6b19c2603f5b0ef9c92b4b8cced4f74339517ee8 100644 (file)
@@ -551,6 +551,14 @@ unsigned Operation::getNumResults() {
   return hasSingleResult ? 1 : resultType.cast<TupleType>().size();
 }
 
+auto Operation::getResultTypes() -> result_type_range {
+  if (!resultType)
+    return llvm::None;
+  if (hasSingleResult)
+    return resultType;
+  return resultType.cast<TupleType>().getTypes();
+}
+
 void Operation::setSuccessor(Block *block, unsigned index) {
   assert(index < getNumSuccessors());
   getBlockOperands()[index].set(block);
@@ -666,10 +674,9 @@ Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
     }
   }
 
-  SmallVector<Type, 8> resultTypes(getResultTypes());
   unsigned numRegions = getNumRegions();
   auto *newOp =
-      Operation::create(getLoc(), getName(), resultTypes, operands, attrs,
+      Operation::create(getLoc(), getName(), getResultTypes(), operands, attrs,
                         successors, numRegions, hasResizableOperandsList());
 
   // Remember the mapping of any results.
@@ -919,7 +926,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
 
   auto type = op->getResult(0).getType();
   auto elementType = getElementTypeOrSelf(type);
-  for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
+  for (auto resultType : op->getResultTypes().drop_front(1)) {
     if (getElementTypeOrSelf(resultType) != elementType ||
         failed(verifyCompatibleShape(resultType, type)))
       return op->emitOpError()
index 5962f0162b2dd4861f404c515cd34fde74752db9..609a6dca6b27765a9b9d923d550f0bee29f312d4 100644 (file)
@@ -152,6 +152,10 @@ OperandRange::OperandRange(Operation *op)
 ResultRange::ResultRange(Operation *op)
     : ResultRange(op, /*startIndex=*/0, op->getNumResults()) {}
 
+ArrayRef<Type> ResultRange::getTypes() const {
+  return getBase()->getResultTypes();
+}
+
 /// See `indexed_accessor_range` for details.
 OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) {
   return op->getResult(index);
index 30da7c086fce603f31db85dd07b8b206560bd192..0e1bab56a2123634dcc13d6c757a29a84814d520 100644 (file)
@@ -37,10 +37,9 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
     //   - Attributes
     //   - Result Types
     //   - Operands
-    return hash_combine(
-        op->getName(), op->getAttrList().getDictionary(),
-        hash_combine_range(op->result_type_begin(), op->result_type_end()),
-        hash_combine_range(op->operand_begin(), op->operand_end()));
+    return llvm::hash_combine(
+        op->getName(), op->getAttrList().getDictionary(), op->getResultTypes(),
+        llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
   }
   static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
     auto *lhs = const_cast<Operation *>(lhsC);
index d34181c45b187de24075ba9273cd215aa0e23a31..0f976a20acf936472b2ccefa6ec0b92d16f16f8a 100644 (file)
@@ -241,7 +241,7 @@ struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     // If the type is I32, change the type to F32.
-    if (!(*op->result_type_begin()).isInteger(32))
+    if (!Type(*op->result_type_begin()).isInteger(32))
       return matchFailure();
     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
     return matchSuccess();
@@ -254,7 +254,7 @@ struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     // If the type is F32, change the type to F64.
-    if (!(*op->result_type_begin()).isF32())
+    if (!Type(*op->result_type_begin()).isF32())
       return matchFailure();
     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
     return matchSuccess();
@@ -477,8 +477,7 @@ struct OneVResOneVOperandOp1Converter
     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
 
-    SmallVector<Type, 1> resultTypes(op.getResultTypes());
-    rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, resultTypes,
+    rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
                                                        remappedOperands);
     return matchSuccess();
   }