%3 = vector.extract_slices %2, [2, 2], [1, 1]
: vector<4x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>,
- vector<2x2xf32>, vector<2x2xf32>>
+ vector<2x2xf32>, vector<2x1xf32>>
```
}];
let builders = [OpBuilder<
}];
}
+def Vector_InsertSlicesOp :
+ Vector_Op<"insert_slices", [NoSideEffect]>,
+ Arguments<(ins TupleOf<[AnyVector]>:$vectors, I64ArrayAttr:$sizes,
+ I64ArrayAttr:$strides)>,
+ Results<(outs AnyVector)> {
+ let summary = "vector insert slices operation";
+ let description = [{
+ Takes a tuple of vector slices and inserts them into the vector result
+ according to the 'sizes' and 'strides' parameters.
+
+ The arguments 'sizes' and 'strides' represent a specification for
+ generating the unrolling of 'vector' shape, which has all slices of shape
+ 'sizes' except for slices at dimension boundaries when 'vector' dimension
+ sizes are not a multiple of 'sizes'.
+
+ Each slice in 'vectors' is at the tuple element index corresponding to the
+ linear index of the slice w.r.t the unrolling scheme represented by 'sizes'.
+ Currently, only unit strides are supported.
+
+ Examples:
+ ```
+ %0 = vector.extract_slices %0, [2, 2], [1, 1]
+ : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
+
+ %1 = vector.insert_slices %0, [2, 2], [1, 1]
+ : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
+
+ // Example with partial slices at dimension boundaries.
+ %3 = vector.extract_slices %2, [2, 2], [1, 1]
+ : vector<4x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>,
+ vector<2x2xf32>, vector<2x1xf32>>
+
+ %4 = vector.insert_slices %3, [2, 2], [1, 1]
+ : tuple<vector<2x2xf32>, vector<2x1xf32>,
+ vector<2x2xf32>, vector<2x1xf32>> into vector<4x3xf32>
+ ```
+ }];
+
+ let extraClassDeclaration = [{
+ TupleType getSourceTupleType() {
+ return vectors()->getType().cast<TupleType>();
+ }
+ VectorType getResultVectorType() {
+ return getResult()->getType().cast<VectorType>();
+ }
+ void getSizes(SmallVectorImpl<int64_t> &results);
+ void getStrides(SmallVectorImpl<int64_t> &results);
+ static StringRef getSizesAttrName() { return "sizes"; }
+ static StringRef getStridesAttrName() { return "strides"; }
+ }];
+}
+
def Vector_InsertStridedSliceOp :
Vector_Op<"insert_strided_slice", [NoSideEffect,
PredOpTrait<"operand #0 and result have same element type",
}
//===----------------------------------------------------------------------===//
+// InsertSlicesOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseInsertSlicesOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType operandInfo;
+ ArrayAttr sizesAttr;
+ StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName();
+ ArrayAttr stridesAttr;
+ StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName();
+ TupleType tupleType;
+ VectorType resultVectorType;
+ return failure(
+ parser.parseOperand(operandInfo) || parser.parseComma() ||
+ parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
+ parser.parseComma() ||
+ parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(tupleType) ||
+ parser.parseKeywordType("into", resultVectorType) ||
+ parser.resolveOperand(operandInfo, tupleType, result.operands) ||
+ parser.addTypeToList(resultVectorType, result.types));
+}
+
+static void print(OpAsmPrinter &p, InsertSlicesOp op) {
+ p << op.getOperationName() << ' ' << *op.vectors() << ", ";
+ p << op.sizes() << ", " << op.strides();
+ p.printOptionalAttrDict(
+ op.getAttrs(),
+ /*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(),
+ InsertSlicesOp::getStridesAttrName()});
+ p << " : " << op.vectors()->getType();
+ p << " into " << op.getResultVectorType();
+}
+
+static LogicalResult verify(InsertSlicesOp op) {
+ SmallVector<int64_t, 4> sizes;
+ op.getSizes(sizes);
+ SmallVector<int64_t, 4> strides;
+ op.getStrides(strides);
+ return isValidExtractOrInsertSlicesType(
+ op.getOperation(), op.getResultVectorType(), op.getSourceTupleType(),
+ sizes, strides);
+}
+
+void InsertSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
+ populateFromInt64AttrArray(sizes(), results);
+}
+
+void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
+ populateFromInt64AttrArray(strides(), results);
+}
+
+//===----------------------------------------------------------------------===//
// InsertStridedSliceOp
//===----------------------------------------------------------------------===//
return
}
+// -----
+
+func @insert_slices_non_unit_strides(%arg0 : tuple<vector<2x2xf32>, vector<2x2xf32>>) {
+ // expected-error@+1 {{requires unit strides}}
+ %0 = vector.insert_slices %arg0, [2, 2], [1, 3]
+ : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
+ return
+}
+
+// -----
+
+func @insert_slices_tuple_element_wrong_rank(%arg0 : tuple<vector<2x2xf32>, vector<2x2x3xf32>>) {
+ // expected-error@+1 {{requires vector tuple elements of rank 2}}
+ %0 = vector.insert_slices %arg0, [2, 2], [1, 1]
+ : tuple<vector<2x2xf32>, vector<2x2x3xf32>> into vector<4x2xf32>
+ return
+}
+// -----
+
+func @insert_slices_sizes_strides_wrong_rank(%arg0 : tuple<vector<2x2xf32>, vector<2x2xf32>>) {
+ // expected-error@+1 {{requires sizes and strides of rank}}
+ %0 = vector.insert_slices %arg0, [2, 2], [1, 1, 1]
+ : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
+ return
+}
+
+// -----
+
+func @insert_slices_invalid_tuple_element_type(%arg0 : tuple<vector<2x2xf32>, vector<4x2xf32>>) {
+ // expected-error@+1 {{invalid tuple element type}}
+ %0 = vector.insert_slices %arg0, [2, 2], [1, 1]
+ : tuple<vector<2x2xf32>, vector<4x2xf32>> into vector<4x2xf32>
+ return
+}
%3 = vector.tuple %1, %2 : vector<2x2xf32>, vector<2x2xf32>
return %3 : tuple<vector<2x2xf32>, vector<2x2xf32>>
}
+
+// CHECK-LABEL: insert_slices
+func @insert_slices(%arg0 : tuple<vector<2x2xf32>, vector<2x2xf32>>)
+ -> (vector<4x2xf32>) {
+ // CHECK: vector.insert_slices %{{.*}}, [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
+ %0 = vector.insert_slices %arg0, [2, 2], [1, 1]
+ : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
+ return %0 : vector<4x2xf32>
+}