[mlir][vector] Add vector.scalable.insert/extract ops
authorJavier Setoain <javier.setoain@gmail.com>
Wed, 15 Jun 2022 13:27:25 +0000 (14:27 +0100)
committerJavier Setoain <javier.setoain@gmail.com>
Tue, 8 Nov 2022 08:51:15 +0000 (08:51 +0000)
These new operations match the semantics of
llvm.experimental.vector.insert and llvm.experimental.vector.extract.

`vector.scalable.insert` and `vector.scalable.extract` allow,
respectively, insert vectors into scalable vectors, and extract vectors
from scalable vectors.

The discussion about the inclusion of these operations is here:
https://discourse.llvm.org/t/rfc-interfacing-between-fixed-length-and-scalable-vectors-for-vls-vector-code-on-scalable-vector-architectures

Differential Revision: https://reviews.llvm.org/D127875

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir

index e952284..758e7c1 100644 (file)
@@ -727,6 +727,114 @@ def Vector_InsertOp :
   let hasVerifier = 1;
 }
 
+def Vector_ScalableInsertOp :
+  Vector_Op<"scalable.insert", [Pure,
+       AllElementTypesMatch<["source", "dest"]>,
+       AllTypesMatch<["dest", "res"]>,
+       PredOpTrait<"position is a multiple of the source length.",
+        CPred<
+          "(getPos() % getSourceVectorType().getNumElements()) == 0"
+        >>]>,
+     Arguments<(ins VectorOfRank<[1]>:$source,
+                    ScalableVectorOfRank<[1]>:$dest,
+                    I64Attr:$pos)>,
+     Results<(outs ScalableVectorOfRank<[1]>:$res)> {
+  let summary = "insert subvector into scalable vector operation";
+  // NOTE: This operation is designed to map to `llvm.vector.insert`, and its
+  //       documentation should be kept aligned with LLVM IR:
+  //       https://llvm.org/docs/LangRef.html#llvm-vector-insert-intrinsic
+  let description = [{
+    This operations takes a rank-1 fixed-length or scalable subvector and
+    inserts it within the destination scalable vector starting from the
+    position specificed by `pos`. If the source vector is scalable, the
+    insertion position will be scaled by the runtime scaling factor of the
+    source subvector.
+
+    The insertion position must be a multiple of the minimum size of the source
+    vector. For the operation to be well defined, the source vector must fit in
+    the destination vector from the specified position. Since the destination
+    vector is scalable and its runtime length is unknown, the validity of the
+    operation can't be verified nor guaranteed at compile time.
+
+    Example:
+
+    ```mlir
+    %2 = vector.scalable.insert %0, %1[8] : vector<4xf32> into vector<[16]xf32>
+    %5 = vector.scalable.insert %3, %4[0] : vector<8xf32> into vector<[4]xf32>
+    %8 = vector.scalable.insert %6, %7[0] : vector<[4]xf32> into vector<[8]xf32>
+    ```
+
+    Invalid example:
+    ```mlir
+    %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
+    ```
+  }];
+
+  let assemblyFormat = [{
+    $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getSourceVectorType() {
+      return getSource().getType().cast<VectorType>();
+    }
+    VectorType getDestVectorType() {
+      return getDest().getType().cast<VectorType>();
+    }
+  }];
+}
+
+def Vector_ScalableExtractOp :
+  Vector_Op<"scalable.extract", [Pure,
+       AllElementTypesMatch<["source", "res"]>,
+       PredOpTrait<"position is a multiple of the result length.",
+        CPred<
+          "(getPos() % getResultVectorType().getNumElements()) == 0"
+        >>]>,
+     Arguments<(ins ScalableVectorOfRank<[1]>:$source,
+                    I64Attr:$pos)>,
+     Results<(outs VectorOfRank<[1]>:$res)> {
+  let summary = "extract subvector from scalable vector operation";
+  // NOTE: This operation is designed to map to `llvm.vector.extract`, and its
+  //       documentation should be kept aligned with LLVM IR:
+  //       https://llvm.org/docs/LangRef.html#llvm-vector-extract-intrinsic
+  let description = [{
+    Takes rank-1 source vector and a position `pos` within the source
+    vector, and extracts a subvector starting from that position.
+
+    The extraction position must be a multiple of the minimum size of the result
+    vector. For the operation to be well defined, the destination vector must
+    fit within the source vector from the specified position. Since the source
+    vector is scalable and its runtime length is unknown, the validity of the
+    operation can't be verified nor guaranteed at compile time.
+
+    Example:
+
+    ```mlir
+    %1 = vector.scalable.extract %0[8] : vector<4xf32> from vector<[8]xf32>
+    %3 = vector.scalable.extract %2[0] : vector<[4]xf32> from vector<[8]xf32>
+    ```
+
+    Invalid example:
+    ```mlir
+    %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
+    ```
+  }];
+
+  let assemblyFormat = [{
+    $source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getSourceVectorType() {
+      return getSource().getType().cast<VectorType>();
+    }
+    VectorType getResultVectorType() {
+      return getRes().getType().cast<VectorType>();
+    }
+  }];
+}
+
 def Vector_InsertStridedSliceOp :
   Vector_Op<"insert_strided_slice", [Pure,
     PredOpTrait<"operand #0 and result have same element type",
index 27e55ad..da0e2de 100644 (file)
@@ -579,11 +579,39 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
                            == }]
                          # allowedlength>)>]>;
 
