[mlir][Vector] Switch ShuffleOp to the declarative assembly format
authorBenjamin Kramer <benny.kra@googlemail.com>
Fri, 18 Feb 2022 00:35:25 +0000 (01:35 +0100)
committerBenjamin Kramer <benny.kra@googlemail.com>
Fri, 18 Feb 2022 00:46:58 +0000 (01:46 +0100)
This also requires implementing return type deduction.

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir

index c9370a0..1e16dbb 100644 (file)
@@ -447,7 +447,8 @@ def Vector_ShuffleOp :
      PredOpTrait<"first operand v1 and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      PredOpTrait<"second operand v2 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 1>>]>,
+                 TCresVTEtIsSameAsOpBase<0, 1>>,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
      Arguments<(ins AnyVector:$v1, AnyVector:$v2, I64ArrayAttr:$mask)>,
      Results<(outs AnyVector:$vector)> {
   let summary = "shuffle operation";
@@ -496,7 +497,7 @@ def Vector_ShuffleOp :
       return vector().getType().cast<VectorType>();
     }
   }];
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
   let hasVerifier = 1;
 }
 
index fc472f8..5607464 100644 (file)
@@ -1723,19 +1723,7 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
                       Value v2, ArrayRef<int64_t> mask) {
-  result.addOperands({v1, v2});
-  auto maskAttr = getVectorSubscriptAttr(builder, mask);
-  auto v1Type = v1.getType().cast<VectorType>();
-  auto shape = llvm::to_vector<4>(v1Type.getShape());
-  shape[0] = mask.size();
-  result.addTypes(VectorType::get(shape, v1Type.getElementType()));
-  result.addAttribute(getMaskAttrStrName(), maskAttr);
-}
-
-void ShuffleOp::print(OpAsmPrinter &p) {
-  p << " " << v1() << ", " << v2() << " " << mask();
-  p.printOptionalAttrDict((*this)->getAttrs(), {ShuffleOp::getMaskAttrName()});
-  p << " : " << v1().getType() << ", " << v2().getType();
+  build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
 }
 
 LogicalResult ShuffleOp::verify() {
@@ -1759,6 +1747,8 @@ LogicalResult ShuffleOp::verify() {
   // Verify mask length.
   auto maskAttr = mask().getValue();
   int64_t maskLength = maskAttr.size();
+  if (maskLength <= 0)
+    return emitOpError("invalid mask length");
   if (maskLength != resultType.getDimSize(0))
     return emitOpError("mask length mismatch");
   // Verify all indices.
@@ -1771,36 +1761,21 @@ LogicalResult ShuffleOp::verify() {
   return success();
 }
 
-ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::OperandType v1, v2;
-  Attribute attr;
-  VectorType v1Type, v2Type;
-  if (parser.parseOperand(v1) || parser.parseComma() ||
-      parser.parseOperand(v2) ||
-      parser.parseAttribute(attr, ShuffleOp::getMaskAttrStrName(),
-                            result.attributes) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(v1Type) || parser.parseComma() ||
-      parser.parseType(v2Type) ||
-      parser.resolveOperand(v1, v1Type, result.operands) ||
-      parser.resolveOperand(v2, v2Type, result.operands))
-    return failure();
+LogicalResult
+ShuffleOp::inferReturnTypes(MLIRContext *, Optional<Location>,
+                            ValueRange operands, DictionaryAttr attributes,
+                            RegionRange,
+                            SmallVectorImpl<Type> &inferredReturnTypes) {
+  ShuffleOp::Adaptor op(operands, attributes);
+  auto v1Type = op.v1().getType().cast<VectorType>();
   // Construct resulting type: leading dimension matches mask length,
   // all trailing dimensions match the operands.
-  auto maskAttr = attr.dyn_cast<ArrayAttr>();
-  if (!maskAttr)
-    return parser.emitError(parser.getNameLoc(), "missing mask attribute");
-  int64_t maskLength = maskAttr.size();
-  if (maskLength <= 0)
-    return parser.emitError(parser.getNameLoc(), "invalid mask length");
-  int64_t v1Rank = v1Type.getRank();
   SmallVector<int64_t, 4> shape;
-  shape.reserve(v1Rank);
-  shape.push_back(maskLength);
-  for (int64_t r = 1; r < v1Rank; ++r)
-    shape.push_back(v1Type.getDimSize(r));
-  VectorType resType = VectorType::get(shape, v1Type.getElementType());
-  parser.addTypeToList(resType, result.types);
+  shape.reserve(v1Type.getRank());
+  shape.push_back(std::max<size_t>(1, op.mask().size()));
+  llvm::append_range(shape, v1Type.getShape().drop_front());
+  inferredReturnTypes.push_back(
+      VectorType::get(shape, v1Type.getElementType()));
   return success();
 }
 
index 54697d1..bc75e0b 100644 (file)
@@ -73,7 +73,7 @@ func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
 // -----
 
 func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
-  // expected-error@+1 {{'vector.shuffle' invalid mask length}}
+  // expected-error@+1 {{'vector.shuffle' op invalid mask length}}
   %1 = vector.shuffle %arg0, %arg1 [] : vector<2xf32>, vector<2xf32>
 }