From: River Riddle Date: Fri, 24 May 2019 20:28:55 +0000 (-0700) Subject: Add operand type iterators to Operation and cleanup usages of operand->getType... X-Git-Tag: llvmorg-11-init~1466^2~1610 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=06734badbcd68ebc8648c7d660d013b600954bbc;p=platform%2Fupstream%2Fllvm.git Add operand type iterators to Operation and cleanup usages of operand->getType. This also simplifies some lingering usages of result->getType. -- PiperOrigin-RevId: 249889174 --- diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index adfa1f7..c9d52de 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -97,7 +97,7 @@ class LoadOpConversion : public LoadStoreOpConversion { void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); - auto elementType = linalg::convertLinalgType(*op->getResultTypes().begin()); + auto elementType = linalg::convertLinalgType(*op->result_type_begin()); Value *viewDescriptor = operands[0]; ArrayRef indices = operands.drop_front(); Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 907e3f1..cc0d7f9 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -226,8 +226,8 @@ public: // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) { - return llvm::all_of(op->getOperands(), [](mlir::Value *v) { - return !v->getType().cast().isGeneric(); + return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) { + return !ty.cast().isGeneric(); }); }); if (nextop == opWorklist.end()) @@ -308,9 +308,8 @@ public: if (!mangledCallee) { // Can't find the target, this is where we queue the request for the // callee and stop the inference for the current function now. - std::vector funcArgs; - for (auto operand : op->getOperands()) - funcArgs.push_back(operand->getType()); + std::vector funcArgs(op->operand_type_begin(), + op->operand_type_end()); funcWorklist.push_back( {callee, std::move(mangledName), std::move(funcArgs)}); return mlir::success(); diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 5267586..3bd5d26 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -226,8 +226,8 @@ public: // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) { - return llvm::all_of(op->getOperands(), [](mlir::Value *v) { - return !v->getType().cast().isGeneric(); + return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) { + return !ty.cast().isGeneric(); }); }); if (nextop == opWorklist.end()) @@ -312,18 +312,15 @@ public: if (!mangledCallee) { // Can't find the target, this is where we queue the request for the // callee and stop the inference for the current function now. - std::vector funcArgs; - for (auto operand : op->getOperands()) - funcArgs.push_back(operand->getType()); + std::vector funcArgs(op->operand_type_begin(), + op->operand_type_end()); funcWorklist.push_back( {callee, std::move(mangledName), std::move(funcArgs)}); return mlir::success(); } // Found a specialized callee! Let's turn this into a normal call // operation. - SmallVector operands; - for (mlir::Value *v : op->getOperands()) - operands.push_back(v); + SmallVector operands(op->getOperands()); mlir::FuncBuilder builder(f); builder.setInsertionPoint(op); auto newCall = diff --git a/mlir/include/mlir/GPU/GPUDialect.h b/mlir/include/mlir/GPU/GPUDialect.h index 8ace6ff..003336b 100644 --- a/mlir/include/mlir/GPU/GPUDialect.h +++ b/mlir/include/mlir/GPU/GPUDialect.h @@ -78,9 +78,9 @@ public: /// Get the SSA values corresponding to kernel block size. KernelDim3 getBlockSize(); /// Get the operand values passed as kernel arguments. - Operation::operand_range getKernelOperandValues(); - /// Append the operand types passed as kernel arguments to `out`. - void getKernelOperandTypes(SmallVectorImpl &out); + operand_range getKernelOperandValues(); + /// Get the operand types passed as kernel arguments. + operand_type_range getKernelOperandTypes(); /// Get the SSA values passed as operands to specify the grid size. KernelDim3 getGridSizeOperandValues(); diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 472441e..58d19d7 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -328,6 +328,8 @@ template class TraitType> struct MultiOperandTraitBase : public TraitBase { using operand_iterator = Operation::operand_iterator; using operand_range = Operation::operand_range; + using operand_type_iterator = Operation::operand_type_iterator; + using operand_type_range = Operation::operand_type_range; /// Return the number of operands. unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } @@ -346,6 +348,17 @@ struct MultiOperandTraitBase : public TraitBase { } operand_iterator operand_end() { return this->getOperation()->operand_end(); } operand_range getOperands() { return this->getOperation()->getOperands(); } + + /// Operand type access. + operand_type_iterator operand_type_begin() { + return this->getOperation()->operand_type_begin(); + } + operand_type_iterator operand_type_end() { + return this->getOperation()->operand_type_end(); + } + operand_type_range getOperandTypes() { + return this->getOperation()->getOperandTypes(); + } }; } // end namespace detail @@ -447,6 +460,8 @@ template class TraitType> struct MultiResultTraitBase : public TraitBase { using result_iterator = Operation::result_iterator; using result_range = Operation::result_range; + using result_type_iterator = Operation::result_type_iterator; + using result_type_range = Operation::result_type_range; /// Return the number of results. unsigned getNumResults() { return this->getOperation()->getNumResults(); } @@ -468,6 +483,17 @@ struct MultiResultTraitBase : public TraitBase { } result_iterator result_end() { return this->getOperation()->result_end(); } result_range getResults() { return this->getOperation()->getResults(); } + + /// Result type access. + result_type_iterator result_type_begin() { + return this->getOperation()->result_type_begin(); + } + result_type_iterator result_type_end() { + return this->getOperation()->result_type_end(); + } + result_type_range getResultTypes() { + return this->getOperation()->getResultTypes(); + } }; } // end namespace detail @@ -477,7 +503,6 @@ template class OneResult : public TraitBase { public: Value *getResult() { return this->getOperation()->getResult(0); } - Type getType() { return getResult()->getType(); } /// Replace all uses of 'this' value with the new value, updating anything in diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 107bfb8..0a7a2aa 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -34,6 +34,7 @@ class BlockAndValueMapping; class Location; class MLIRContext; class OperandIterator; +class OperandTypeIterator; struct OperationState; class ResultIterator; class ResultTypeIterator; @@ -198,6 +199,13 @@ public: OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; } + // Support operand type iteration. + using operand_type_iterator = OperandTypeIterator; + using operand_type_range = llvm::iterator_range; + operand_type_iterator operand_type_begin(); + operand_type_iterator operand_type_end(); + operand_type_range getOperandTypes(); + //===--------------------------------------------------------------------===// // Results //===--------------------------------------------------------------------===// @@ -226,9 +234,10 @@ public: // Support result type iteration. using result_type_iterator = ResultTypeIterator; + using result_type_range = llvm::iterator_range; result_type_iterator result_type_begin(); result_type_iterator result_type_end(); - llvm::iterator_range getResultTypes(); + result_type_range getResultTypes(); //===--------------------------------------------------------------------===// // Attributes @@ -500,6 +509,19 @@ public: Value *operator*() const { return this->object->getOperand(this->index); } }; +/// This class implements the operand type iterators for the Operation +/// class in terms of operand_iterator->getType(). +class OperandTypeIterator final + : public llvm::mapped_iterator { + static Type unwrap(Value *value) { return value->getType(); } + +public: + /// Initializes the operand type iterator to the specified operand iterator. + OperandTypeIterator(OperandIterator it) + : llvm::mapped_iterator(it, &unwrap) { + } +}; + // Implement the inline operand iterator methods. inline auto Operation::operand_begin() -> operand_iterator { return operand_iterator(this, 0); @@ -513,6 +535,18 @@ inline auto Operation::getOperands() -> operand_range { return {operand_begin(), operand_end()}; } +inline auto Operation::operand_type_begin() -> operand_type_iterator { + return operand_type_iterator(operand_begin()); +} + +inline auto Operation::operand_type_end() -> operand_type_iterator { + return operand_type_iterator(operand_end()); +} + +inline auto Operation::getOperandTypes() -> operand_type_range { + return {operand_type_begin(), operand_type_end()}; +} + /// This class implements the result iterators for the Operation class /// in terms of getResult(idx). class ResultIterator final @@ -559,8 +593,7 @@ inline auto Operation::result_type_end() -> result_type_iterator { return result_type_iterator(result_end()); } -inline auto Operation::getResultTypes() - -> llvm::iterator_range { +inline auto Operation::getResultTypes() -> result_type_range { return {result_type_begin(), result_type_end()}; } diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index 755a2c2..5c0539a 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -99,16 +99,12 @@ KernelDim3 LaunchOp::getBlockSize() { return KernelDim3{args[9], args[10], args[11]}; } -Operation::operand_range LaunchOp::getKernelOperandValues() { - return {getOperation()->operand_begin() + kNumConfigOperands, - getOperation()->operand_end()}; +LaunchOp::operand_range LaunchOp::getKernelOperandValues() { + return llvm::drop_begin(getOperands(), kNumConfigOperands); } -void LaunchOp::getKernelOperandTypes(SmallVectorImpl &out) { - out.reserve(getNumOperands() - kNumConfigOperands + out.size()); - for (unsigned i = kNumConfigOperands; i < getNumOperands(); ++i) { - out.push_back(getOperand(i)->getType()); - } +LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() { + return llvm::drop_begin(getOperandTypes(), kNumConfigOperands); } KernelDim3 LaunchOp::getGridSizeOperandValues() { diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp index 006ba4f..163a7cf 100644 --- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp @@ -60,8 +60,7 @@ void injectGpuIndexOperations(Location loc, Function &kernelFunc) { // Outline the `gpu.launch` operation body into a kernel function. Function *outlineKernelFunc(Module &module, gpu::LaunchOp &launchOp) { Location loc = launchOp.getLoc(); - SmallVector kernelOperandTypes; - launchOp.getKernelOperandTypes(kernelOperandTypes); + SmallVector kernelOperandTypes(launchOp.getKernelOperandTypes()); FunctionType type = FunctionType::get(kernelOperandTypes, {}, module.getContext()); std::string kernelFuncName = diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index bc46b45..71744cf 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -220,10 +220,10 @@ void ModuleState::visitAttribute(Attribute attr) { void ModuleState::visitOperation(Operation *op) { // Visit all the types used in the operation. - for (auto *operand : op->getOperands()) - visitType(operand->getType()); - for (auto *result : op->getResults()) - visitType(result->getType()); + for (auto type : op->getOperandTypes()) + visitType(type); + for (auto type : op->getResultTypes()) + visitType(type); // Visit each of the attributes. for (auto elt : op->getAttrs()) diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 22463f1..582fb39 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -593,11 +593,7 @@ Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper, } } - SmallVector resultTypes; - resultTypes.reserve(getNumResults()); - for (auto *result : getResults()) - resultTypes.push_back(result->getType()); - + SmallVector resultTypes(getResultTypes()); unsigned numRegions = getNumRegions(); auto *newOp = Operation::create(getLoc(), getName(), operands, resultTypes, attrs, successors, numRegions, @@ -718,8 +714,8 @@ static Type getTensorOrVectorElementType(Type type) { } LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) { - for (auto *operand : op->getOperands()) { - auto type = getTensorOrVectorElementType(operand->getType()); + for (auto opType : op->getOperandTypes()) { + auto type = getTensorOrVectorElementType(opType); if (!type.isIntOrIndex()) return op->emitOpError() << "requires an integer or index type"; } @@ -727,8 +723,8 @@ LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) { } LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { - for (auto *operand : op->getOperands()) { - auto type = getTensorOrVectorElementType(operand->getType()); + for (auto opType : op->getOperandTypes()) { + auto type = getTensorOrVectorElementType(opType); if (!type.isa()) return op->emitOpError("requires a float type"); } @@ -742,8 +738,8 @@ LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { return success(); auto type = op->getOperand(0)->getType(); - for (unsigned i = 1; i < nOperands; ++i) - if (op->getOperand(i)->getType() != type) + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) + if (opType != type) return op->emitOpError() << "requires all operands to have the same type"; return success(); } @@ -798,13 +794,13 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { return failure(); auto type = op->getOperand(0)->getType(); - for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { - if (failed(verifyShapeMatch(op->getResult(i)->getType(), type))) + for (auto resultType : op->getResultTypes()) { + if (failed(verifyShapeMatch(resultType, type))) return op->emitOpError() << "requires the same shape for all operands and results"; } - for (unsigned i = 1, e = op->getNumOperands(); i < e; ++i) { - if (failed(verifyShapeMatch(op->getOperand(i)->getType(), type))) + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { + if (failed(verifyShapeMatch(opType, type))) return op->emitOpError() << "requires the same shape for all operands and results"; } @@ -849,13 +845,13 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { return failure(); auto type = op->getResult(0)->getType(); - for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) { - if (op->getResult(i)->getType() != type) + for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) { + if (resultType != type) return op->emitOpError() << "requires the same type for all operands and results"; } - for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) { - if (op->getOperand(i)->getType() != type) + for (auto opType : op->getOperandTypes()) { + if (opType != type) return op->emitOpError() << "requires the same type for all operands and results"; } @@ -905,8 +901,8 @@ LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { } LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { - for (auto *result : op->getResults()) { - auto elementType = getTensorOrVectorElementType(result->getType()); + for (auto resultType : op->getResultTypes()) { + auto elementType = getTensorOrVectorElementType(resultType); bool isBoolType = elementType.isInteger(1); if (!isBoolType) return op->emitOpError() << "requires a bool result type"; @@ -916,19 +912,17 @@ LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { } LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { - for (auto *result : op->getResults()) - if (!getTensorOrVectorElementType(result->getType()).isa()) + for (auto resultType : op->getResultTypes()) + if (!getTensorOrVectorElementType(resultType).isa()) return op->emitOpError() << "requires a floating point type"; return success(); } LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) { - for (auto *result : op->getResults()) { - auto type = getTensorOrVectorElementType(result->getType()); - if (!type.isIntOrIndex()) + for (auto resultType : op->getResultTypes()) + if (!getTensorOrVectorElementType(resultType).isIntOrIndex()) return op->emitOpError() << "requires an integer or index type"; - } return success(); } diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 8b673f37..8490fda 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -177,11 +177,8 @@ static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) { //===----------------------------------------------------------------------===// static void printGEPOp(OpAsmPrinter *p, GEPOp &op) { - SmallVector types; - for (auto *operand : op.getOperands()) - types.push_back(operand->getType()); - auto funcTy = - FunctionType::get(types, op.getResult()->getType(), op.getContext()); + SmallVector types(op.getOperandTypes()); + auto funcTy = FunctionType::get(types, op.getType(), op.getContext()); *p << op.getOperationName() << ' ' << *op.base() << '['; p->printOperands(std::next(op.operand_begin()), op.operand_end()); @@ -326,11 +323,9 @@ static void printCallOp(OpAsmPrinter *p, CallOp &op) { p->printOptionalAttrDict(op.getAttrs(), {"callee"}); // Reconstruct the function MLIR function type from operand and result types. - SmallVector resultTypes(op.getOperation()->getResultTypes()); - SmallVector argTypes; - argTypes.reserve(op.getNumOperands()); - for (auto *operand : llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1)) - argTypes.push_back(operand->getType()); + SmallVector resultTypes(op.getResultTypes()); + SmallVector argTypes( + llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); *p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); } diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index acf0be3..4354f82 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -262,17 +262,6 @@ protected: LLVM::LLVMDialect &dialect; }; -// Given a range of MLIR typed objects, return a list of their types. -template -SmallVector getTypes(llvm::iterator_range range) { - SmallVector types; - types.reserve(llvm::size(range)); - for (auto operand : range) { - types.push_back(operand->getType()); - } - return types; -} - // Basic lowering implementation for one-to-one rewriting from Standard Ops to // LLVM Dialect Ops. template @@ -288,8 +277,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { Type packedType; if (numResults != 0) { - packedType = - this->lowering.packFunctionResults(getTypes(op->getResults())); + packedType = this->lowering.packFunctionResults( + llvm::to_vector<4>(op->getResultTypes())); assert(packedType && "type conversion failed, such operation should not " "have been matched"); } @@ -832,7 +821,8 @@ struct ReturnOpLowering : public LLVMLegalizationPattern { // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. - auto packedType = lowering.packFunctionResults(getTypes(op->getOperands())); + auto packedType = + lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes())); Value *packed = rewriter.create(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index cf8a2cc..f94868d 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -316,7 +316,7 @@ class LoadOpConversion : public LoadStoreOpConversion { void rewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); - auto elementTy = lowering.convertType(*op->getResultTypes().begin()); + auto elementTy = lowering.convertType(*op->result_type_begin()); Value *viewDescriptor = operands[0]; ArrayRef indices = operands.drop_front(); auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index f05b0cf..fd9d57f 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -284,8 +284,8 @@ static LogicalResult verify(AllocOp op) { "operand count does not equal dimension plus symbol operand count"); // Verify that all operands are of type Index. - for (auto *operand : op.getOperands()) - if (!operand->getType().isIndex()) + for (auto operandType : op.getOperandTypes()) + if (!operandType.isIndex()) return op.emitOpError("requires operands to be of type Index"); return success(); } @@ -475,11 +475,8 @@ static LogicalResult verify(CallOp op) { } FunctionType CallOp::getCalleeType() { - SmallVector resultTypes(getOperation()->getResultTypes()); - SmallVector argTypes; - argTypes.reserve(getNumOperands()); - for (auto *operand : getArgOperands()) - argTypes.push_back(operand->getType()); + SmallVector resultTypes(getResultTypes()); + SmallVector argTypes(getOperandTypes()); return FunctionType::get(argTypes, resultTypes, getContext()); }