/// Conversion pattern for a vector.matrix_multiply.
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
-class VectorMatmulOpConversion : public ConvertToLLVMPattern {
+class VectorMatmulOpConversion
+ : public ConvertOpToLLVMPattern<vector::MatmulOp> {
public:
- explicit VectorMatmulOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto matmulOp = cast<vector::MatmulOp>(op);
auto adaptor = vector::MatmulOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
- op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(),
- adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
- matmulOp.rhs_columns());
+ matmulOp, typeConverter->convertType(matmulOp.res().getType()),
+ adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
+ matmulOp.lhs_columns(), matmulOp.rhs_columns());
return success();
}
};
/// Conversion pattern for a vector.flat_transpose.
/// This is lowered directly to the proper llvm.intr.matrix.transpose.
-class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
+class VectorFlatTransposeOpConversion
+ : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
public:
- explicit VectorFlatTransposeOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
- context, typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto transOp = cast<vector::FlatTransposeOp>(op);
auto adaptor = vector::FlatTransposeOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
transOp, typeConverter->convertType(transOp.res().getType()),
};
/// Conversion pattern for a vector.maskedload.
-class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
+class VectorMaskedLoadOpConversion
+ : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
public:
- explicit VectorMaskedLoadOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto load = cast<vector::MaskedLoadOp>(op);
+ auto loc = load->getLoc();
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
// Resolve alignment.
};
/// Conversion pattern for a vector.maskedstore.
-class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
+class VectorMaskedStoreOpConversion
+ : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
public:
- explicit VectorMaskedStoreOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto store = cast<vector::MaskedStoreOp>(op);
+ auto loc = store->getLoc();
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
// Resolve alignment.
};
/// Conversion pattern for a vector.gather.
-class VectorGatherOpConversion : public ConvertToLLVMPattern {
+class VectorGatherOpConversion
+ : public ConvertOpToLLVMPattern<vector::GatherOp> {
public:
- explicit VectorGatherOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto gather = cast<vector::GatherOp>(op);
+ auto loc = gather->getLoc();
auto adaptor = vector::GatherOpAdaptor(operands);
// Resolve alignment.
};
/// Conversion pattern for a vector.scatter.
-class VectorScatterOpConversion : public ConvertToLLVMPattern {
+class VectorScatterOpConversion
+ : public ConvertOpToLLVMPattern<vector::ScatterOp> {
public:
- explicit VectorScatterOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto scatter = cast<vector::ScatterOp>(op);
+ auto loc = scatter->getLoc();
auto adaptor = vector::ScatterOpAdaptor(operands);
// Resolve alignment.
};
/// Conversion pattern for a vector.expandload.
-class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
+class VectorExpandLoadOpConversion
+ : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
public:
- explicit VectorExpandLoadOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto expand = cast<vector::ExpandLoadOp>(op);
+ auto loc = expand->getLoc();
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
Value ptr;
auto vType = expand.getResultVectorType();
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
- op, typeConverter->convertType(vType), ptr, adaptor.mask(),
+ expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
adaptor.pass_thru());
return success();
}
};
/// Conversion pattern for a vector.compressstore.
-class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
+class VectorCompressStoreOpConversion
+ : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
public:
- explicit VectorCompressStoreOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
- context, typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto compress = cast<vector::CompressStoreOp>(op);
+ auto loc = compress->getLoc();
auto adaptor = vector::CompressStoreOpAdaptor(operands);
Value ptr;
return failure();
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
- op, adaptor.value(), ptr, adaptor.mask());
+ compress, adaptor.value(), ptr, adaptor.mask());
return success();
}
};
/// Conversion pattern for all vector reductions.
-class VectorReductionOpConversion : public ConvertToLLVMPattern {
+class VectorReductionOpConversion
+ : public ConvertOpToLLVMPattern<vector::ReductionOp> {
public:
- explicit VectorReductionOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter,
+ explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
bool reassociateFPRed)
- : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
- typeConverter),
+ : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
reassociateFPReductions(reassociateFPRed) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto reductionOp = cast<vector::ReductionOp>(op);
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
Type llvmType = typeConverter->convertType(eltType);
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "mul")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "min" &&
(eltType.isIndex() || eltType.isUnsignedInteger()))
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "min")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "max" &&
(eltType.isIndex() || eltType.isUnsignedInteger()))
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "max")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "and")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "or")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "xor")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else
return failure();
return success();
// Optional accumulator (or zero).
Value acc = operands.size() > 1 ? operands[1]
: rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmType,
+ reductionOp->getLoc(), llvmType,
rewriter.getZeroAttr(eltType));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
- op, llvmType, acc, operands[0],
+ reductionOp, llvmType, acc, operands[0],
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "mul") {
// Optional accumulator (or one).
Value acc = operands.size() > 1
? operands[1]
: rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmType,
+ reductionOp->getLoc(), llvmType,
rewriter.getFloatAttr(eltType, 1.0));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
- op, llvmType, acc, operands[0],
+ reductionOp, llvmType, acc, operands[0],
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "min")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
+ reductionOp, llvmType, operands[0]);
else if (kind == "max")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
+ reductionOp, llvmType, operands[0]);
else
return failure();
return success();
};
/// Conversion pattern for a vector.create_mask (1-D only).
-class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
+class VectorCreateMaskOpConversion
+ : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
public:
- explicit VectorCreateMaskOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter,
+ explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
bool enableIndexOpt)
- : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
- typeConverter),
+ : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = op->getResult(0).getType().cast<VectorType>();
int64_t rank = dstType.getRank();
const bool enableIndexOptimizations;
};
-class VectorShuffleOpConversion : public ConvertToLLVMPattern {
+class VectorShuffleOpConversion
+ : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
public:
- explicit VectorShuffleOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ auto loc = shuffleOp->getLoc();
auto adaptor = vector::ShuffleOpAdaptor(operands);
- auto shuffleOp = cast<vector::ShuffleOp>(op);
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
// For rank 1, where both operands have *exactly* the same vector type,
// there is direct shuffle support in LLVM. Use it!
if (rank == 1 && v1Type == v2Type) {
- Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
+ Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
- rewriter.replaceOp(op, shuffle);
+ rewriter.replaceOp(shuffleOp, llvmShuffleOp);
return success();
}
insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
llvmType, rank, insPos++);
}
- rewriter.replaceOp(op, insert);
+ rewriter.replaceOp(shuffleOp, insert);
return success();
}
};
-class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
+class VectorExtractElementOpConversion
+ : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
public:
- explicit VectorExtractElementOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
- context, typeConverter) {}
+ using ConvertOpToLLVMPattern<
+ vector::ExtractElementOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ExtractElementOp extractEltOp,
+ ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpAdaptor(operands);
- auto extractEltOp = cast<vector::ExtractElementOp>(op);
auto vectorType = extractEltOp.getVectorType();
auto llvmType = typeConverter->convertType(vectorType.getElementType());
return failure();
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- op, llvmType, adaptor.vector(), adaptor.position());
+ extractEltOp, llvmType, adaptor.vector(), adaptor.position());
return success();
}
};
-class VectorExtractOpConversion : public ConvertToLLVMPattern {
+class VectorExtractOpConversion
+ : public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
- explicit VectorExtractOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ auto loc = extractOp->getLoc();
auto adaptor = vector::ExtractOpAdaptor(operands);
- auto extractOp = cast<vector::ExtractOp>(op);
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
if (resultType.isa<VectorType>()) {
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
- rewriter.replaceOp(op, extracted);
+ rewriter.replaceOp(extractOp, extracted);
return success();
}
// Potential extraction of 1-D vector from array.
- auto *context = op->getContext();
+ auto *context = extractOp->getContext();
Value extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
- rewriter.replaceOp(op, extracted);
+ rewriter.replaceOp(extractOp, extracted);
return success();
}
/// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
/// -> !llvm<"<8 x float>">
/// ```
-class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
+class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
public:
- explicit VectorFMAOp1DConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::FMAOpAdaptor(operands);
- vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
return failure();
- rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
+ rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
adaptor.rhs(), adaptor.acc());
return success();
}
};
-class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
+class VectorInsertElementOpConversion
+ : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
public:
- explicit VectorInsertElementOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
- context, typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpAdaptor(operands);
- auto insertEltOp = cast<vector::InsertElementOp>(op);
auto vectorType = insertEltOp.getDestVectorType();
auto llvmType = typeConverter->convertType(vectorType);
return failure();
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
+ insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
+ adaptor.position());
return success();
}
};
-class VectorInsertOpConversion : public ConvertToLLVMPattern {
+class VectorInsertOpConversion
+ : public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
- explicit VectorInsertOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ auto loc = insertOp->getLoc();
auto adaptor = vector::InsertOpAdaptor(operands);
- auto insertOp = cast<vector::InsertOp>(op);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
- rewriter.replaceOp(op, inserted);
+ rewriter.replaceOp(insertOp, inserted);
return success();
}
// Potential extraction of 1-D vector from array.
- auto *context = op->getContext();
+ auto *context = insertOp->getContext();
Value extracted = adaptor.dest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = positionAttrs.back().cast<IntegerAttr>();
nMinusOnePositionAttrs);
}
- rewriter.replaceOp(op, inserted);
+ rewriter.replaceOp(insertOp, inserted);
return success();
}
};
return strides;
}
-class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
+class VectorTypeCastOpConversion
+ : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
public:
- explicit VectorTypeCastOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
+ auto loc = castOp->getLoc();
MemRefType sourceMemRefType =
castOp.getOperand().getType().cast<MemRefType>();
MemRefType targetMemRefType =
desc.setStride(rewriter, loc, index, stride);
}
- rewriter.replaceOp(op, {desc});
+ rewriter.replaceOp(castOp, {desc});
return success();
}
};
/// 4. Create a mask where offsetVector is compared against memref upper bound.
/// 5. Rewrite op as a masked read or write.
template <typename ConcreteOp>
-class VectorTransferConversion : public ConvertToLLVMPattern {
+class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
public:
- explicit VectorTransferConversion(MLIRContext *context,
- LLVMTypeConverter &typeConv,
+ explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
bool enableIndexOpt)
- : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
+ : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto xferOp = cast<ConcreteOp>(op);
auto adaptor = getTransferOpAdapter(xferOp, operands);
if (xferOp.getVectorType().getRank() > 1 ||
if (xferOp.permutation_map() !=
AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
xferOp.getVectorType().getRank(),
- op->getContext()))
+ xferOp->getContext()))
return failure();
// Only contiguous source tensors supported atm.
auto strides = computeContiguousStrides(xferOp.getMemRefType());
if (!strides)
return failure();
- auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
+ auto toLLVMTy = [&](Type t) {
+ return this->getTypeConverter()->convertType(t);
+ };
- Location loc = op->getLoc();
+ Location loc = xferOp->getLoc();
MemRefType memRefType = xferOp.getMemRefType();
if (auto memrefVectorElementType =
// addrspacecast shall be used when source/dst memrefs are not on
// address space 0.
// TODO: support alignment when possible.
- Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter);
+ Value dataPtr = this->getStridedElementPtr(
+ loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
Value vectorDataPtr;
loc, vecTy.getPointerTo(), dataPtr);
if (!xferOp.isMaskedDim(0))
- return replaceTransferOpWithLoadOrStore(
- rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr);
+ return replaceTransferOpWithLoadOrStore(rewriter,
+ *this->getTypeConverter(), loc,
+ xferOp, operands, vectorDataPtr);
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
- Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
- vecWidth, dim, &off);
+ Value mask = buildVectorComparison(
+ rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
// 5. Rewrite as a masked read / write.
- return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc,
+ return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
xferOp, operands, vectorDataPtr, mask);
}
const bool enableIndexOptimizations;
};
-class VectorPrintOpConversion : public ConvertToLLVMPattern {
+class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
public:
- explicit VectorPrintOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
// Proof-of-concept lowering implementation that relies on a small
// runtime support library, which only needs to provide a few
// TODO: rely solely on libc in future? something else?
//
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto printOp = cast<vector::PrintOp>(op);
auto adaptor = vector::PrintOpAdaptor(operands);
Type printType = printOp.getPrintType();
Type eltType = vectorType ? vectorType.getElementType() : printType;
Operation *printer;
if (eltType.isF32()) {
- printer = getPrintFloat(op);
+ printer = getPrintFloat(printOp);
} else if (eltType.isF64()) {
- printer = getPrintDouble(op);
+ printer = getPrintDouble(printOp);
} else if (eltType.isIndex()) {
- printer = getPrintU64(op);
+ printer = getPrintU64(printOp);
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
- printer = getPrintU64(op);
+ printer = getPrintU64(printOp);
} else {
return failure();
}
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
- printer = getPrintI64(op);
+ printer = getPrintI64(printOp);
} else {
return failure();
}
// Unroll vector into elementary print calls.
int64_t rank = vectorType ? vectorType.getRank() : 0;
- emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
+ emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
conversion);
- emitCall(rewriter, op->getLoc(), getPrintNewline(op));
- rewriter.eraseOp(op);
+ emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
+ rewriter.eraseOp(printOp);
return success();
}
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorExtractStridedSliceOpConversion>(ctx);
patterns.insert<VectorReductionOpConversion>(
- ctx, converter, reassociateFPReductions);
+ converter, reassociateFPReductions);
patterns.insert<VectorCreateMaskOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>>(
- ctx, converter, enableIndexOptimizations);
+ converter, enableIndexOptimizations);
patterns
.insert<VectorShuffleOpConversion,
VectorExtractElementOpConversion,
VectorGatherOpConversion,
VectorScatterOpConversion,
VectorExpandLoadOpConversion,
- VectorCompressStoreOpConversion>(ctx, converter);
+ VectorCompressStoreOpConversion>(converter);
// clang-format on
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- MLIRContext *ctx = converter.getDialect()->getContext();
- patterns.insert<VectorMatmulOpConversion>(ctx, converter);
- patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
+ patterns.insert<VectorMatmulOpConversion>(converter);
+ patterns.insert<VectorFlatTransposeOpConversion>(converter);
}