[mlir][Shape][NFC] Migrate shape dialect to the new fold API
authorMarkus Böck <markus.boeck02@gmail.com>
Tue, 10 Jan 2023 19:03:08 +0000 (20:03 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Thu, 12 Jan 2023 08:52:14 +0000 (09:52 +0100)
See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context

Changes are mostly mechanical in nature. The code nevertheless became more expressive in a lot of places thanks to the use of the new getters!

Differential Revision: https://reviews.llvm.org/D141501

mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
mlir/lib/Dialect/Shape/IR/Shape.cpp

index 9c02579..e13bf4b 100644 (file)
@@ -41,6 +41,7 @@ def ShapeDialect : Dialect {
   let useDefaultTypePrinterParser = 1;
   let hasConstantMaterializer = 1;
   let hasOperationAttrVerify = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 class Shape_Type<string name, string typeMnemonic> : TypeDef<ShapeDialect, name> {
index c34d509..f28cbca 100644 (file)
@@ -215,10 +215,10 @@ LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
 
 // TODO: Canonicalization should be implemented for shapes that can be
 // determined through mixtures of the known dimensions of the inputs.
-OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
   // Only the last operand is checked because AnyOp is commutative.
-  if (operands.back())
-    return operands.back();
+  if (adaptor.getInputs().back())
+    return adaptor.getInputs().back();
 
   return nullptr;
 }
@@ -410,13 +410,14 @@ bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
 }
 
-OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
   // add(x, 0) -> x
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
 
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
+      adaptor.getOperands(),
+      [](APInt a, const APInt &b) { return std::move(a) + b; });
 }
 
 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
@@ -604,11 +605,11 @@ void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
            RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
 }
 
-OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
   // Iterate in reverse to first handle all constant operands. They are
   // guaranteed to be the tail of the inputs because this is commutative.
