def Vector_ExtractOp :
Vector_Op<"extract", [NoSideEffect,
PredOpTrait<"operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
+ TCresVTEtIsSameAsOpBase<0, 0>>,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$position)>,
Results<(outs AnyType)> {
let summary = "extract operation";
VectorType getVectorType() {
return vector().getType().cast<VectorType>();
}
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
+ let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
let hasCanonicalizer = 1;
- let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
// ExtractOp
//===----------------------------------------------------------------------===//
-static Type inferExtractOpResultType(VectorType vectorType,
- ArrayAttr position) {
- if (static_cast<int64_t>(position.size()) == vectorType.getRank())
- return vectorType.getElementType();
- return VectorType::get(vectorType.getShape().drop_front(position.size()),
- vectorType.getElementType());
-}
-
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, ArrayRef<int64_t> position) {
- result.addOperands(source);
- auto positionAttr = getVectorSubscriptAttr(builder, position);
- result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
- positionAttr));
- result.addAttribute(getPositionAttrStrName(), positionAttr);
+ build(builder, result, source, getVectorSubscriptAttr(builder, position));
}
// Convenience builder which assumes the values are constant indices.
build(builder, result, source, positionConstants);
}
-void vector::ExtractOp::print(OpAsmPrinter &p) {
- p << " " << vector() << position();
- p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
- p << " : " << vector().getType();
+LogicalResult
+ExtractOp::inferReturnTypes(MLIRContext *, Optional<Location>,
+ ValueRange operands, DictionaryAttr attributes,
+ RegionRange,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ ExtractOp::Adaptor op(operands, attributes);
+ auto vectorType = op.vector().getType().cast<VectorType>();
+ if (static_cast<int64_t>(op.position().size()) == vectorType.getRank()) {
+ inferredReturnTypes.push_back(vectorType.getElementType());
+ } else {
+ auto n = std::min<size_t>(op.position().size(), vectorType.getRank() - 1);
+ inferredReturnTypes.push_back(VectorType::get(
+ vectorType.getShape().drop_front(n), vectorType.getElementType()));
+ }
+ return success();
}
-ParseResult vector::ExtractOp::parse(OpAsmParser &parser,
- OperationState &result) {
- SMLoc attributeLoc, typeLoc;
- NamedAttrList attrs;
- OpAsmParser::OperandType vector;
- Type type;
- Attribute attr;
- if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
- parser.parseAttribute(attr, "position", attrs) ||
- parser.parseOptionalAttrDict(attrs) ||
- parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
- return failure();
-
- auto vectorType = type.dyn_cast<VectorType>();
- if (!vectorType)
- return parser.emitError(typeLoc, "expected vector type");
-
- auto positionAttr = attr.dyn_cast<ArrayAttr>();
- if (!positionAttr ||
- static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
- return parser.emitError(
- attributeLoc,
- "expected position attribute of rank smaller than vector rank");
-
- Type resType = inferExtractOpResultType(vectorType, positionAttr);
- result.attributes = attrs;
- return failure(parser.resolveOperand(vector, type, result.operands) ||
- parser.addTypeToList(resType, result.types));
+bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ // Allow extracting 1-element vectors instead of scalars.
+ auto isCompatible = [](TypeRange l, TypeRange r) {
+ auto vectorType = l.front().dyn_cast<VectorType>();
+ return vectorType && vectorType.getShape().equals({1}) &&
+ vectorType.getElementType() == r.front();
+ };
+ if (l.size() == 1 && r.size() == 1 &&
+ (isCompatible(l, r) || isCompatible(r, l)))
+ return true;
+ return l == r;
}
LogicalResult vector::ExtractOp::verify() {