void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
- auto elementType = linalg::convertLinalgType(*op->getResultTypes().begin());
+ auto elementType = linalg::convertLinalgType(*op->result_type_begin());
Value *viewDescriptor = operands[0];
ArrayRef<Value *> indices = operands.drop_front();
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
- return llvm::all_of(op->getOperands(), [](mlir::Value *v) {
- return !v->getType().cast<ToyArrayType>().isGeneric();
+ return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
+ return !ty.cast<ToyArrayType>().isGeneric();
});
});
if (nextop == opWorklist.end())
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
- std::vector<mlir::Type> funcArgs;
- for (auto operand : op->getOperands())
- funcArgs.push_back(operand->getType());
+ std::vector<mlir::Type> funcArgs(op->operand_type_begin(),
+ op->operand_type_end());
funcWorklist.push_back(
{callee, std::move(mangledName), std::move(funcArgs)});
return mlir::success();
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
- return llvm::all_of(op->getOperands(), [](mlir::Value *v) {
- return !v->getType().cast<ToyArrayType>().isGeneric();
+ return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
+ return !ty.cast<ToyArrayType>().isGeneric();
});
});
if (nextop == opWorklist.end())
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
- std::vector<mlir::Type> funcArgs;
- for (auto operand : op->getOperands())
- funcArgs.push_back(operand->getType());
+ std::vector<mlir::Type> funcArgs(op->operand_type_begin(),
+ op->operand_type_end());
funcWorklist.push_back(
{callee, std::move(mangledName), std::move(funcArgs)});
return mlir::success();
}
// Found a specialized callee! Let's turn this into a normal call
// operation.
- SmallVector<mlir::Value *, 8> operands;
- for (mlir::Value *v : op->getOperands())
- operands.push_back(v);
+ SmallVector<mlir::Value *, 8> operands(op->getOperands());
mlir::FuncBuilder builder(f);
builder.setInsertionPoint(op);
auto newCall =
/// Get the SSA values corresponding to kernel block size.
KernelDim3 getBlockSize();
/// Get the operand values passed as kernel arguments.
- Operation::operand_range getKernelOperandValues();
- /// Append the operand types passed as kernel arguments to `out`.
- void getKernelOperandTypes(SmallVectorImpl<Type> &out);
+ operand_range getKernelOperandValues();
+ /// Get the operand types passed as kernel arguments.
+ operand_type_range getKernelOperandTypes();
/// Get the SSA values passed as operands to specify the grid size.
KernelDim3 getGridSizeOperandValues();
struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
using operand_iterator = Operation::operand_iterator;
using operand_range = Operation::operand_range;
+ using operand_type_iterator = Operation::operand_type_iterator;
+ using operand_type_range = Operation::operand_type_range;
/// Return the number of operands.
unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
}
operand_iterator operand_end() { return this->getOperation()->operand_end(); }
operand_range getOperands() { return this->getOperation()->getOperands(); }
+
+ /// Operand type access.
+ operand_type_iterator operand_type_begin() {
+ return this->getOperation()->operand_type_begin();
+ }
+ operand_type_iterator operand_type_end() {
+ return this->getOperation()->operand_type_end();
+ }
+ operand_type_range getOperandTypes() {
+ return this->getOperation()->getOperandTypes();
+ }
};
} // end namespace detail
struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
using result_iterator = Operation::result_iterator;
using result_range = Operation::result_range;
+ using result_type_iterator = Operation::result_type_iterator;
+ using result_type_range = Operation::result_type_range;
/// Return the number of results.
unsigned getNumResults() { return this->getOperation()->getNumResults(); }
}
result_iterator result_end() { return this->getOperation()->result_end(); }
result_range getResults() { return this->getOperation()->getResults(); }
+
+ /// Result type access.
+ result_type_iterator result_type_begin() {
+ return this->getOperation()->result_type_begin();
+ }
+ result_type_iterator result_type_end() {
+ return this->getOperation()->result_type_end();
+ }
+ result_type_range getResultTypes() {
+ return this->getOperation()->getResultTypes();
+ }
};
} // end namespace detail
class OneResult : public TraitBase<ConcreteType, OneResult> {
public:
Value *getResult() { return this->getOperation()->getResult(0); }
-
Type getType() { return getResult()->getType(); }
/// Replace all uses of 'this' value with the new value, updating anything in
class Location;
class MLIRContext;
class OperandIterator;
+class OperandTypeIterator;
struct OperationState;
class ResultIterator;
class ResultTypeIterator;
OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; }
+ // Support operand type iteration.
+ using operand_type_iterator = OperandTypeIterator;
+ using operand_type_range = llvm::iterator_range<operand_type_iterator>;
+ operand_type_iterator operand_type_begin();
+ operand_type_iterator operand_type_end();
+ operand_type_range getOperandTypes();
+
//===--------------------------------------------------------------------===//
// Results
//===--------------------------------------------------------------------===//
// Support result type iteration.
using result_type_iterator = ResultTypeIterator;
+ using result_type_range = llvm::iterator_range<result_type_iterator>;
result_type_iterator result_type_begin();
result_type_iterator result_type_end();
- llvm::iterator_range<result_type_iterator> getResultTypes();
+ result_type_range getResultTypes();
//===--------------------------------------------------------------------===//
// Attributes
Value *operator*() const { return this->object->getOperand(this->index); }
};
+/// This class implements the operand type iterators for the Operation
+/// class in terms of operand_iterator->getType().
+class OperandTypeIterator final
+ : public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
+ static Type unwrap(Value *value) { return value->getType(); }
+
+public:
+ /// Initializes the operand type iterator to the specified operand iterator.
+ OperandTypeIterator(OperandIterator it)
+ : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {
+ }
+};
+
// Implement the inline operand iterator methods.
inline auto Operation::operand_begin() -> operand_iterator {
return operand_iterator(this, 0);
return {operand_begin(), operand_end()};
}
+inline auto Operation::operand_type_begin() -> operand_type_iterator {
+ return operand_type_iterator(operand_begin());
+}
+
+inline auto Operation::operand_type_end() -> operand_type_iterator {
+ return operand_type_iterator(operand_end());
+}
+
+inline auto Operation::getOperandTypes() -> operand_type_range {
+ return {operand_type_begin(), operand_type_end()};
+}
+
/// This class implements the result iterators for the Operation class
/// in terms of getResult(idx).
class ResultIterator final
return result_type_iterator(result_end());
}
-inline auto Operation::getResultTypes()
- -> llvm::iterator_range<result_type_iterator> {
+inline auto Operation::getResultTypes() -> result_type_range {
return {result_type_begin(), result_type_end()};
}
return KernelDim3{args[9], args[10], args[11]};
}
-Operation::operand_range LaunchOp::getKernelOperandValues() {
- return {getOperation()->operand_begin() + kNumConfigOperands,
- getOperation()->operand_end()};
+LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
+ return llvm::drop_begin(getOperands(), kNumConfigOperands);
}
-void LaunchOp::getKernelOperandTypes(SmallVectorImpl<Type> &out) {
- out.reserve(getNumOperands() - kNumConfigOperands + out.size());
- for (unsigned i = kNumConfigOperands; i < getNumOperands(); ++i) {
- out.push_back(getOperand(i)->getType());
- }
+LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
+ return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
}
KernelDim3 LaunchOp::getGridSizeOperandValues() {
// Outline the `gpu.launch` operation body into a kernel function.
Function *outlineKernelFunc(Module &module, gpu::LaunchOp &launchOp) {
Location loc = launchOp.getLoc();
- SmallVector<Type, 4> kernelOperandTypes;
- launchOp.getKernelOperandTypes(kernelOperandTypes);
+ SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
FunctionType type =
FunctionType::get(kernelOperandTypes, {}, module.getContext());
std::string kernelFuncName =
void ModuleState::visitOperation(Operation *op) {
// Visit all the types used in the operation.
- for (auto *operand : op->getOperands())
- visitType(operand->getType());
- for (auto *result : op->getResults())
- visitType(result->getType());
+ for (auto type : op->getOperandTypes())
+ visitType(type);
+ for (auto type : op->getResultTypes())
+ visitType(type);
// Visit each of the attributes.
for (auto elt : op->getAttrs())
}
}
- SmallVector<Type, 8> resultTypes;
- resultTypes.reserve(getNumResults());
- for (auto *result : getResults())
- resultTypes.push_back(result->getType());
-
+ SmallVector<Type, 8> resultTypes(getResultTypes());
unsigned numRegions = getNumRegions();
auto *newOp = Operation::create(getLoc(), getName(), operands, resultTypes,
attrs, successors, numRegions,
}
LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) {
- for (auto *operand : op->getOperands()) {
- auto type = getTensorOrVectorElementType(operand->getType());
+ for (auto opType : op->getOperandTypes()) {
+ auto type = getTensorOrVectorElementType(opType);
if (!type.isIntOrIndex())
return op->emitOpError() << "requires an integer or index type";
}
}
LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
- for (auto *operand : op->getOperands()) {
- auto type = getTensorOrVectorElementType(operand->getType());
+ for (auto opType : op->getOperandTypes()) {
+ auto type = getTensorOrVectorElementType(opType);
if (!type.isa<FloatType>())
return op->emitOpError("requires a float type");
}
return success();
auto type = op->getOperand(0)->getType();
- for (unsigned i = 1; i < nOperands; ++i)
- if (op->getOperand(i)->getType() != type)
+ for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
+ if (opType != type)
return op->emitOpError() << "requires all operands to have the same type";
return success();
}
return failure();
auto type = op->getOperand(0)->getType();
- for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) {
- if (failed(verifyShapeMatch(op->getResult(i)->getType(), type)))
+ for (auto resultType : op->getResultTypes()) {
+ if (failed(verifyShapeMatch(resultType, type)))
return op->emitOpError()
<< "requires the same shape for all operands and results";
}
- for (unsigned i = 1, e = op->getNumOperands(); i < e; ++i) {
- if (failed(verifyShapeMatch(op->getOperand(i)->getType(), type)))
+ for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
+ if (failed(verifyShapeMatch(opType, type)))
return op->emitOpError()
<< "requires the same shape for all operands and results";
}
return failure();
auto type = op->getResult(0)->getType();
- for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
- if (op->getResult(i)->getType() != type)
+ for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
+ if (resultType != type)
return op->emitOpError()
<< "requires the same type for all operands and results";
}
- for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
- if (op->getOperand(i)->getType() != type)
+ for (auto opType : op->getOperandTypes()) {
+ if (opType != type)
return op->emitOpError()
<< "requires the same type for all operands and results";
}
}
LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
- for (auto *result : op->getResults()) {
- auto elementType = getTensorOrVectorElementType(result->getType());
+ for (auto resultType : op->getResultTypes()) {
+ auto elementType = getTensorOrVectorElementType(resultType);
bool isBoolType = elementType.isInteger(1);
if (!isBoolType)
return op->emitOpError() << "requires a bool result type";
}
LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
- for (auto *result : op->getResults())
- if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
+ for (auto resultType : op->getResultTypes())
+ if (!getTensorOrVectorElementType(resultType).isa<FloatType>())
return op->emitOpError() << "requires a floating point type";
return success();
}
LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) {
- for (auto *result : op->getResults()) {
- auto type = getTensorOrVectorElementType(result->getType());
- if (!type.isIntOrIndex())
+ for (auto resultType : op->getResultTypes())
+ if (!getTensorOrVectorElementType(resultType).isIntOrIndex())
return op->emitOpError() << "requires an integer or index type";
- }
return success();
}
//===----------------------------------------------------------------------===//
static void printGEPOp(OpAsmPrinter *p, GEPOp &op) {
- SmallVector<Type, 8> types;
- for (auto *operand : op.getOperands())
- types.push_back(operand->getType());
- auto funcTy =
- FunctionType::get(types, op.getResult()->getType(), op.getContext());
+ SmallVector<Type, 8> types(op.getOperandTypes());
+ auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
*p << op.getOperationName() << ' ' << *op.base() << '[';
p->printOperands(std::next(op.operand_begin()), op.operand_end());
p->printOptionalAttrDict(op.getAttrs(), {"callee"});
// Reconstruct the function MLIR function type from operand and result types.
- SmallVector<Type, 1> resultTypes(op.getOperation()->getResultTypes());
- SmallVector<Type, 8> argTypes;
- argTypes.reserve(op.getNumOperands());
- for (auto *operand : llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1))
- argTypes.push_back(operand->getType());
+ SmallVector<Type, 1> resultTypes(op.getResultTypes());
+ SmallVector<Type, 8> argTypes(
+ llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
*p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
}
LLVM::LLVMDialect &dialect;
};
-// Given a range of MLIR typed objects, return a list of their types.
-template <typename T>
-SmallVector<Type, 4> getTypes(llvm::iterator_range<T> range) {
- SmallVector<Type, 4> types;
- types.reserve(llvm::size(range));
- for (auto operand : range) {
- types.push_back(operand->getType());
- }
- return types;
-}
-
// Basic lowering implementation for one-to-one rewriting from Standard Ops to
// LLVM Dialect Ops.
template <typename SourceOp, typename TargetOp>
Type packedType;
if (numResults != 0) {
- packedType =
- this->lowering.packFunctionResults(getTypes(op->getResults()));
+ packedType = this->lowering.packFunctionResults(
+ llvm::to_vector<4>(op->getResultTypes()));
assert(packedType && "type conversion failed, such operation should not "
"have been matched");
}
// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
- auto packedType = lowering.packFunctionResults(getTypes(op->getOperands()));
+ auto packedType =
+ lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
for (unsigned i = 0; i < numArguments; ++i) {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
- auto elementTy = lowering.convertType(*op->getResultTypes().begin());
+ auto elementTy = lowering.convertType(*op->result_type_begin());
Value *viewDescriptor = operands[0];
ArrayRef<Value *> indices = operands.drop_front();
auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
"operand count does not equal dimension plus symbol operand count");
// Verify that all operands are of type Index.
- for (auto *operand : op.getOperands())
- if (!operand->getType().isIndex())
+ for (auto operandType : op.getOperandTypes())
+ if (!operandType.isIndex())
return op.emitOpError("requires operands to be of type Index");
return success();
}
}
FunctionType CallOp::getCalleeType() {
- SmallVector<Type, 4> resultTypes(getOperation()->getResultTypes());
- SmallVector<Type, 8> argTypes;
- argTypes.reserve(getNumOperands());
- for (auto *operand : getArgOperands())
- argTypes.push_back(operand->getType());
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getOperandTypes());
return FunctionType::get(argTypes, resultTypes, getContext());
}