ArrayRef<StringRef> generatedNames = {})
: OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
benefit, context, generatedNames) {}
+ /// Generic adaptor around the root op of this pattern using the converted
+ /// operands. Importantly, each operand is represented as a *range* of values,
+ /// namely the N values each original operand gets converted to. Concretely,
+ /// this makes the result type of the accessor functions of the adaptor class
+ /// be a `ValueRange`.
+ class OpAdaptor
+ : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
+ public:
+ using RangeT = ArrayRef<ValueRange>;
+ using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
+
+ OpAdaptor(const OneToNTypeMapping *operandMapping,
+ const OneToNTypeMapping *resultMapping,
+ const ValueRange *convertedOperands, RangeT values,
+ DictionaryAttr attrs = nullptr, RegionRange regions = {})
+ : BaseT(values, attrs, regions), operandMapping(operandMapping),
+ resultMapping(resultMapping), convertedOperands(convertedOperands) {}
+
+ /// Get the type mapping of the original operands to the converted operands.
+ const OneToNTypeMapping &getOperandMapping() const {
+ return *operandMapping;
+ }
+
+ /// Get the type mapping of the original results to the converted results.
+ const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
+
+ /// Get a flat range of all converted operands. Unlike `getOperands`, which
+ /// returns an `ArrayRef` with one `ValueRange` for each original operand,
+ /// this function returns a `ValueRange` that contains all converted
+ /// operands irrespectively of which operand they originated from.
+ ValueRange getFlatOperands() const { return *convertedOperands; }
+
+ private:
+ const OneToNTypeMapping *operandMapping;
+ const OneToNTypeMapping *resultMapping;
+ const ValueRange *convertedOperands;
+ };
using OneToNConversionPattern::matchAndRewrite;
/// Overload that derived classes have to override for their op type.
- virtual LogicalResult matchAndRewrite(SourceOp op,
- OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const = 0;
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const = 0;
LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
const OneToNTypeMapping &operandMapping,
const OneToNTypeMapping &resultMapping,
ValueRange convertedOperands) const final {
- return matchAndRewrite(cast<SourceOp>(op), rewriter, operandMapping,
- resultMapping, convertedOperands);
+ // Wrap converted operands and type mappings into an adaptor.
+ SmallVector<ValueRange> valueRanges;
+ for (int64_t i = 0; i < op->getNumOperands(); i++) {
+ auto values = operandMapping.getConvertedValues(convertedOperands, i);
+ valueRanges.push_back(values);
+ }
+ OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
+ valueRanges, op->getAttrDictionary(), op->getRegions());
+
+ // Call overload implemented by the derived class.
+ return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
}
};
public:
using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
- LogicalResult matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const override {
+ LogicalResult
+ matchAndRewrite(CallOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
// Nothing to do if the op doesn't have any non-identity conversions for its
// operands or results.
- if (!operandMapping.hasNonIdentityConversion() &&
+ if (!adaptor.getOperandMapping().hasNonIdentityConversion() &&
!resultMapping.hasNonIdentityConversion())
return failure();
// Create new CallOp.
auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
- convertedOperands);
+ adaptor.getFlatOperands());
newOp->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newOp->getResults(), resultMapping);
using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
LogicalResult
- matchAndRewrite(FuncOp op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping & /*operandMapping*/,
- const OneToNTypeMapping & /*resultMapping*/,
- ValueRange /*convertedOperands*/) const override {
+ matchAndRewrite(FuncOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
// Construct mapping for function arguments.
public:
using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
- LogicalResult matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping & /*resultMapping*/,
- ValueRange convertedOperands) const override {
+ LogicalResult
+ matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
// Nothing to do if there is no non-identity conversion.
- if (!operandMapping.hasNonIdentityConversion())
+ if (!adaptor.getOperandMapping().hasNonIdentityConversion())
return failure();
// Convert operands.
- rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+ rewriter.updateRootInPlace(
+ op, [&] { op->setOperands(adaptor.getFlatOperands()); });
return success();
}
using OneToNOpConversionPattern<IfOp>::OneToNOpConversionPattern;
LogicalResult
- matchAndRewrite(IfOp op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping & /*operandMapping*/,
- const OneToNTypeMapping &resultMapping,
- const ValueRange /*convertedOperands*/) const override {
+ matchAndRewrite(IfOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
// Nothing to do if there is no non-identity conversion.
if (!resultMapping.hasNonIdentityConversion())
using OneToNOpConversionPattern<WhileOp>::OneToNOpConversionPattern;
LogicalResult
- matchAndRewrite(WhileOp op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- const ValueRange convertedOperands) const override {
+ matchAndRewrite(WhileOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
+ const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+
// Nothing to do if the op doesn't have any non-identity conversions for its
// operands or results.
if (!operandMapping.hasNonIdentityConversion() &&
// Create new WhileOp.
TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
- auto newOp =
- rewriter.create<WhileOp>(loc, convertedResultTypes, convertedOperands);
+ auto newOp = rewriter.create<WhileOp>(loc, convertedResultTypes,
+ adaptor.getFlatOperands());
newOp->setAttrs(op->getAttrs());
// Update block signatures.
using OneToNOpConversionPattern<YieldOp>::OneToNOpConversionPattern;
LogicalResult
- matchAndRewrite(YieldOp op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping & /*resultMapping*/,
- const ValueRange convertedOperands) const override {
+ matchAndRewrite(YieldOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
// Nothing to do if there is no non-identity conversion.
- if (!operandMapping.hasNonIdentityConversion())
+ if (!adaptor.getOperandMapping().hasNonIdentityConversion())
return failure();
// Convert operands.
- rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+ rewriter.updateRootInPlace(
+ op, [&] { op->setOperands(adaptor.getFlatOperands()); });
return success();
}
using OneToNOpConversionPattern<ConditionOp>::OneToNOpConversionPattern;
LogicalResult
- matchAndRewrite(ConditionOp op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping & /*resultMapping*/,
- const ValueRange convertedOperands) const override {
+ matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
// Nothing to do if there is no non-identity conversion.
- if (!operandMapping.hasNonIdentityConversion())
+ if (!adaptor.getOperandMapping().hasNonIdentityConversion())
return failure();
// Convert operands.
- rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+ rewriter.updateRootInPlace(
+ op, [&] { op->setOperands(adaptor.getFlatOperands()); });
return success();
}
using OneToNOpConversionPattern<
::test::MakeTupleOp>::OneToNOpConversionPattern;
- LogicalResult matchAndRewrite(::test::MakeTupleOp op,
- OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const override {
+ LogicalResult
+ matchAndRewrite(::test::MakeTupleOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
// Simply replace the current op with the converted operands.
- rewriter.replaceOp(op, convertedOperands, resultMapping);
+ rewriter.replaceOp(op, adaptor.getFlatOperands(),
+ adaptor.getResultMapping());
return success();
}
};
using OneToNOpConversionPattern<
::test::GetTupleElementOp>::OneToNOpConversionPattern;
- LogicalResult matchAndRewrite(::test::GetTupleElementOp op,
- OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const override {
+ LogicalResult
+ matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
// Construct mapping for tuple element types.
auto stateType = op->getOperand(0).getType().cast<TupleType>();
TypeRange originalElementTypes = stateType.getTypes();
return failure();
// Compute converted operands corresponding to original input tuple.
- ValueRange convertedTuple =
- operandMapping.getConvertedValues(convertedOperands, 0);
+ assert(adaptor.getOperands().size() == 1 &&
+ "expected 'get_tuple_element' to have one operand");
+ ValueRange convertedTuple = adaptor.getOperands()[0];
- // Got those converted operands that correspond to the index-th element of
+ // Got those converted operands that correspond to the index-th element ofq
// the original input tuple.
size_t index = op.getIndex();
ValueRange extractedElement =
elementMapping.getConvertedValues(convertedTuple, index);
- rewriter.replaceOp(op, extractedElement, resultMapping);
+ rewriter.replaceOp(op, extractedElement, adaptor.getResultMapping());
return success();
}