+// Whether the number of elements of a fixed-length vector is from the given
+// `allowedRanks` list
+class IsFixedVectorOfRankPred<list<int> allowedRanks> :
+  And<[IsFixedVectorTypePred,
+       Or<!foreach(allowedlength, allowedRanks,
+                   CPred<[{$_self.cast<::mlir::VectorType>().getRank()
+                           == }]
+                         # allowedlength>)>]>;
+
+// Whether the number of elements of a scalable vector is from the given
+// `allowedRanks` list
+class IsScalableVectorOfRankPred<list<int> allowedRanks> :
+  And<[IsScalableVectorTypePred,
+       Or<!foreach(allowedlength, allowedRanks,
+                   CPred<[{$_self.cast<::mlir::VectorType>().getRank()
+                           == }]
+                         # allowedlength>)>]>;
+
 // Any vector where the rank is from the given `allowedRanks` list
 class VectorOfRank<list<int> allowedRanks> : Type<
   IsVectorOfRankPred<allowedRanks>,
   " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
 
+// Any fixed-length vector where the rank is from the given `allowedRanks` list
+class FixedVectorOfRank<list<int> allowedRanks> : Type<
+  IsFixedVectorOfRankPred<allowedRanks>,
+  " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
+
+// Any scalable vector where the rank is from the given `allowedRanks` list
+class ScalableVectorOfRank<list<int> allowedRanks> : Type<
+  IsScalableVectorOfRankPred<allowedRanks>,
+  " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
+
 // Any vector where the rank is from the given `allowedRanks` list and the type
 // is from the given `allowedTypes` list
 class VectorOfRankAndType<list<int> allowedRanks,
index f4191e2..fd3cee5 100644 (file)
@@ -857,6 +857,37 @@ public:
   }
 };
 
+/// Lower vector.scalable.insert ops to LLVM vector.insert
+struct VectorScalableInsertOpLowering
+    : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
+  using ConvertOpToLLVMPattern<
+      vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
+        insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
+    return success();
+  }
+};
+
+/// Lower vector.scalable.extract ops to LLVM vector.extract
+struct VectorScalableExtractOpLowering
+    : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
+  using ConvertOpToLLVMPattern<
+      vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
+        extOp, typeConverter->convertType(extOp.getResultVectorType()),
+        adaptor.getSource(), adaptor.getPos());
+    return success();
+  }
+};
+
 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
 ///
 /// Example:
@@ -1329,7 +1360,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
                                      vector::MaskedStoreOpAdaptor>,
            VectorGatherOpConversion, VectorScatterOpConversion,
            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
