[VectorOps] Add lowering of vector.insert to LLVM IR
authorAart Bik <ajcbik@google.com>
Wed, 11 Dec 2019 01:12:11 +0000 (17:12 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 11 Dec 2019 01:12:49 +0000 (17:12 -0800)
For example, an insert

  %0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>

becomes

  %0 = llvm.mlir.constant(3 : i32) : !llvm.i32
  %1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">

A more elaborate example, inserting an element in a higher dimension
vector

  %0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>

becomes

  %0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
  %1 = llvm.mlir.constant(15 : i32) : !llvm.i32
  %2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
  %3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">

PiperOrigin-RevId: 284882443

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

index 8adc415..17fb933 100644 (file)
@@ -49,6 +49,18 @@ static LLVM::LLVMType getPtrToElementType(T containerType,
       .getPointerTo();
 }
 
+// Helper to reduce vector type by one rank at front.
+static VectorType reducedVectorTypeFront(VectorType tp) {
+  assert((tp.getRank() > 1) && "unlowerable vector type");
+  return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
+}
+
+// Helper to reduce vector type by *all* but one rank at back.
+static VectorType reducedVectorTypeBack(VectorType tp) {
+  assert((tp.getRank() > 1) && "unlowerable vector type");
+  return VectorType::get(tp.getShape().take_back(), tp.getElementType());
+}
+
 class VectorBroadcastOpConversion : public LLVMOpLowering {
 public:
   explicit VectorBroadcastOpConversion(MLIRContext *context,
@@ -135,8 +147,9 @@ private:
       return rewriter.create<LLVM::ShuffleVectorOp>(
           loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
     }
-    Value *expand = expandRanks(value, loc, srcVectorType,
-                                reducedVectorType(dstVectorType), rewriter);
+    Value *expand =
+        expandRanks(value, loc, srcVectorType,
+                    reducedVectorTypeFront(dstVectorType), rewriter);
     Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
     for (int64_t d = 0; d < dim; ++d) {
       result = insertOne(result, expand, loc, llvmType, rank, d, rewriter);
@@ -183,8 +196,8 @@ private:
         result = insertOne(result, one, loc, llvmType, rank, d, rewriter);
       }
     } else {
-      VectorType redSrcType = reducedVectorType(srcVectorType);
-      VectorType redDstType = reducedVectorType(dstVectorType);
+      VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
+      VectorType redDstType = reducedVectorTypeFront(dstVectorType);
       Type redLlvmType = lowering.convertType(redSrcType);
       for (int64_t d = 0; d < dim; ++d) {
         int64_t pos = atStretch ? 0 : d;
@@ -226,18 +239,12 @@ private:
     return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value,
                                                  rewriter.getI64ArrayAttr(pos));
   }
-
-  // Helper to reduce vector type by one rank.
-  static VectorType reducedVectorType(VectorType tp) {
-    assert((tp.getRank() > 1) && "unlowerable vector type");
-    return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
-  }
 };
 
-class VectorExtractElementOpConversion : public LLVMOpLowering {
+class VectorExtractOpConversion : public LLVMOpLowering {
 public:
-  explicit VectorExtractElementOpConversion(MLIRContext *context,
-                                            LLVMTypeConverter &typeConverter)
+  explicit VectorExtractOpConversion(MLIRContext *context,
+                                     LLVMTypeConverter &typeConverter)
       : LLVMOpLowering(vector::ExtractOp::getOperationName(), context,
                        typeConverter) {}
 
@@ -247,11 +254,15 @@ public:
     auto loc = op->getLoc();
     auto adaptor = vector::ExtractOpOperandAdaptor(operands);
     auto extractOp = cast<vector::ExtractOp>(op);
-    auto vectorType = extractOp.vector()->getType().cast<VectorType>();
+    auto vectorType = extractOp.getVectorType();
     auto resultType = extractOp.getResult()->getType();
     auto llvmResultType = lowering.convertType(resultType);
-
     auto positionArrayAttr = extractOp.position();
+
+    // Bail if result type cannot be lowered.
+    if (!llvmResultType)
+      return matchFailure();
+
     // One-shot extraction of vector from array (only requires extractvalue).
     if (resultType.isa<VectorType>()) {
       Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
@@ -260,15 +271,12 @@ public:
       return matchSuccess();
     }
 
-    // Potential extraction of 1-D vector from struct.
+    // Potential extraction of 1-D vector from array.
     auto *context = op->getContext();
     Value *extracted = adaptor.vector();
     auto positionAttrs = positionArrayAttr.getValue();
-    auto i32Type = rewriter.getIntegerType(32);
     if (positionAttrs.size() > 1) {
-      auto nDVectorType = vectorType;
-      auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
-                                            nDVectorType.getElementType());
+      auto oneDVectorType = reducedVectorTypeBack(vectorType);
       auto nMinusOnePositionAttrs =
           ArrayAttr::get(positionAttrs.drop_back(), context);
       extracted = rewriter.create<LLVM::ExtractValueOp>(
@@ -278,8 +286,8 @@ public:
 
     // Remaining extraction of element from 1-D LLVM vector
     auto position = positionAttrs.back().cast<IntegerAttr>();
-    auto constant = rewriter.create<LLVM::ConstantOp>(
-        loc, lowering.convertType(i32Type), position);
+    auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
+    auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
     extracted =
         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
     rewriter.replaceOp(op, extracted);
@@ -288,6 +296,73 @@ public:
   }
 };
 
+class VectorInsertOpConversion : public LLVMOpLowering {
+public:
+  explicit VectorInsertOpConversion(MLIRContext *context,
+                                    LLVMTypeConverter &typeConverter)
+      : LLVMOpLowering(vector::InsertOp::getOperationName(), context,
+                       typeConverter) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto adaptor = vector::InsertOpOperandAdaptor(operands);
+    auto insertOp = cast<vector::InsertOp>(op);
+    auto sourceType = insertOp.getSourceType();
+    auto destVectorType = insertOp.getDestVectorType();
+    auto llvmResultType = lowering.convertType(destVectorType);
+    auto positionArrayAttr = insertOp.position();
+
+    // Bail if result type cannot be lowered.
+    if (!llvmResultType)
+      return matchFailure();
+
+    // One-shot insertion of a vector into an array (only requires insertvalue).
+    if (sourceType.isa<VectorType>()) {
+      Value *inserted = rewriter.create<LLVM::InsertValueOp>(
+          loc, llvmResultType, adaptor.dest(), adaptor.source(),
+          positionArrayAttr);
+      rewriter.replaceOp(op, inserted);
+      return matchSuccess();
+    }
+
+    // Potential extraction of 1-D vector from array.
+    auto *context = op->getContext();
+    Value *extracted = adaptor.dest();
+    auto positionAttrs = positionArrayAttr.getValue();
+    auto position = positionAttrs.back().cast<IntegerAttr>();
+    auto oneDVectorType = destVectorType;
+    if (positionAttrs.size() > 1) {
+      oneDVectorType = reducedVectorTypeBack(destVectorType);
+      auto nMinusOnePositionAttrs =
+          ArrayAttr::get(positionAttrs.drop_back(), context);
+      extracted = rewriter.create<LLVM::ExtractValueOp>(
+          loc, lowering.convertType(oneDVectorType), extracted,
+          nMinusOnePositionAttrs);
+    }
+
+    // Insertion of an element into a 1-D LLVM vector.
+    auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
+    auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
+    Value *inserted = rewriter.create<LLVM::InsertElementOp>(
+        loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
+        constant);
+
+    // Potential insertion of resulting 1-D vector into array.
+    if (positionAttrs.size() > 1) {
+      auto nMinusOnePositionAttrs =
+          ArrayAttr::get(positionAttrs.drop_back(), context);
+      inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
+                                                      adaptor.dest(), inserted,
+                                                      nMinusOnePositionAttrs);
+    }
+
+    rewriter.replaceOp(op, inserted);
+    return matchSuccess();
+  }
+};
+
 class VectorOuterProductOpConversion : public LLVMOpLowering {
 public:
   explicit VectorOuterProductOpConversion(MLIRContext *context,
@@ -431,8 +506,9 @@ public:
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
-  patterns.insert<VectorBroadcastOpConversion, VectorExtractElementOpConversion,
-                  VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
+  patterns.insert<VectorBroadcastOpConversion, VectorExtractOpConversion,
+                  VectorInsertOpConversion, VectorOuterProductOpConversion,
+                  VectorTypeCastOpConversion>(
       converter.getDialect()->getContext(), converter);
 }
 
index 0802799..28c21f6 100644 (file)
@@ -230,6 +230,15 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
 //          CHECK:   llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
 //          CHECK:   llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
 
+func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
+  %0 = vector.extract %arg0[15 : i32]: vector<16xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: extract_element_from_vec_1d
+//       CHECK:   llvm.mlir.constant(15 : i32) : !llvm.i32
+//       CHECK:   llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
+//       CHECK:   llvm.return {{.*}} : !llvm.float
+
 func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
   %0 = vector.extract %arg0[0 : i32]: vector<4x3x16xf32>
   return %0 : vector<3x16xf32>
@@ -238,6 +247,14 @@ func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32>
 //       CHECK:   llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
 //       CHECK:   llvm.return {{.*}} : !llvm<"[3 x <16 x float>]">
 
+func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf32> {
+  %0 = vector.extract %arg0[0 : i32, 0 : i32]: vector<4x3x16xf32>
+  return %0 : vector<16xf32>
+}
+// CHECK-LABEL: extract_vec_1d_from_vec_3d
+//       CHECK:   llvm.extractvalue {{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
+//       CHECK:   llvm.return {{.*}} : !llvm<"<16 x float>">
+
 func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
   %0 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
   return %0 : f32
@@ -248,6 +265,42 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
 //       CHECK:   llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
 //       CHECK:   llvm.return {{.*}} : !llvm.float
 
+func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
+  return %0 : vector<4xf32>
+}
+// CHECK-LABEL: insert_element_into_vec_1d
+//       CHECK:   llvm.mlir.constant(3 : i32) : !llvm.i32
+//       CHECK:   llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<4 x float>">
+//       CHECK:   llvm.return {{.*}} : !llvm<"<4 x float>">
+
+func @insert_vec_2d_into_vec_3d(%arg0: vector<8x16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+  %0 = vector.insert %arg0, %arg1[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
+  return %0 : vector<4x8x16xf32>
+}
+// CHECK-LABEL: insert_vec_2d_into_vec_3d
+//       CHECK:   llvm.insertvalue {{.*}}, {{.*}}[3 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
+//       CHECK:   llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">
+
+func @insert_vec_1d_into_vec_3d(%arg0: vector<16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+  %0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32] : vector<16xf32> into vector<4x8x16xf32>
+  return %0 : vector<4x8x16xf32>
+}
+// CHECK-LABEL: insert_vec_1d_into_vec_3d
+//       CHECK:   llvm.insertvalue {{.*}}, {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
+//       CHECK:   llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">
+
+func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+  %0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
+  return %0 : vector<4x8x16xf32>
+}
+// CHECK-LABEL: insert_element_into_vec_3d
+//       CHECK:   llvm.extractvalue {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
+//       CHECK:   llvm.mlir.constant(15 : i32) : !llvm.i32
+//       CHECK:   llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
+//       CHECK:   llvm.insertvalue {{.*}}, {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
+//       CHECK:   llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">
+
 func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
   %0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
   return %0 : memref<vector<8x8x8xf32>>