-  for (int idx = operands.size() - 1; idx >= 0; idx--) {
-    Attribute a = operands[idx];
+  for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
+    Attribute a = adaptor.getInputs()[idx];
     // Cannot fold if any inputs are not constant;
     if (!a)
       return nullptr;
@@ -637,7 +638,7 @@ LogicalResult AssumingAllOp::verify() {
 // BroadcastOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
   if (getShapes().size() == 1) {
     // Otherwise, we need a cast which would be a canonicalization, not folding.
     if (getShapes().front().getType() != getType())
@@ -649,12 +650,12 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
   if (getShapes().size() > 2)
     return nullptr;
 
-  if (!operands[0] || !operands[1])
+  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
     return nullptr;
   auto lhsShape = llvm::to_vector<6>(
-      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+      adaptor.getShapes()[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
   auto rhsShape = llvm::to_vector<6>(
-      operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+      adaptor.getShapes()[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
   SmallVector<int64_t, 6> resultShape;
 
   // If the shapes are not compatible, we can't fold it.
@@ -847,13 +848,13 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 // ConcatOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
-  if (!operands[0] || !operands[1])
+OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
+  if (!adaptor.getLhs() || !adaptor.getRhs())
     return nullptr;
   auto lhsShape = llvm::to_vector<6>(
-      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+      adaptor.getLhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
   auto rhsShape = llvm::to_vector<6>(
-      operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+      adaptor.getRhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
   SmallVector<int64_t, 6> resultShape;
   resultShape.append(lhsShape.begin(), lhsShape.end());
   resultShape.append(rhsShape.begin(), rhsShape.end());
@@ -903,7 +904,7 @@ ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
+OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
 
 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {
@@ -966,14 +967,14 @@ static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
   return true;
 }
 
-OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
   // No broadcasting is needed if all operands but one are scalar.
-  if (hasAtMostSingleNonScalar(operands))
+  if (hasAtMostSingleNonScalar(adaptor.getShapes()))
     return BoolAttr::get(getContext(), true);
 
   if ([&] {
         SmallVector<SmallVector<int64_t, 6>, 6> extents;
-        for (const auto &operand : operands) {
+        for (const auto &operand : adaptor.getShapes()) {
           if (!operand)
             return false;
           extents.push_back(llvm::to_vector<6>(
@@ -1018,9 +1019,10 @@ void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<CstrEqEqOps>(context);
 }
 
-OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
-  if (llvm::all_of(operands,
-                   [&](Attribute a) { return a && a == operands[0]; }))
+OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
+  if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
+        return a && a == adaptor.getShapes().front();
+      }))
     return BoolAttr::get(getContext(), true);
 
   // Because a failing witness result here represents an eventual assertion
@@ -1038,7 +1040,7 @@ void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, builder.getIndexAttr(value));
 }
 
-OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
+OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
 
 void ConstSizeOp::getAsmResultNames(
     llvm::function_ref<void(Value, StringRef)> setNameFn) {
@@ -1052,16 +1054,14 @@ void ConstSizeOp::getAsmResultNames(
 // ConstWitnessOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
-  return getPassingAttr();
-}
+OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
 
 //===----------------------------------------------------------------------===//
 // CstrRequireOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
-  return operands[0];
+OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
+  return adaptor.getPred();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1076,7 +1076,7 @@ std::optional<int64_t> DimOp::getConstantIndex() {
   return std::nullopt;
 }
 
-OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   Type valType = getValue().getType();
   auto valShapedType = valType.dyn_cast<ShapedType>();
   if (!valShapedType || !valShapedType.hasRank())
@@ -1120,11 +1120,11 @@ LogicalResult mlir::shape::DimOp::verify() {
 // DivOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
-  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
+OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
+  auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
   if (!lhs)
     return nullptr;
-  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
+  auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
   if (!rhs)
     return nullptr;
 
@@ -1163,14 +1163,14 @@ LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
 // ShapeEqOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
   bool allSame = true;
-  if (!operands.empty() && !operands[0])
+  if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
     return {};
-  for (Attribute operand : operands.drop_front(1)) {
+  for (Attribute operand : adaptor.getShapes().drop_front()) {
     if (!operand)
       return {};
-    allSame = allSame && operand == operands[0];
+    allSame = allSame && operand == adaptor.getShapes().front();
   }
   return BoolAttr::get(getContext(), allSame);
 }
@@ -1179,10 +1179,10 @@ OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
 // IndexToSizeOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
   // Constant values of both types, `shape.size` and `index`, are represented as
   // `IntegerAttr`s which makes constant folding simple.
-  if (Attribute arg = operands[0])
+  if (Attribute arg = adaptor.getArg())
     return arg;
   return {};
 }
@@ -1196,11 +1196,11 @@ void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 // FromExtentsOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
-  if (llvm::any_of(operands, [](Attribute a) { return !a; }))
+OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
+  if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
     return nullptr;
   SmallVector<int64_t, 6> extents;
-  for (auto attr : operands)
+  for (auto attr : adaptor.getExtents())
     extents.push_back(attr.cast<IntegerAttr>().getInt());
   Builder builder(getContext());
   return builder.getIndexTensorAttr(extents);
@@ -1335,8 +1335,8 @@ std::optional<int64_t> GetExtentOp::getConstantDim() {
   return std::nullopt;
 }
 
-OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
-  auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
+  auto elements = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
   if (!elements)
     return nullptr;
   std::optional<int64_t> dim = getConstantDim();
@@ -1386,9 +1386,9 @@ void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
 }
 
-OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
   // Can always broadcast fewer than two shapes.
-  if (operands.size() < 2) {
+  if (adaptor.getShapes().size() < 2) {
     return BoolAttr::get(getContext(), true);
   }
 
@@ -1479,8 +1479,8 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 // RankOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
-  auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
+  auto shape = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
   if (!shape)
     return {};
   int64_t rank = shape.getNumElements();
@@ -1557,10 +1557,10 @@ LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
 // NumElementsOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
 
   // Fold only when argument constant.
-  Attribute shape = operands[0];
+  Attribute shape = adaptor.getShape();
   if (!shape)
     return {};
 
@@ -1596,7 +1596,7 @@ LogicalResult shape::NumElementsOp::verify() {
 // MaxOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
+OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
   // If operands are equal, just propagate one.
   if (getLhs() == getRhs())
     return getLhs();
@@ -1628,7 +1628,7 @@ bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 // MinOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
+OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
   // If operands are equal, just propagate one.
   if (getLhs() == getRhs())
     return getLhs();
@@ -1660,11 +1660,11 @@ bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 // MulOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
-  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
+OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
+  auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
   if (!lhs)
     return nullptr;
-  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
+  auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
   if (!rhs)
     return nullptr;
   APInt folded = lhs.getValue() * rhs.getValue();
@@ -1695,7 +1695,7 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
 // ShapeOfOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
+OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
   auto type = getOperand().getType().dyn_cast<ShapedType>();
   if (!type || !type.hasStaticShape())
     return nullptr;
@@ -1805,10 +1805,10 @@ LogicalResult shape::ShapeOfOp::verify() {
 // SizeToIndexOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
   // Constant values of both types, `shape.size` and `index`, are represented as
   // `IntegerAttr`s which makes constant folding simple.
-  if (Attribute arg = operands[0])
+  if (Attribute arg = adaptor.getArg())
     return arg;
   return OpFoldResult();
 }
@@ -1847,14 +1847,14 @@ LogicalResult shape::YieldOp::verify() {
 // SplitAtOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
+LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
                               SmallVectorImpl<OpFoldResult> &results) {
-  if (!operands[0] || !operands[1])
+  if (!adaptor.getOperand() || !adaptor.getIndex())
     return failure();
   auto shapeVec = llvm::to_vector<6>(
-      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+      adaptor.getOperand().cast<DenseIntElementsAttr>().getValues<int64_t>());
   auto shape = llvm::ArrayRef(shapeVec);
-  auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
+  auto splitPoint = adaptor.getIndex().cast<IntegerAttr>().getInt();
   // Verify that the split point is in the correct range.
   // TODO: Constant fold to an "error".
   int64_t rank = shape.size();
@@ -1862,7 +1862,7 @@ LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
     return failure();
   if (splitPoint < 0)
     splitPoint += shape.size();
-  Builder builder(operands[0].getContext());
+  Builder builder(adaptor.getOperand().getContext());
   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
   return success();
@@ -1872,12 +1872,12 @@ LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
 // ToExtentTensorOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
-  if (!operands[0])
+OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
+  if (!adaptor.getInput())
     return OpFoldResult();
   Builder builder(getContext());
   auto shape = llvm::to_vector<6>(
-      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+      adaptor.getInput().cast<DenseIntElementsAttr>().getValues<int64_t>());
   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
                                     builder.getIndexType());
   return DenseIntElementsAttr::get(type, shape);