From 9826fe5c9fb65da8f1d53b21348f013c58c09791 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 10 Dec 2019 17:12:11 -0800 Subject: [PATCH] [VectorOps] Add lowering of vector.insert to LLVM IR For example, an insert %0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32> becomes %0 = llvm.mlir.constant(3 : i32) : !llvm.i32 %1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>"> A more elaborate example, inserting an element in a higher dimension vector %0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32> becomes %0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]"> %1 = llvm.mlir.constant(15 : i32) : !llvm.i32 %2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>"> %3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]"> PiperOrigin-RevId: 284882443 --- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 124 +++++++++++++++++---- .../Conversion/VectorToLLVM/vector-to-llvm.mlir | 53 +++++++++ 2 files changed, 153 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 8adc415..17fb933 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -49,6 +49,18 @@ static LLVM::LLVMType getPtrToElementType(T containerType, .getPointerTo(); } +// Helper to reduce vector type by one rank at front. +static VectorType reducedVectorTypeFront(VectorType tp) { + assert((tp.getRank() > 1) && "unlowerable vector type"); + return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); +} + +// Helper to reduce vector type by *all* but one rank at back. +static VectorType reducedVectorTypeBack(VectorType tp) { + assert((tp.getRank() > 1) && "unlowerable vector type"); + return VectorType::get(tp.getShape().take_back(), tp.getElementType()); +} + class VectorBroadcastOpConversion : public LLVMOpLowering { public: explicit VectorBroadcastOpConversion(MLIRContext *context, @@ -135,8 +147,9 @@ private: return rewriter.create( loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); } - Value *expand = expandRanks(value, loc, srcVectorType, - reducedVectorType(dstVectorType), rewriter); + Value *expand = + expandRanks(value, loc, srcVectorType, + reducedVectorTypeFront(dstVectorType), rewriter); Value *result = rewriter.create(loc, llvmType); for (int64_t d = 0; d < dim; ++d) { result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); @@ -183,8 +196,8 @@ private: result = insertOne(result, one, loc, llvmType, rank, d, rewriter); } } else { - VectorType redSrcType = reducedVectorType(srcVectorType); - VectorType redDstType = reducedVectorType(dstVectorType); + VectorType redSrcType = reducedVectorTypeFront(srcVectorType); + VectorType redDstType = reducedVectorTypeFront(dstVectorType); Type redLlvmType = lowering.convertType(redSrcType); for (int64_t d = 0; d < dim; ++d) { int64_t pos = atStretch ? 0 : d; @@ -226,18 +239,12 @@ private: return rewriter.create(loc, llvmType, value, rewriter.getI64ArrayAttr(pos)); } - - // Helper to reduce vector type by one rank. - static VectorType reducedVectorType(VectorType tp) { - assert((tp.getRank() > 1) && "unlowerable vector type"); - return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); - } }; -class VectorExtractElementOpConversion : public LLVMOpLowering { +class VectorExtractOpConversion : public LLVMOpLowering { public: - explicit VectorExtractElementOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) + explicit VectorExtractOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, typeConverter) {} @@ -247,11 +254,15 @@ public: auto loc = op->getLoc(); auto adaptor = vector::ExtractOpOperandAdaptor(operands); auto extractOp = cast(op); - auto vectorType = extractOp.vector()->getType().cast(); + auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult()->getType(); auto llvmResultType = lowering.convertType(resultType); - auto positionArrayAttr = extractOp.position(); + + // Bail if result type cannot be lowered. + if (!llvmResultType) + return matchFailure(); + // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { Value *extracted = rewriter.create( @@ -260,15 +271,12 @@ public: return matchSuccess(); } - // Potential extraction of 1-D vector from struct. + // Potential extraction of 1-D vector from array. auto *context = op->getContext(); Value *extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); - auto i32Type = rewriter.getIntegerType(32); if (positionAttrs.size() > 1) { - auto nDVectorType = vectorType; - auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(), - nDVectorType.getElementType()); + auto oneDVectorType = reducedVectorTypeBack(vectorType); auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create( @@ -278,8 +286,8 @@ public: // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast(); - auto constant = rewriter.create( - loc, lowering.convertType(i32Type), position); + auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); + auto constant = rewriter.create(loc, i32Type, position); extracted = rewriter.create(loc, extracted, constant); rewriter.replaceOp(op, extracted); @@ -288,6 +296,73 @@ public: } }; +class VectorInsertOpConversion : public LLVMOpLowering { +public: + explicit VectorInsertOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::InsertOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::InsertOpOperandAdaptor(operands); + auto insertOp = cast(op); + auto sourceType = insertOp.getSourceType(); + auto destVectorType = insertOp.getDestVectorType(); + auto llvmResultType = lowering.convertType(destVectorType); + auto positionArrayAttr = insertOp.position(); + + // Bail if result type cannot be lowered. + if (!llvmResultType) + return matchFailure(); + + // One-shot insertion of a vector into an array (only requires insertvalue). + if (sourceType.isa()) { + Value *inserted = rewriter.create( + loc, llvmResultType, adaptor.dest(), adaptor.source(), + positionArrayAttr); + rewriter.replaceOp(op, inserted); + return matchSuccess(); + } + + // Potential extraction of 1-D vector from array. + auto *context = op->getContext(); + Value *extracted = adaptor.dest(); + auto positionAttrs = positionArrayAttr.getValue(); + auto position = positionAttrs.back().cast(); + auto oneDVectorType = destVectorType; + if (positionAttrs.size() > 1) { + oneDVectorType = reducedVectorTypeBack(destVectorType); + auto nMinusOnePositionAttrs = + ArrayAttr::get(positionAttrs.drop_back(), context); + extracted = rewriter.create( + loc, lowering.convertType(oneDVectorType), extracted, + nMinusOnePositionAttrs); + } + + // Insertion of an element into a 1-D LLVM vector. + auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); + auto constant = rewriter.create(loc, i32Type, position); + Value *inserted = rewriter.create( + loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), + constant); + + // Potential insertion of resulting 1-D vector into array. + if (positionAttrs.size() > 1) { + auto nMinusOnePositionAttrs = + ArrayAttr::get(positionAttrs.drop_back(), context); + inserted = rewriter.create(loc, llvmResultType, + adaptor.dest(), inserted, + nMinusOnePositionAttrs); + } + + rewriter.replaceOp(op, inserted); + return matchSuccess(); + } +}; + class VectorOuterProductOpConversion : public LLVMOpLowering { public: explicit VectorOuterProductOpConversion(MLIRContext *context, @@ -431,8 +506,9 @@ public: /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert( + patterns.insert( converter.getDialect()->getContext(), converter); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 0802799..28c21f6 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -230,6 +230,15 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector // CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> // CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> +func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 { + %0 = vector.extract %arg0[15 : i32]: vector<16xf32> + return %0 : f32 +} +// CHECK-LABEL: extract_element_from_vec_1d +// CHECK: llvm.mlir.constant(15 : i32) : !llvm.i32 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>"> +// CHECK: llvm.return {{.*}} : !llvm.float + func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> { %0 = vector.extract %arg0[0 : i32]: vector<4x3x16xf32> return %0 : vector<3x16xf32> @@ -238,6 +247,14 @@ func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> // CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> // CHECK: llvm.return {{.*}} : !llvm<"[3 x <16 x float>]"> +func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf32> { + %0 = vector.extract %arg0[0 : i32, 0 : i32]: vector<4x3x16xf32> + return %0 : vector<16xf32> +} +// CHECK-LABEL: extract_vec_1d_from_vec_3d +// CHECK: llvm.extractvalue {{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"<16 x float>"> + func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 { %0 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32> return %0 : f32 @@ -248,6 +265,42 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 { // CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>"> // CHECK: llvm.return {{.*}} : !llvm.float +func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { + %0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: insert_element_into_vec_1d +// CHECK: llvm.mlir.constant(3 : i32) : !llvm.i32 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<4 x float>"> +// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>"> + +func @insert_vec_2d_into_vec_3d(%arg0: vector<8x16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> { + %0 = vector.insert %arg0, %arg1[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> + return %0 : vector<4x8x16xf32> +} +// CHECK-LABEL: insert_vec_2d_into_vec_3d +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3 : i32] : !llvm<"[4 x [8 x <16 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]"> + +func @insert_vec_1d_into_vec_3d(%arg0: vector<16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> { + %0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32] : vector<16xf32> into vector<4x8x16xf32> + return %0 : vector<4x8x16xf32> +} +// CHECK-LABEL: insert_vec_1d_into_vec_3d +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]"> + +func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> { + %0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32> + return %0 : vector<4x8x16xf32> +} +// CHECK-LABEL: insert_element_into_vec_3d +// CHECK: llvm.extractvalue {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]"> +// CHECK: llvm.mlir.constant(15 : i32) : !llvm.i32 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]"> + func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref> { %0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref> return %0 : memref> -- 2.7.4