From 1fe65688d42d1dacca528a871ac8de370043f793 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 9 Dec 2019 16:15:02 -0800 Subject: [PATCH] [VectorOps] Add a ShuffleOp to the VectorOps dialect For example %0 = vector.shuffle %x, %y [3 : i32, 2 : i32, 1 : i32, 0 : i32] : vector<2xf32>, vector<2xf32> yields a vector<4xf32> result with a permutation of the elements of %x and %y PiperOrigin-RevId: 284657191 --- mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 53 +++++++++++++ mlir/lib/Dialect/VectorOps/VectorOps.cpp | 86 ++++++++++++++++++++++ .../Conversion/VectorToLLVM/vector-to-llvm.mlir | 10 +-- mlir/test/Dialect/VectorOps/invalid.mlir | 35 +++++++++ mlir/test/Dialect/VectorOps/ops.mlir | 46 ++++++++---- 5 files changed, 211 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index df0e5e9..1e84010 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -214,6 +214,59 @@ def Vector_BroadcastOp : }]; } +def Vector_ShuffleOp : + Vector_Op<"shuffle", [NoSideEffect, + PredOpTrait<"first operand v1 and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"second operand v2 and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyVector:$v1, AnyVector:$v2, I32ArrayAttr:$mask)>, + Results<(outs AnyVector:$vector)> { + let summary = "shuffle operation"; + let description = [{ + The shuffle operation constructs a permutation (or duplication) of elements + from two input vectors, returning a vector with the same element type as + the input and a length that is the same as the shuffle mask. The two input + vectors must have the same element type, rank, and trailing dimension sizes + and shuffles their values in the leading dimension (which may differ in size) + according to the given mask. The legality rules are: + * the two operands must have the same element type as the result + * the two operands and the result must have the same rank and trailing + dimension sizes, viz. given two k-D operands + v1 : and + v2 : + we have s_i = t_i for all 1 < i <= k + * the mask length equals the leading dimension size of the result + * numbering the input vector indices left to right accross the operands, all + mask values must be within range, viz. given two k-D operands v1 and v2 + above, all mask values are in the range [0,s_1+t_1) + + Examples: + ``` + %0 = vector.shuffle %a, %b[0:i32, 3:i32] + : vector<2xf32>, vector<2xf32> ; yields vector<2xf32> + %1 = vector.shuffle %c, %b[0:i32, 1:i32, 2:i32] + : vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32> + %2 = vector.shuffle %a, %b[3:i32, 2:i32, 1:i32 : 0:i32] + : vector<2xf32>, vector<2xf32> ; yields vector<4xf32> + + ``` + }]; + let builders = [OpBuilder<"Builder *builder, OperationState &result, Value *v1, Value *v2, ArrayRef">]; + let extraClassDeclaration = [{ + static StringRef getMaskAttrName() { return "mask"; } + VectorType getV1VectorType() { + return v1()->getType().cast(); + } + VectorType getV2VectorType() { + return v2()->getType().cast(); + } + VectorType getVectorType() { + return vector()->getType().cast(); + } + }]; +} + def Vector_ExtractOp : Vector_Op<"extract", [NoSideEffect, PredOpTrait<"operand and result have same element type", diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 08bb8f4..7714623 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -459,6 +459,92 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser, } //===----------------------------------------------------------------------===// +// ShuffleOp +//===----------------------------------------------------------------------===// + +void ShuffleOp::build(Builder *builder, OperationState &result, Value *v1, + Value *v2, ArrayRef mask) { + result.addOperands({v1, v2}); + auto maskAttr = builder->getI32ArrayAttr(mask); + result.addTypes(v1->getType()); + result.addAttribute(getMaskAttrName(), maskAttr); +} + +static void print(OpAsmPrinter &p, ShuffleOp op) { + p << op.getOperationName() << " " << *op.v1() << ", " << *op.v2() << " " + << op.mask(); + p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()}); + p << " : " << op.v1()->getType() << ", " << op.v2()->getType(); +} + +static LogicalResult verify(ShuffleOp op) { + VectorType resultType = op.getVectorType(); + VectorType v1Type = op.getV1VectorType(); + VectorType v2Type = op.getV2VectorType(); + // Verify ranks. + int64_t resRank = resultType.getRank(); + int64_t v1Rank = v1Type.getRank(); + int64_t v2Rank = v2Type.getRank(); + if (resRank != v1Rank || v1Rank != v2Rank) + return op.emitOpError("rank mismatch"); + // Verify all but leading dimension sizes. + for (int64_t r = 1; r < v1Rank; ++r) { + int64_t resDim = resultType.getDimSize(r); + int64_t v1Dim = v1Type.getDimSize(r); + int64_t v2Dim = v2Type.getDimSize(r); + if (resDim != v1Dim || v1Dim != v2Dim) + return op.emitOpError("dimension mismatch"); + } + // Verify mask length. + auto maskAttr = op.mask().getValue(); + int64_t maskLength = maskAttr.size(); + if (maskLength != resultType.getDimSize(0)) + return op.emitOpError("mask length mismatch"); + // Verify all indices. + int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0); + for (auto en : llvm::enumerate(maskAttr)) { + auto attr = en.value().dyn_cast(); + if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) + return op.emitOpError("mask index #") + << (en.index() + 1) << " out of range"; + } + return success(); +} + +static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType v1, v2; + Attribute attr; + VectorType v1Type, v2Type; + if (parser.parseOperand(v1) || parser.parseComma() || + parser.parseOperand(v2) || + parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(), + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(v1Type) || parser.parseComma() || + parser.parseType(v2Type) || + parser.resolveOperand(v1, v1Type, result.operands) || + parser.resolveOperand(v2, v2Type, result.operands)) + return failure(); + // Construct resulting type: leading dimension matches mask length, + // all trailing dimensions match the operands. + auto maskAttr = attr.dyn_cast(); + if (!maskAttr) + return parser.emitError(parser.getNameLoc(), "missing mask attribute"); + int64_t maskLength = maskAttr.size(); + if (maskLength <= 0) + return parser.emitError(parser.getNameLoc(), "invalid mask length"); + int64_t v1Rank = v1Type.getRank(); + SmallVector shape; + shape.reserve(v1Rank); + shape.push_back(maskLength); + for (int64_t r = 1; r < v1Rank; ++r) + shape.push_back(v1Type.getDimSize(r)); + VectorType resType = VectorType::get(shape, v1Type.getElementType()); + parser.addTypeToList(resType, result.types); + return success(); +} + +//===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 8f66b44..0802799 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -235,18 +235,18 @@ func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> return %0 : vector<3x16xf32> } // CHECK-LABEL: extract_vec_2d_from_vec_3d -// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> -// CHECK: llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> +// CHECK: llvm.return {{.*}} : !llvm<"[3 x <16 x float>]"> func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 { %0 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32> return %0 : f32 } // CHECK-LABEL: extract_element_from_vec_3d -// CHECK: llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> +// CHECK: llvm.extractvalue {{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]"> // CHECK: llvm.mlir.constant(0 : i32) : !llvm.i32 -// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i32] : !llvm<"<16 x float>"> -// CHECK: llvm.return %{{.*}} : !llvm.float +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>"> +// CHECK: llvm.return {{.*}} : !llvm.float func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref> { %0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref> diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index d6aa291..4f56e94 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -31,6 +31,41 @@ func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) { // ----- +func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) { + // expected-error@+1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}} + %1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xi32> +} + +// ----- + +func @shuffle_rank_mismatch(%arg0: vector<2xf32>, %arg1: vector<4x2xf32>) { + // expected-error@+1 {{'vector.shuffle' op rank mismatch}} + %1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<4x2xf32> +} + +// ----- + +func @shuffle_trailing_dim_size_mismatch(%arg0: vector<2x2xf32>, %arg1: vector<2x4xf32>) { + // expected-error@+1 {{'vector.shuffle' op dimension mismatch}} + %1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2x2xf32>, vector<2x4xf32> +} + +// ----- + +func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { + // expected-error@+1 {{'vector.shuffle' op mask index #2 out of range}} + %1 = vector.shuffle %arg0, %arg1 [0 : i32, 4 : i32] : vector<2xf32>, vector<2xf32> +} + +// ----- + +func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { + // expected-error@+1 {{custom op 'vector.shuffle' invalid mask length}} + %1 = vector.shuffle %arg0, %arg1 [] : vector<2xf32>, vector<2xf32> +} + +// ----- + func @extract_vector_type(%arg0: index) { // expected-error@+1 {{expected vector type}} %1 = vector.extract %arg0[] : index diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index b98e749..a5bafb44 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -24,20 +24,38 @@ func @vector_transfer_ops(%arg0: memref) { // CHECK-LABEL: @vector_broadcast func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> { - // CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32> + // CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32> %0 = vector.broadcast %a : f32 to vector<16xf32> - // CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32> %1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32> - // CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32> %2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32> - // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32> %3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32> return %3 : vector<8x16xf32> } +// CHECK-LABEL: @shuffle1D +func @shuffle1D(%a: vector<2xf32>, %b: vector<4xf32>) -> vector<2xf32> { + // CHECK: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32, 3 : i32] : vector<2xf32>, vector<2xf32> + %1 = vector.shuffle %a, %a[0 : i32, 1 : i32, 2: i32, 3 : i32] : vector<2xf32>, vector<2xf32> + // CHECK-NEXT: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32] : vector<4xf32>, vector<4xf32> + %2 = vector.shuffle %1, %b[0 : i32, 1 : i32, 2 : i32] : vector<4xf32>, vector<4xf32> + // CHECK-NEXT: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 6 : i32] : vector<3xf32>, vector<4xf32> + %3 = vector.shuffle %2, %b[0 : i32, 6 : i32] : vector<3xf32>, vector<4xf32> + return %3 : vector<2xf32> +} + +// CHECK-LABEL: @shuffle2D +func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> { + // CHECK: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32] : vector<1x4xf32>, vector<2x4xf32> + %1 = vector.shuffle %a, %b[0 : i32, 1 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32> + return %1 : vector<3x4xf32> +} + // CHECK-LABEL: @extract func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) { - // CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32> + // CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32> %1 = vector.extract %arg0[3 : i32] : vector<4x8x16xf32> // CHECK-NEXT: vector.extract {{.*}}[3 : i32, 3 : i32] : vector<4x8x16xf32> %2 = vector.extract %arg0[3 : i32, 3 : i32] : vector<4x8x16xf32> @@ -47,35 +65,35 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f } // CHECK-LABEL: @insert -func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) { - // CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> +func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { + // CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> %1 = vector.insert %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> - // CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32> %2 = vector.insert %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32> - // CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32> + // CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32> %3 = vector.insert %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32> - return + return %3 : vector<4x8x16xf32> } // CHECK-LABEL: @outerproduct func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> { - // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32> + // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32> %0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32> - // CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32> + // CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32> %1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32> return %1 : vector<4x8xf32> } // CHECK-LABEL: @insert_strided_slice func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) { - // CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32> + // CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32> %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32> return } // CHECK-LABEL: @strided_slice func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> { - // CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> + // CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> %1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32> return %1: vector<2x2x16xf32> } -- 2.7.4