PredOpTrait<"first operand v1 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"second operand v2 and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 1>>]>,
+ TCresVTEtIsSameAsOpBase<0, 1>>,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnyVector:$v1, AnyVector:$v2, I64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
let summary = "shuffle operation";
return vector().getType().cast<VectorType>();
}
}];
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
let hasVerifier = 1;
}
void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
Value v2, ArrayRef<int64_t> mask) {
- result.addOperands({v1, v2});
- auto maskAttr = getVectorSubscriptAttr(builder, mask);
- auto v1Type = v1.getType().cast<VectorType>();
- auto shape = llvm::to_vector<4>(v1Type.getShape());
- shape[0] = mask.size();
- result.addTypes(VectorType::get(shape, v1Type.getElementType()));
- result.addAttribute(getMaskAttrStrName(), maskAttr);
-}
-
-void ShuffleOp::print(OpAsmPrinter &p) {
- p << " " << v1() << ", " << v2() << " " << mask();
- p.printOptionalAttrDict((*this)->getAttrs(), {ShuffleOp::getMaskAttrName()});
- p << " : " << v1().getType() << ", " << v2().getType();
+ build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
}
LogicalResult ShuffleOp::verify() {
// Verify mask length.
auto maskAttr = mask().getValue();
int64_t maskLength = maskAttr.size();
+ if (maskLength <= 0)
+ return emitOpError("invalid mask length");
if (maskLength != resultType.getDimSize(0))
return emitOpError("mask length mismatch");
// Verify all indices.
return success();
}
-ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType v1, v2;
- Attribute attr;
- VectorType v1Type, v2Type;
- if (parser.parseOperand(v1) || parser.parseComma() ||
- parser.parseOperand(v2) ||
- parser.parseAttribute(attr, ShuffleOp::getMaskAttrStrName(),
- result.attributes) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(v1Type) || parser.parseComma() ||
- parser.parseType(v2Type) ||
- parser.resolveOperand(v1, v1Type, result.operands) ||
- parser.resolveOperand(v2, v2Type, result.operands))
- return failure();
+LogicalResult
+ShuffleOp::inferReturnTypes(MLIRContext *, Optional<Location>,
+ ValueRange operands, DictionaryAttr attributes,
+ RegionRange,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ ShuffleOp::Adaptor op(operands, attributes);
+ auto v1Type = op.v1().getType().cast<VectorType>();
// Construct resulting type: leading dimension matches mask length,
// all trailing dimensions match the operands.
- auto maskAttr = attr.dyn_cast<ArrayAttr>();
- if (!maskAttr)
- return parser.emitError(parser.getNameLoc(), "missing mask attribute");
- int64_t maskLength = maskAttr.size();
- if (maskLength <= 0)
- return parser.emitError(parser.getNameLoc(), "invalid mask length");
- int64_t v1Rank = v1Type.getRank();
SmallVector<int64_t, 4> shape;
- shape.reserve(v1Rank);
- shape.push_back(maskLength);
- for (int64_t r = 1; r < v1Rank; ++r)
- shape.push_back(v1Type.getDimSize(r));
- VectorType resType = VectorType::get(shape, v1Type.getElementType());
- parser.addTypeToList(resType, result.types);
+ shape.reserve(v1Type.getRank());
+ shape.push_back(std::max<size_t>(1, op.mask().size()));
+ llvm::append_range(shape, v1Type.getShape().drop_front());
+ inferredReturnTypes.push_back(
+ VectorType::get(shape, v1Type.getElementType()));
return success();
}