Summary: The new internal representation of operation results now allows for accessing the result types to be more efficient. Changing the API to ArrayRef is more efficient and removes the need to explicitly materialize vectors in several places.
Differential Revision: https://reviews.llvm.org/D73429
interleaveComma(types, p);
return p;
}
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
+ interleaveComma(types, p);
+ return p;
+}
//===----------------------------------------------------------------------===//
// OpAsmParser
/// Support result type iteration.
using result_type_iterator = result_range::type_iterator;
- using result_type_range = iterator_range<result_type_iterator>;
- result_type_iterator result_type_begin() { return result_begin(); }
- result_type_iterator result_type_end() { return result_end(); }
- result_type_range getResultTypes() { return getResults().getTypes(); }
+ using result_type_range = ArrayRef<Type>;
+ result_type_iterator result_type_begin() { return getResultTypes().begin(); }
+ result_type_iterator result_type_end() { return getResultTypes().end(); }
+ result_type_range getResultTypes();
//===--------------------------------------------------------------------===//
// Attributes
ResultRange(Operation *op);
/// Returns the types of the values within this range.
- using type_iterator = ValueTypeIterator<iterator>;
- iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
+ using type_iterator = ArrayRef<Type>::iterator;
+ ArrayRef<Type> getTypes() const;
private:
/// See `indexed_accessor_range` for details.
op->getOperands(), op->getAttrs(),
op->getRegions(), inferedReturnTypes)))
return failure();
- SmallVector<Type, 4> resultTypes(op->getResultTypes());
- if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
+ if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes,
+ op->getResultTypes()))
return op->emitOpError(
"inferred type incompatible with return type of operation");
return success();
Type packedType;
if (numResults != 0) {
- packedType = this->lowering.packFunctionResults(
- llvm::to_vector<4>(op->getResultTypes()));
+ packedType = this->lowering.packFunctionResults(op->getResultTypes());
if (!packedType)
return this->matchFailure();
}
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
// Reconstruct the function MLIR function type from operand and result types.
- SmallVector<Type, 1> resultTypes(op.getResultTypes());
SmallVector<Type, 8> argTypes(
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
- p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
+ p << " : "
+ << FunctionType::get(argTypes, op.getResultTypes(), op.getContext());
}
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
SmallVector<Type, 4> argTypes(functionCallOp.getOperandTypes());
- SmallVector<Type, 1> resultTypes(functionCallOp.getResultTypes());
- Type functionType =
- FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
+ Type functionType = FunctionType::get(
+ argTypes, functionCallOp.getResultTypes(), functionCallOp.getContext());
printer << spirv::FunctionCallOp::getOperationName() << ' '
<< functionCallOp.getAttr(kCallee) << '('
auto funcName = op.callee();
uint32_t resTypeID = 0;
- SmallVector<Type, 1> resultTypes(op.getResultTypes());
- if (failed(processType(op.getLoc(),
- (resultTypes.empty() ? getVoidType() : resultTypes[0]),
- resTypeID))) {
+ Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
+ if (failed(processType(op.getLoc(), resultTy, resTypeID)))
return failure();
- }
auto funcID = getOrCreateFunctionID(funcName);
auto funcCallID = getNextID();
operands.push_back(valueID);
}
- if (!resultTypes.empty()) {
+ if (!resultTy.isa<NoneType>())
valueIDMap[op.getResult(0)] = funcCallID;
- }
return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall,
operands);
}
FunctionType CallOp::getCalleeType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
SmallVector<Type, 8> argTypes(getOperandTypes());
- return FunctionType::get(argTypes, resultTypes, getContext());
+ return FunctionType::get(argTypes, getResultTypes(), getContext());
}
//===----------------------------------------------------------------------===//
return matchFailure();
// Replace with a direct call.
- SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
- rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, callResults,
+ rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
+ indirectCall.getResultTypes(),
indirectCall.getArgOperands());
return matchSuccess();
}
return hasSingleResult ? 1 : resultType.cast<TupleType>().size();
}
+auto Operation::getResultTypes() -> result_type_range {
+ if (!resultType)
+ return llvm::None;
+ if (hasSingleResult)
+ return resultType;
+ return resultType.cast<TupleType>().getTypes();
+}
+
void Operation::setSuccessor(Block *block, unsigned index) {
assert(index < getNumSuccessors());
getBlockOperands()[index].set(block);
}
}
- SmallVector<Type, 8> resultTypes(getResultTypes());
unsigned numRegions = getNumRegions();
auto *newOp =
- Operation::create(getLoc(), getName(), resultTypes, operands, attrs,
+ Operation::create(getLoc(), getName(), getResultTypes(), operands, attrs,
successors, numRegions, hasResizableOperandsList());
// Remember the mapping of any results.
auto type = op->getResult(0).getType();
auto elementType = getElementTypeOrSelf(type);
- for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
+ for (auto resultType : op->getResultTypes().drop_front(1)) {
if (getElementTypeOrSelf(resultType) != elementType ||
failed(verifyCompatibleShape(resultType, type)))
return op->emitOpError()
ResultRange::ResultRange(Operation *op)
: ResultRange(op, /*startIndex=*/0, op->getNumResults()) {}
+ArrayRef<Type> ResultRange::getTypes() const {
+ return getBase()->getResultTypes();
+}
+
/// See `indexed_accessor_range` for details.
OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) {
return op->getResult(index);
// - Attributes
// - Result Types
// - Operands
- return hash_combine(
- op->getName(), op->getAttrList().getDictionary(),
- hash_combine_range(op->result_type_begin(), op->result_type_end()),
- hash_combine_range(op->operand_begin(), op->operand_end()));
+ return llvm::hash_combine(
+ op->getName(), op->getAttrList().getDictionary(), op->getResultTypes(),
+ llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
}
static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
auto *lhs = const_cast<Operation *>(lhsC);
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is I32, change the type to F32.
- if (!(*op->result_type_begin()).isInteger(32))
+ if (!Type(*op->result_type_begin()).isInteger(32))
return matchFailure();
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
return matchSuccess();
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is F32, change the type to F64.
- if (!(*op->result_type_begin()).isF32())
+ if (!Type(*op->result_type_begin()).isF32())
return matchFailure();
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
return matchSuccess();
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
- SmallVector<Type, 1> resultTypes(op.getResultTypes());
- rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, resultTypes,
+ rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
remappedOperands);
return matchSuccess();
}