let hasVerifier = 1;
}
+def Vector_ScalableInsertOp :
+ Vector_Op<"scalable.insert", [Pure,
+ AllElementTypesMatch<["source", "dest"]>,
+ AllTypesMatch<["dest", "res"]>,
+ PredOpTrait<"position is a multiple of the source length.",
+ CPred<
+ "(getPos() % getSourceVectorType().getNumElements()) == 0"
+ >>]>,
+ Arguments<(ins VectorOfRank<[1]>:$source,
+ ScalableVectorOfRank<[1]>:$dest,
+ I64Attr:$pos)>,
+ Results<(outs ScalableVectorOfRank<[1]>:$res)> {
+ let summary = "insert subvector into scalable vector operation";
+ // NOTE: This operation is designed to map to `llvm.vector.insert`, and its
+ // documentation should be kept aligned with LLVM IR:
+ // https://llvm.org/docs/LangRef.html#llvm-vector-insert-intrinsic
+ let description = [{
+ This operations takes a rank-1 fixed-length or scalable subvector and
+ inserts it within the destination scalable vector starting from the
+ position specificed by `pos`. If the source vector is scalable, the
+ insertion position will be scaled by the runtime scaling factor of the
+ source subvector.
+
+ The insertion position must be a multiple of the minimum size of the source
+ vector. For the operation to be well defined, the source vector must fit in
+ the destination vector from the specified position. Since the destination
+ vector is scalable and its runtime length is unknown, the validity of the
+ operation can't be verified nor guaranteed at compile time.
+
+ Example:
+
+ ```mlir
+ %2 = vector.scalable.insert %0, %1[8] : vector<4xf32> into vector<[16]xf32>
+ %5 = vector.scalable.insert %3, %4[0] : vector<8xf32> into vector<[4]xf32>
+ %8 = vector.scalable.insert %6, %7[0] : vector<[4]xf32> into vector<[8]xf32>
+ ```
+
+ Invalid example:
+ ```mlir
+ %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getSourceVectorType() {
+ return getSource().getType().cast<VectorType>();
+ }
+ VectorType getDestVectorType() {
+ return getDest().getType().cast<VectorType>();
+ }
+ }];
+}
+
+def Vector_ScalableExtractOp :
+ Vector_Op<"scalable.extract", [Pure,
+ AllElementTypesMatch<["source", "res"]>,
+ PredOpTrait<"position is a multiple of the result length.",
+ CPred<
+ "(getPos() % getResultVectorType().getNumElements()) == 0"
+ >>]>,
+ Arguments<(ins ScalableVectorOfRank<[1]>:$source,
+ I64Attr:$pos)>,
+ Results<(outs VectorOfRank<[1]>:$res)> {
+ let summary = "extract subvector from scalable vector operation";
+ // NOTE: This operation is designed to map to `llvm.vector.extract`, and its
+ // documentation should be kept aligned with LLVM IR:
+ // https://llvm.org/docs/LangRef.html#llvm-vector-extract-intrinsic
+ let description = [{
+ Takes rank-1 source vector and a position `pos` within the source
+ vector, and extracts a subvector starting from that position.
+
+ The extraction position must be a multiple of the minimum size of the result
+ vector. For the operation to be well defined, the destination vector must
+ fit within the source vector from the specified position. Since the source
+ vector is scalable and its runtime length is unknown, the validity of the
+ operation can't be verified nor guaranteed at compile time.
+
+ Example:
+
+ ```mlir
+ %1 = vector.scalable.extract %0[8] : vector<4xf32> from vector<[8]xf32>
+ %3 = vector.scalable.extract %2[0] : vector<[4]xf32> from vector<[8]xf32>
+ ```
+
+ Invalid example:
+ ```mlir
+ %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ $source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getSourceVectorType() {
+ return getSource().getType().cast<VectorType>();
+ }
+ VectorType getResultVectorType() {
+ return getRes().getType().cast<VectorType>();
+ }
+ }];
+}
+
def Vector_InsertStridedSliceOp :
Vector_Op<"insert_strided_slice", [Pure,
PredOpTrait<"operand #0 and result have same element type",
== }]
# allowedlength>)>]>;
+// Whether the number of elements of a fixed-length vector is from the given
+// `allowedRanks` list
+class IsFixedVectorOfRankPred<list<int> allowedRanks> :
+ And<[IsFixedVectorTypePred,
+ Or<!foreach(allowedlength, allowedRanks,
+ CPred<[{$_self.cast<::mlir::VectorType>().getRank()
+ == }]
+ # allowedlength>)>]>;
+
+// Whether the number of elements of a scalable vector is from the given
+// `allowedRanks` list
+class IsScalableVectorOfRankPred<list<int> allowedRanks> :
+ And<[IsScalableVectorTypePred,
+ Or<!foreach(allowedlength, allowedRanks,
+ CPred<[{$_self.cast<::mlir::VectorType>().getRank()
+ == }]
+ # allowedlength>)>]>;
+
// Any vector where the rank is from the given `allowedRanks` list
class VectorOfRank<list<int> allowedRanks> : Type<
IsVectorOfRankPred<allowedRanks>,
" of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
+// Any fixed-length vector where the rank is from the given `allowedRanks` list
+class FixedVectorOfRank<list<int> allowedRanks> : Type<
+ IsFixedVectorOfRankPred<allowedRanks>,
+ " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
+
+// Any scalable vector where the rank is from the given `allowedRanks` list
+class ScalableVectorOfRank<list<int> allowedRanks> : Type<
+ IsScalableVectorOfRankPred<allowedRanks>,
+ " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
+
// Any vector where the rank is from the given `allowedRanks` list and the type
// is from the given `allowedTypes` list
class VectorOfRankAndType<list<int> allowedRanks,
}
};
+/// Lower vector.scalable.insert ops to LLVM vector.insert
+struct VectorScalableInsertOpLowering
+ : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
+ using ConvertOpToLLVMPattern<
+ vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
+ insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
+ return success();
+ }
+};
+
+/// Lower vector.scalable.extract ops to LLVM vector.extract
+struct VectorScalableExtractOpLowering
+ : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
+ using ConvertOpToLLVMPattern<
+ vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
+ extOp, typeConverter->convertType(extOp.getResultVectorType()),
+ adaptor.getSource(), adaptor.getPos());
+ return success();
+ }
+};
+
/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
///
/// Example:
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
+ VectorSplatOpLowering, VectorSplatNdOpLowering,
+ VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>(
+ converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0, 0, 0, 0]
// CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32>
// CHECK-NEXT: return %[[SCALE]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: @vector_scalable_insert
+// CHECK-SAME: %[[SUB:.*]]: vector<4xf32>, %[[SV:.*]]: vector<[4]xf32>
+func.func @vector_scalable_insert(%sub: vector<4xf32>, %dsv: vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK-NEXT: %[[TMP:.*]] = llvm.intr.vector.insert %[[SUB]], %[[SV]][0] : vector<4xf32> into vector<[4]xf32>
+ %0 = vector.scalable.insert %sub, %dsv[0] : vector<4xf32> into vector<[4]xf32>
+ // CHECK-NEXT: llvm.intr.vector.insert %[[SUB]], %[[TMP]][4] : vector<4xf32> into vector<[4]xf32>
+ %1 = vector.scalable.insert %sub, %0[4] : vector<4xf32> into vector<[4]xf32>
+ return %1 : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_scalable_extract
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> {
+ // CHECK-NEXT: %{{.*}} = llvm.intr.vector.extract %[[VEC]][0] : vector<8xf32> from vector<[4]xf32>
+ %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32>
+ return %0 : vector<8xf32>
+}
return
}
+// -----
+
+func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
+ // expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
+ %0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
+}
+
+// -----
+
+func.func @vector_scalable_extract_unaligned(%vec: vector<[16]xf32>) {
+ // expected-error@+1 {{op failed to verify that position is a multiple of the result length.}}
+ %0 = vector.scalable.extract %vec[5] : vector<4xf32> from vector<[16]xf32>
+}
return
}
+// CHECK-LABEL: func @vector_scalable_insert(
+// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
+// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>
+func.func @vector_scalable_insert(%sub0: vector<4xi32>, %sub1: vector<8xi32>,
+ %sub2: vector<[4]xi32>, %sv: vector<[8]xi32>) {
+ // CHECK-NEXT: vector.scalable.insert %[[SUB0]], %[[SV]][12] : vector<4xi32> into vector<[8]xi32>
+ %0 = vector.scalable.insert %sub0, %sv[12] : vector<4xi32> into vector<[8]xi32>
+ // CHECK-NEXT: vector.scalable.insert %[[SUB1]], %[[SV]][0] : vector<8xi32> into vector<[8]xi32>
+ %1 = vector.scalable.insert %sub1, %sv[0] : vector<8xi32> into vector<[8]xi32>
+ // CHECK-NEXT: vector.scalable.insert %[[SUB2]], %[[SV]][0] : vector<[4]xi32> into vector<[8]xi32>
+ %2 = vector.scalable.insert %sub2, %sv[0] : vector<[4]xi32> into vector<[8]xi32>
+ return
+ }
+
+// CHECK-LABEL: func @vector_scalable_extract(
+// CHECK-SAME: %[[SV:.*]]: vector<[8]xi32>
+func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
+ // CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<16xi32> from vector<[8]xi32>
+ %0 = vector.scalable.extract %sv[0] : vector<16xi32> from vector<[8]xi32>
+ // CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<[4]xi32> from vector<[8]xi32>
+ %1 = vector.scalable.extract %sv[0] : vector<[4]xi32> from vector<[8]xi32>
+ // CHECK-NEXT: vector.scalable.extract %[[SV]][4] : vector<4xi32> from vector<[8]xi32>
+ %2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
+ return
+ }