-           VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
+           VectorSplatOpLowering, VectorSplatNdOpLowering,
+           VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>(
+          converter);
   // Transfer ops with rank > 1 are handled by VectorToSCF.
   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
index 0a4732a..24697bc 100644 (file)
@@ -2140,3 +2140,25 @@ func.func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
 // CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0, 0, 0, 0]
 // CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32>
 // CHECK-NEXT: return %[[SCALE]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: @vector_scalable_insert
+// CHECK-SAME: %[[SUB:.*]]: vector<4xf32>, %[[SV:.*]]: vector<[4]xf32>
+func.func @vector_scalable_insert(%sub: vector<4xf32>, %dsv: vector<[4]xf32>) -> vector<[4]xf32> {
+  // CHECK-NEXT: %[[TMP:.*]] = llvm.intr.vector.insert %[[SUB]], %[[SV]][0] : vector<4xf32> into vector<[4]xf32>
+  %0 = vector.scalable.insert %sub, %dsv[0] : vector<4xf32> into vector<[4]xf32>
+  // CHECK-NEXT: llvm.intr.vector.insert %[[SUB]], %[[TMP]][4] : vector<4xf32> into vector<[4]xf32>
+  %1 = vector.scalable.insert %sub, %0[4] : vector<4xf32> into vector<[4]xf32>
+  return %1 : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_scalable_extract
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> {
+  // CHECK-NEXT: %{{.*}} = llvm.intr.vector.extract %[[VEC]][0] : vector<8xf32> from vector<[4]xf32>
+  %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32>
+  return %0 : vector<8xf32>
+}
index 115cf41..3366b1d 100644 (file)
@@ -1632,3 +1632,16 @@ func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor<?xf3
   return
 }
 
+// -----
+
+func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
+  // expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
+  %0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
+}
+
+// -----
+
+func.func @vector_scalable_extract_unaligned(%vec: vector<[16]xf32>) {
+  // expected-error@+1 {{op failed to verify that position is a multiple of the result length.}}
+  %0 = vector.scalable.extract %vec[5] : vector<4xf32> from vector<[16]xf32>
+}
index 7a6a2c3..88b1abb 100644 (file)
@@ -853,3 +853,28 @@ func.func @vector_mask_tensor_return(%val: vector<16xf32>, %t0: tensor<?xf32>, %
   return
 }
 
+// CHECK-LABEL: func @vector_scalable_insert(
+// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
+// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>
+func.func @vector_scalable_insert(%sub0: vector<4xi32>, %sub1: vector<8xi32>,
+                                 %sub2: vector<[4]xi32>, %sv: vector<[8]xi32>) {
+  // CHECK-NEXT: vector.scalable.insert %[[SUB0]], %[[SV]][12] : vector<4xi32> into vector<[8]xi32>
+  %0 = vector.scalable.insert %sub0, %sv[12] : vector<4xi32> into vector<[8]xi32>
+  // CHECK-NEXT: vector.scalable.insert %[[SUB1]], %[[SV]][0] : vector<8xi32> into vector<[8]xi32>
+  %1 = vector.scalable.insert %sub1, %sv[0] : vector<8xi32> into vector<[8]xi32>
+  // CHECK-NEXT: vector.scalable.insert %[[SUB2]], %[[SV]][0] : vector<[4]xi32> into vector<[8]xi32>
+  %2 = vector.scalable.insert %sub2, %sv[0] : vector<[4]xi32> into vector<[8]xi32>
+  return
+ }
+
+// CHECK-LABEL: func @vector_scalable_extract(
+// CHECK-SAME: %[[SV:.*]]: vector<[8]xi32>
+func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
+  // CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<16xi32> from vector<[8]xi32>
+  %0 = vector.scalable.extract %sv[0] : vector<16xi32> from vector<[8]xi32>
+  // CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<[4]xi32> from vector<[8]xi32>
+  %1 = vector.scalable.extract %sv[0] : vector<[4]xi32> from vector<[8]xi32>
+  // CHECK-NEXT: vector.scalable.extract %[[SV]][4] : vector<4xi32> from vector<[8]xi32>
+  %2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
+  return
+ }