[mlir][Tensor][NFC] Migrate Tensor dialect to the new fold API
authorMarkus Böck <markus.boeck02@gmail.com>
Tue, 10 Jan 2023 19:40:25 +0000 (20:40 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Tue, 17 Jan 2023 12:22:11 +0000 (13:22 +0100)
See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context

Differential Revision: https://reviews.llvm.org/D141530

mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

index fe49f8d..b27b6ea 100644 (file)
@@ -47,6 +47,7 @@ def Tensor_Dialect : Dialect {
 
   let hasCanonicalizer = 1;
   let hasConstantMaterializer = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
   let dependentDialects = [
     "AffineDialect",
     "arith::ArithDialect",
index 18ffbe3..d8c337d 100644 (file)
@@ -418,9 +418,9 @@ LogicalResult DimOp::verify() {
   return success();
 }
 
-OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   // All forms of folding require a known index.
-  auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
+  auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
   if (!index)
     return {};
 
@@ -763,16 +763,16 @@ LogicalResult ExtractOp::verify() {
   return success();
 }
 
-OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // If this is a splat elements attribute, simply return the value. All of
   // the elements of a splat attribute are the same.
-  if (Attribute tensor = operands.front())
+  if (Attribute tensor = adaptor.getTensor())
     if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
       return splatTensor.getSplatValue<Attribute>();
 
   // Collect the constant indices into the tensor.
   SmallVector<uint64_t, 8> indices;
-  for (Attribute indice : llvm::drop_begin(operands, 1)) {
+  for (Attribute indice : adaptor.getIndices()) {
     if (!indice || !indice.isa<IntegerAttr>())
       return {};
     indices.push_back(indice.cast<IntegerAttr>().getInt());
@@ -800,7 +800,7 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
   }
 
   // If this is an elements attribute, query the value at the given indices.
-  if (Attribute tensor = operands.front()) {
+  if (Attribute tensor = adaptor.getTensor()) {
     auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
     if (elementsAttr && elementsAttr.isValidIndex(indices))
       return elementsAttr.getValues<Attribute>()[indices];
@@ -837,9 +837,9 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, resultType, elements);
 }
 
-OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
-  if (!llvm::is_contained(operands, nullptr))
-    return DenseElementsAttr::get(getType(), operands);
+OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
+  if (!llvm::is_contained(adaptor.getElements(), nullptr))
+    return DenseElementsAttr::get(getType(), adaptor.getElements());
   return {};
 }
 
@@ -996,9 +996,9 @@ LogicalResult InsertOp::verify() {
   return success();
 }
 
-OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
-  Attribute scalar = operands[0];
-  Attribute dest = operands[1];
+OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
+  Attribute scalar = adaptor.getScalar();
+  Attribute dest = adaptor.getDest();
   if (scalar && dest)
     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
       if (scalar == splatDest.getSplatValue<Attribute>())
@@ -1178,7 +1178,7 @@ void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "rank");
 }
 
-OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
   // Constant fold rank when the rank of the operand is known.
   auto type = getOperand().getType();
   auto shapedType = type.dyn_cast<ShapedType>();
@@ -1558,12 +1558,14 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
           context);
 }
 
-OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
-  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
+OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
+                                                       adaptor.getOperands());
 }
 
-OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
-  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
+OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
+                                                       adaptor.getOperands());
 }
 
 //===----------------------------------------------------------------------===//
@@ -2050,8 +2052,8 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
   return {};
 }
 
-OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
-  if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
+OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
+  if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
     auto resultType = getResult().getType().cast<ShapedType>();
     if (resultType.hasStaticShape())
       return splat.resizeSplat(resultType);
@@ -2197,7 +2199,7 @@ static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
   return extractOp.getSource();
 }
 
-OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
+OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
       getSourceType() == getType() &&
       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
@@ -2869,7 +2871,7 @@ Value PadOp::getConstantPaddingValue() {
   return padValue;
 }
 
-OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
+OpFoldResult PadOp::fold(FoldAdaptor) {
   if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
       !getNofold())
     return getSource();
@@ -3004,8 +3006,8 @@ void SplatOp::getAsmResultNames(
   setNameFn(getResult(), "splat");
 }
 
-OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
-  auto constOperand = operands.front();
+OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
+  auto constOperand = adaptor.getInput();
   if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
     return {};