From: Andy Davis Date: Mon, 16 Dec 2019 20:56:06 +0000 (-0800) Subject: Add InsertSlicesOp to the VectorOps dialect. X-Git-Tag: llvmorg-11-init~1466^2~61 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=11e92875f07261c64205c8b72038abf0d65729a0;p=platform%2Fupstream%2Fllvm.git Add InsertSlicesOp to the VectorOps dialect. PiperOrigin-RevId: 285830394 --- diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index eb05821..50bf581 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -352,7 +352,7 @@ def Vector_ExtractSlicesOp : %3 = vector.extract_slices %2, [2, 2], [1, 1] : vector<4x3xf32> into tuple, vector<2x1xf32>, - vector<2x2xf32>, vector<2x2xf32>> + vector<2x2xf32>, vector<2x1xf32>> ``` }]; let builders = [OpBuilder< @@ -439,6 +439,58 @@ def Vector_InsertOp : }]; } +def Vector_InsertSlicesOp : + Vector_Op<"insert_slices", [NoSideEffect]>, + Arguments<(ins TupleOf<[AnyVector]>:$vectors, I64ArrayAttr:$sizes, + I64ArrayAttr:$strides)>, + Results<(outs AnyVector)> { + let summary = "vector insert slices operation"; + let description = [{ + Takes a tuple of vector slices and inserts them into the vector result + according to the 'sizes' and 'strides' parameters. + + The arguments 'sizes' and 'strides' represent a specification for + generating the unrolling of 'vector' shape, which has all slices of shape + 'sizes' except for slices at dimension boundaries when 'vector' dimension + sizes are not a multiple of 'sizes'. + + Each slice in 'vectors' is at the tuple element index corresponding to the + linear index of the slice w.r.t the unrolling scheme represented by 'sizes'. + Currently, only unit strides are supported. + + Examples: + ``` + %0 = vector.extract_slices %0, [2, 2], [1, 1] + : vector<4x2xf32> into tuple, vector<2x2xf32>> + + %1 = vector.insert_slices %0, [2, 2], [1, 1] + : tuple, vector<2x2xf32>> into vector<4x2xf32> + + // Example with partial slices at dimension boundaries. + %3 = vector.extract_slices %2, [2, 2], [1, 1] + : vector<4x3xf32> into tuple, vector<2x1xf32>, + vector<2x2xf32>, vector<2x1xf32>> + + %4 = vector.insert_slices %3, [2, 2], [1, 1] + : tuple, vector<2x1xf32>, + vector<2x2xf32>, vector<2x1xf32>> into vector<4x3xf32> + ``` + }]; + + let extraClassDeclaration = [{ + TupleType getSourceTupleType() { + return vectors()->getType().cast(); + } + VectorType getResultVectorType() { + return getResult()->getType().cast(); + } + void getSizes(SmallVectorImpl &results); + void getStrides(SmallVectorImpl &results); + static StringRef getSizesAttrName() { return "sizes"; } + static StringRef getStridesAttrName() { return "strides"; } + }]; +} + def Vector_InsertStridedSliceOp : Vector_Op<"insert_strided_slice", [NoSideEffect, PredOpTrait<"operand #0 and result have same element type", diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index fc8abd7..48fc0d4 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -826,6 +826,60 @@ static LogicalResult verify(InsertOp op) { } //===----------------------------------------------------------------------===// +// InsertSlicesOp +//===----------------------------------------------------------------------===// + +static ParseResult parseInsertSlicesOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType operandInfo; + ArrayAttr sizesAttr; + StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName(); + ArrayAttr stridesAttr; + StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName(); + TupleType tupleType; + VectorType resultVectorType; + return failure( + parser.parseOperand(operandInfo) || parser.parseComma() || + parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) || + parser.parseComma() || + parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(tupleType) || + parser.parseKeywordType("into", resultVectorType) || + parser.resolveOperand(operandInfo, tupleType, result.operands) || + parser.addTypeToList(resultVectorType, result.types)); +} + +static void print(OpAsmPrinter &p, InsertSlicesOp op) { + p << op.getOperationName() << ' ' << *op.vectors() << ", "; + p << op.sizes() << ", " << op.strides(); + p.printOptionalAttrDict( + op.getAttrs(), + /*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(), + InsertSlicesOp::getStridesAttrName()}); + p << " : " << op.vectors()->getType(); + p << " into " << op.getResultVectorType(); +} + +static LogicalResult verify(InsertSlicesOp op) { + SmallVector sizes; + op.getSizes(sizes); + SmallVector strides; + op.getStrides(strides); + return isValidExtractOrInsertSlicesType( + op.getOperation(), op.getResultVectorType(), op.getSourceTupleType(), + sizes, strides); +} + +void InsertSlicesOp::getSizes(SmallVectorImpl &results) { + populateFromInt64AttrArray(sizes(), results); +} + +void InsertSlicesOp::getStrides(SmallVectorImpl &results) { + populateFromInt64AttrArray(strides(), results); +} + +//===----------------------------------------------------------------------===// // InsertStridedSliceOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index c04c8ea..3c2dd60 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -783,4 +783,38 @@ func @tuple_get_of_non_vectors(%arg0 : tuple, index>) { return } +// ----- + +func @insert_slices_non_unit_strides(%arg0 : tuple, vector<2x2xf32>>) { + // expected-error@+1 {{requires unit strides}} + %0 = vector.insert_slices %arg0, [2, 2], [1, 3] + : tuple, vector<2x2xf32>> into vector<4x2xf32> + return +} + +// ----- + +func @insert_slices_tuple_element_wrong_rank(%arg0 : tuple, vector<2x2x3xf32>>) { + // expected-error@+1 {{requires vector tuple elements of rank 2}} + %0 = vector.insert_slices %arg0, [2, 2], [1, 1] + : tuple, vector<2x2x3xf32>> into vector<4x2xf32> + return +} +// ----- + +func @insert_slices_sizes_strides_wrong_rank(%arg0 : tuple, vector<2x2xf32>>) { + // expected-error@+1 {{requires sizes and strides of rank}} + %0 = vector.insert_slices %arg0, [2, 2], [1, 1, 1] + : tuple, vector<2x2xf32>> into vector<4x2xf32> + return +} + +// ----- + +func @insert_slices_invalid_tuple_element_type(%arg0 : tuple, vector<4x2xf32>>) { + // expected-error@+1 {{invalid tuple element type}} + %0 = vector.insert_slices %arg0, [2, 2], [1, 1] + : tuple, vector<4x2xf32>> into vector<4x2xf32> + return +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index 69af80f..f1db45e 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -189,3 +189,12 @@ func @extract_slices(%arg0 : vector<4x2xf32>) %3 = vector.tuple %1, %2 : vector<2x2xf32>, vector<2x2xf32> return %3 : tuple, vector<2x2xf32>> } + +// CHECK-LABEL: insert_slices +func @insert_slices(%arg0 : tuple, vector<2x2xf32>>) + -> (vector<4x2xf32>) { + // CHECK: vector.insert_slices %{{.*}}, [2, 2], [1, 1] : tuple, vector<2x2xf32>> into vector<4x2xf32> + %0 = vector.insert_slices %arg0, [2, 2], [1, 1] + : tuple, vector<2x2xf32>> into vector<4x2xf32> + return %0 : vector<4x2xf32> +}