From: Benjamin Kramer Date: Fri, 18 Feb 2022 00:35:25 +0000 (+0100) Subject: [mlir][Vector] Switch ShuffleOp to the declarative assembly format X-Git-Tag: upstream/15.0.7~15983 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f0dd818be389e193cbe5817996e434696f37a702;p=platform%2Fupstream%2Fllvm.git [mlir][Vector] Switch ShuffleOp to the declarative assembly format This also requires implementing return type deduction. --- diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index c9370a0..1e16dbb 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -447,7 +447,8 @@ def Vector_ShuffleOp : PredOpTrait<"first operand v1 and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, PredOpTrait<"second operand v2 and result have same element type", - TCresVTEtIsSameAsOpBase<0, 1>>]>, + TCresVTEtIsSameAsOpBase<0, 1>>, + DeclareOpInterfaceMethods]>, Arguments<(ins AnyVector:$v1, AnyVector:$v2, I64ArrayAttr:$mask)>, Results<(outs AnyVector:$vector)> { let summary = "shuffle operation"; @@ -496,7 +497,7 @@ def Vector_ShuffleOp : return vector().getType().cast(); } }]; - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "operands $mask attr-dict `:` type(operands)"; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index fc472f8..5607464 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1723,19 +1723,7 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1, Value v2, ArrayRef mask) { - result.addOperands({v1, v2}); - auto maskAttr = getVectorSubscriptAttr(builder, mask); - auto v1Type = v1.getType().cast(); - auto shape = llvm::to_vector<4>(v1Type.getShape()); - shape[0] = mask.size(); - result.addTypes(VectorType::get(shape, v1Type.getElementType())); - result.addAttribute(getMaskAttrStrName(), maskAttr); -} - -void ShuffleOp::print(OpAsmPrinter &p) { - p << " " << v1() << ", " << v2() << " " << mask(); - p.printOptionalAttrDict((*this)->getAttrs(), {ShuffleOp::getMaskAttrName()}); - p << " : " << v1().getType() << ", " << v2().getType(); + build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask)); } LogicalResult ShuffleOp::verify() { @@ -1759,6 +1747,8 @@ LogicalResult ShuffleOp::verify() { // Verify mask length. auto maskAttr = mask().getValue(); int64_t maskLength = maskAttr.size(); + if (maskLength <= 0) + return emitOpError("invalid mask length"); if (maskLength != resultType.getDimSize(0)) return emitOpError("mask length mismatch"); // Verify all indices. @@ -1771,36 +1761,21 @@ LogicalResult ShuffleOp::verify() { return success(); } -ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType v1, v2; - Attribute attr; - VectorType v1Type, v2Type; - if (parser.parseOperand(v1) || parser.parseComma() || - parser.parseOperand(v2) || - parser.parseAttribute(attr, ShuffleOp::getMaskAttrStrName(), - result.attributes) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(v1Type) || parser.parseComma() || - parser.parseType(v2Type) || - parser.resolveOperand(v1, v1Type, result.operands) || - parser.resolveOperand(v2, v2Type, result.operands)) - return failure(); +LogicalResult +ShuffleOp::inferReturnTypes(MLIRContext *, Optional, + ValueRange operands, DictionaryAttr attributes, + RegionRange, + SmallVectorImpl &inferredReturnTypes) { + ShuffleOp::Adaptor op(operands, attributes); + auto v1Type = op.v1().getType().cast(); // Construct resulting type: leading dimension matches mask length, // all trailing dimensions match the operands. - auto maskAttr = attr.dyn_cast(); - if (!maskAttr) - return parser.emitError(parser.getNameLoc(), "missing mask attribute"); - int64_t maskLength = maskAttr.size(); - if (maskLength <= 0) - return parser.emitError(parser.getNameLoc(), "invalid mask length"); - int64_t v1Rank = v1Type.getRank(); SmallVector shape; - shape.reserve(v1Rank); - shape.push_back(maskLength); - for (int64_t r = 1; r < v1Rank; ++r) - shape.push_back(v1Type.getDimSize(r)); - VectorType resType = VectorType::get(shape, v1Type.getElementType()); - parser.addTypeToList(resType, result.types); + shape.reserve(v1Type.getRank()); + shape.push_back(std::max(1, op.mask().size())); + llvm::append_range(shape, v1Type.getShape().drop_front()); + inferredReturnTypes.push_back( + VectorType::get(shape, v1Type.getElementType())); return success(); } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 54697d1..bc75e0b 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -73,7 +73,7 @@ func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { // ----- func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { - // expected-error@+1 {{'vector.shuffle' invalid mask length}} + // expected-error@+1 {{'vector.shuffle' op invalid mask length}} %1 = vector.shuffle %arg0, %arg1 [] : vector<2xf32>, vector<2xf32> }