.getPointerTo();
}
+// Helper to reduce vector type by one rank at front.
+static VectorType reducedVectorTypeFront(VectorType tp) {
+ assert((tp.getRank() > 1) && "unlowerable vector type");
+ return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
+}
+
+// Helper to reduce vector type by *all* but one rank at back.
+static VectorType reducedVectorTypeBack(VectorType tp) {
+ assert((tp.getRank() > 1) && "unlowerable vector type");
+ return VectorType::get(tp.getShape().take_back(), tp.getElementType());
+}
+
class VectorBroadcastOpConversion : public LLVMOpLowering {
public:
explicit VectorBroadcastOpConversion(MLIRContext *context,
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
}
- Value *expand = expandRanks(value, loc, srcVectorType,
- reducedVectorType(dstVectorType), rewriter);
+ Value *expand =
+ expandRanks(value, loc, srcVectorType,
+ reducedVectorTypeFront(dstVectorType), rewriter);
Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (int64_t d = 0; d < dim; ++d) {
result = insertOne(result, expand, loc, llvmType, rank, d, rewriter);
result = insertOne(result, one, loc, llvmType, rank, d, rewriter);
}
} else {
- VectorType redSrcType = reducedVectorType(srcVectorType);
- VectorType redDstType = reducedVectorType(dstVectorType);
+ VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
+ VectorType redDstType = reducedVectorTypeFront(dstVectorType);
Type redLlvmType = lowering.convertType(redSrcType);
for (int64_t d = 0; d < dim; ++d) {
int64_t pos = atStretch ? 0 : d;
return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value,
rewriter.getI64ArrayAttr(pos));
}
-
- // Helper to reduce vector type by one rank.
- static VectorType reducedVectorType(VectorType tp) {
- assert((tp.getRank() > 1) && "unlowerable vector type");
- return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
- }
};
-class VectorExtractElementOpConversion : public LLVMOpLowering {
+class VectorExtractOpConversion : public LLVMOpLowering {
public:
- explicit VectorExtractElementOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
+ explicit VectorExtractOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::ExtractOp::getOperationName(), context,
typeConverter) {}
auto loc = op->getLoc();
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
auto extractOp = cast<vector::ExtractOp>(op);
- auto vectorType = extractOp.vector()->getType().cast<VectorType>();
+ auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult()->getType();
auto llvmResultType = lowering.convertType(resultType);
-
auto positionArrayAttr = extractOp.position();
+
+ // Bail if result type cannot be lowered.
+ if (!llvmResultType)
+ return matchFailure();
+
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
return matchSuccess();
}
- // Potential extraction of 1-D vector from struct.
+ // Potential extraction of 1-D vector from array.
auto *context = op->getContext();
Value *extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
- auto i32Type = rewriter.getIntegerType(32);
if (positionAttrs.size() > 1) {
- auto nDVectorType = vectorType;
- auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
- nDVectorType.getElementType());
+ auto oneDVectorType = reducedVectorTypeBack(vectorType);
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto constant = rewriter.create<LLVM::ConstantOp>(
- loc, lowering.convertType(i32Type), position);
+ auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
+ auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
}
};
+class VectorInsertOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorInsertOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::InsertOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto adaptor = vector::InsertOpOperandAdaptor(operands);
+ auto insertOp = cast<vector::InsertOp>(op);
+ auto sourceType = insertOp.getSourceType();
+ auto destVectorType = insertOp.getDestVectorType();
+ auto llvmResultType = lowering.convertType(destVectorType);
+ auto positionArrayAttr = insertOp.position();
+
+ // Bail if result type cannot be lowered.
+ if (!llvmResultType)
+ return matchFailure();
+
+ // One-shot insertion of a vector into an array (only requires insertvalue).
+ if (sourceType.isa<VectorType>()) {
+ Value *inserted = rewriter.create<LLVM::InsertValueOp>(
+ loc, llvmResultType, adaptor.dest(), adaptor.source(),
+ positionArrayAttr);
+ rewriter.replaceOp(op, inserted);
+ return matchSuccess();
+ }
+
+ // Potential extraction of 1-D vector from array.
+ auto *context = op->getContext();
+ Value *extracted = adaptor.dest();
+ auto positionAttrs = positionArrayAttr.getValue();
+ auto position = positionAttrs.back().cast<IntegerAttr>();
+ auto oneDVectorType = destVectorType;
+ if (positionAttrs.size() > 1) {
+ oneDVectorType = reducedVectorTypeBack(destVectorType);
+ auto nMinusOnePositionAttrs =
+ ArrayAttr::get(positionAttrs.drop_back(), context);
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, lowering.convertType(oneDVectorType), extracted,
+ nMinusOnePositionAttrs);
+ }
+
+ // Insertion of an element into a 1-D LLVM vector.
+ auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
+ auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
+ Value *inserted = rewriter.create<LLVM::InsertElementOp>(
+ loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
+ constant);
+
+ // Potential insertion of resulting 1-D vector into array.
+ if (positionAttrs.size() > 1) {
+ auto nMinusOnePositionAttrs =
+ ArrayAttr::get(positionAttrs.drop_back(), context);
+ inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
+ adaptor.dest(), inserted,
+ nMinusOnePositionAttrs);
+ }
+
+ rewriter.replaceOp(op, inserted);
+ return matchSuccess();
+ }
+};
+
class VectorOuterProductOpConversion : public LLVMOpLowering {
public:
explicit VectorOuterProductOpConversion(MLIRContext *context,
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<VectorBroadcastOpConversion, VectorExtractElementOpConversion,
- VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
+ patterns.insert<VectorBroadcastOpConversion, VectorExtractOpConversion,
+ VectorInsertOpConversion, VectorOuterProductOpConversion,
+ VectorTypeCastOpConversion>(
converter.getDialect()->getContext(), converter);
}
// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
+func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
+ %0 = vector.extract %arg0[15 : i32]: vector<16xf32>
+ return %0 : f32
+}
+// CHECK-LABEL: extract_element_from_vec_1d
+// CHECK: llvm.mlir.constant(15 : i32) : !llvm.i32
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
+// CHECK: llvm.return {{.*}} : !llvm.float
+
func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
%0 = vector.extract %arg0[0 : i32]: vector<4x3x16xf32>
return %0 : vector<3x16xf32>
// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"[3 x <16 x float>]">
+func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf32> {
+ %0 = vector.extract %arg0[0 : i32, 0 : i32]: vector<4x3x16xf32>
+ return %0 : vector<16xf32>
+}
+// CHECK-LABEL: extract_vec_1d_from_vec_3d
+// CHECK: llvm.extractvalue {{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
+// CHECK: llvm.return {{.*}} : !llvm<"<16 x float>">
+
func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
%0 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
return %0 : f32
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
// CHECK: llvm.return {{.*}} : !llvm.float
+func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
+ return %0 : vector<4xf32>
+}
+// CHECK-LABEL: insert_element_into_vec_1d
+// CHECK: llvm.mlir.constant(3 : i32) : !llvm.i32
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<4 x float>">
+// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>">
+
+func @insert_vec_2d_into_vec_3d(%arg0: vector<8x16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+ %0 = vector.insert %arg0, %arg1[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
+ return %0 : vector<4x8x16xf32>
+}
+// CHECK-LABEL: insert_vec_2d_into_vec_3d
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
+// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">
+
+func @insert_vec_1d_into_vec_3d(%arg0: vector<16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+ %0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32] : vector<16xf32> into vector<4x8x16xf32>
+ return %0 : vector<4x8x16xf32>
+}
+// CHECK-LABEL: insert_vec_1d_into_vec_3d
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
+// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">
+
+func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+ %0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
+ return %0 : vector<4x8x16xf32>
+}
+// CHECK-LABEL: insert_element_into_vec_3d
+// CHECK: llvm.extractvalue {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
+// CHECK: llvm.mlir.constant(15 : i32) : !llvm.i32
+// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
+// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">
+
func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
%0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
return %0 : memref<vector<8x8x8xf32>>