[mlir][Vector] Switch ExtractOp to the declarative assembly format
authorBenjamin Kramer <benny.kra@googlemail.com>
Fri, 18 Feb 2022 10:39:48 +0000 (11:39 +0100)
committerBenjamin Kramer <benny.kra@googlemail.com>
Fri, 18 Feb 2022 10:45:59 +0000 (11:45 +0100)
This is a bit awkward since ExtractOp allows both `f32` and
`vector<1xf32>` results for a scalar extraction. Allow both, but make
inference return the scalar to make this as NFC as possible.

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir

index 1e16dbb..4a20ea0 100644 (file)
@@ -551,7 +551,8 @@ def Vector_ExtractElementOp :
 def Vector_ExtractOp :
   Vector_Op<"extract", [NoSideEffect,
      PredOpTrait<"operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+                 TCresVTEtIsSameAsOpBase<0, 0>>,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnyVector:$vector, I64ArrayAttr:$position)>,
     Results<(outs AnyType)> {
   let summary = "extract operation";
@@ -577,9 +578,10 @@ def Vector_ExtractOp :
     VectorType getVectorType() {
       return vector().getType().cast<VectorType>();
     }
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
   }];
+  let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
   let hasCanonicalizer = 1;
-  let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 }
index 5607464..4ffb2b8 100644 (file)
@@ -940,21 +940,9 @@ LogicalResult vector::ExtractElementOp::verify() {
 // ExtractOp
 //===----------------------------------------------------------------------===//
 
-static Type inferExtractOpResultType(VectorType vectorType,
-                                     ArrayAttr position) {
-  if (static_cast<int64_t>(position.size()) == vectorType.getRank())
-    return vectorType.getElementType();
-  return VectorType::get(vectorType.getShape().drop_front(position.size()),
-                         vectorType.getElementType());
-}
-
 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                               Value source, ArrayRef<int64_t> position) {
-  result.addOperands(source);
-  auto positionAttr = getVectorSubscriptAttr(builder, position);
-  result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
-                                           positionAttr));
-  result.addAttribute(getPositionAttrStrName(), positionAttr);
+  build(builder, result, source, getVectorSubscriptAttr(builder, position));
 }
 
 // Convenience builder which assumes the values are constant indices.
@@ -967,40 +955,34 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, source, positionConstants);
 }
 
-void vector::ExtractOp::print(OpAsmPrinter &p) {
-  p << " " << vector() << position();
-  p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
-  p << " : " << vector().getType();
+LogicalResult
+ExtractOp::inferReturnTypes(MLIRContext *, Optional<Location>,
+                            ValueRange operands, DictionaryAttr attributes,
+                            RegionRange,
+                            SmallVectorImpl<Type> &inferredReturnTypes) {
+  ExtractOp::Adaptor op(operands, attributes);
+  auto vectorType = op.vector().getType().cast<VectorType>();
+  if (static_cast<int64_t>(op.position().size()) == vectorType.getRank()) {
+    inferredReturnTypes.push_back(vectorType.getElementType());
+  } else {
+    auto n = std::min<size_t>(op.position().size(), vectorType.getRank() - 1);
+    inferredReturnTypes.push_back(VectorType::get(
+        vectorType.getShape().drop_front(n), vectorType.getElementType()));
+  }
+  return success();
 }
 
-ParseResult vector::ExtractOp::parse(OpAsmParser &parser,
-                                     OperationState &result) {
-  SMLoc attributeLoc, typeLoc;
-  NamedAttrList attrs;
-  OpAsmParser::OperandType vector;
-  Type type;
-  Attribute attr;
-  if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
-      parser.parseAttribute(attr, "position", attrs) ||
-      parser.parseOptionalAttrDict(attrs) ||
-      parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
-    return failure();
-
-  auto vectorType = type.dyn_cast<VectorType>();
-  if (!vectorType)
-    return parser.emitError(typeLoc, "expected vector type");
-
-  auto positionAttr = attr.dyn_cast<ArrayAttr>();
-  if (!positionAttr ||
-      static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
-    return parser.emitError(
-        attributeLoc,
-        "expected position attribute of rank smaller than vector rank");
-
-  Type resType = inferExtractOpResultType(vectorType, positionAttr);
-  result.attributes = attrs;
-  return failure(parser.resolveOperand(vector, type, result.operands) ||
-                 parser.addTypeToList(resType, result.types));
+bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  // Allow extracting 1-element vectors instead of scalars.
+  auto isCompatible = [](TypeRange l, TypeRange r) {
+    auto vectorType = l.front().dyn_cast<VectorType>();
+    return vectorType && vectorType.getShape().equals({1}) &&
+           vectorType.getElementType() == r.front();
+  };
+  if (l.size() == 1 && r.size() == 1 &&
+      (isCompatible(l, r) || isCompatible(r, l)))
+    return true;
+  return l == r;
 }
 
 LogicalResult vector::ExtractOp::verify() {
index bc75e0b..2e224f7 100644 (file)
@@ -104,7 +104,7 @@ func @extract_element(%arg0: vector<4x4xf32>) {
 // -----
 
 func @extract_vector_type(%arg0: index) {
-  // expected-error@+1 {{expected vector type}}
+  // expected-error@+1 {{invalid kind of type specified}}
   %1 = vector.extract %arg0[] : index
 }