[mlir] Remove some code duplication between `Builders.cpp` and `FoldUtils.cpp`
authorMatthias Springer <me@m-sp.org>
Thu, 20 Jul 2023 08:20:36 +0000 (10:20 +0200)
committerMatthias Springer <me@m-sp.org>
Thu, 20 Jul 2023 08:27:14 +0000 (10:27 +0200)
Also update the documentation of `Operation::fold`, which did not take into account in-place foldings.

Differential Revision: https://reviews.llvm.org/D155691

mlir/include/mlir/IR/Operation.h
mlir/lib/IR/Builders.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp

index ec6d4ca..fbded01 100644 (file)
@@ -679,11 +679,25 @@ public:
 
   /// Attempt to fold this operation with the specified constant operand values
   /// - the elements in "operands" will correspond directly to the operands of
-  /// the operation, but may be null if non-constant. If folding is successful,
-  /// this fills in the `results` vector. If not, `results` is unspecified.
+  /// the operation, but may be null if non-constant.
+  ///
+  /// If folding was successful, this function returns "success".
+  /// * If this operation was modified in-place (but not folded away),
+  ///   `results` is empty.
+  /// * Otherwise, `results` is filled with the folded results.
+  /// If folding was unsuccessful, this function returns "failure".
   LogicalResult fold(ArrayRef<Attribute> operands,
                      SmallVectorImpl<OpFoldResult> &results);
 
+  /// Attempt to fold this operation.
+  ///
+  /// If folding was successful, this function returns "success".
+  /// * If this operation was modified in-place (but not folded away),
+  ///   `results` is empty.
+  /// * Otherwise, `results` is filled with the folded results.
+  /// If folding was unsuccessful, this function returns "failure".
+  LogicalResult fold(SmallVectorImpl<OpFoldResult> &results);
+
   /// Returns true if the operation was registered with a particular trait, e.g.
   /// hasTrait<OperandsAreSignlessIntegerLike>().
   template <template <typename T> class Trait>
index 0f1aceb..b8a98ed 100644 (file)
@@ -475,15 +475,9 @@ LogicalResult OpBuilder::tryFold(Operation *op,
   if (matchPattern(op, m_Constant()))
     return cleanupFailure();
 
-  // Check to see if any operands to the operation is constant and whether
-  // the operation knows how to constant fold itself.
-  SmallVector<Attribute, 4> constOperands(op->getNumOperands());
-  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
-    matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
-
   // Try to fold the operation.
   SmallVector<OpFoldResult, 4> foldResults;
-  if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
+  if (failed(op->fold(foldResults)) || foldResults.empty())
     return cleanupFailure();
 
   // A temporary builder used for creating constants during folding.
index efce8d9..1c70bc3 100644 (file)
@@ -628,6 +628,15 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
   return interface->fold(this, operands, results);
 }
 
+LogicalResult Operation::fold(SmallVectorImpl<OpFoldResult> &results) {
+  // Check if any operands are constants.
+  SmallVector<Attribute> constants;
+  constants.assign(getNumOperands(), Attribute());
+  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
+    matchPattern(getOperand(i), m_Constant(&constants[i]));
+  return fold(constants, results);
+}
+
 /// Emit an error with the op name prefixed, like "'dim' op " which is
 /// convenient for verifiers.
 InFlightDiagnostic Operation::emitOpError(const Twine &message) {
index ad1e043..90ee5ba 100644 (file)
@@ -215,19 +215,8 @@ bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
 /// `results` with the results of the folding.
 LogicalResult OperationFolder::tryToFold(Operation *op,
                                          SmallVectorImpl<Value> &results) {
-  SmallVector<Attribute, 8> operandConstants;
-
-  // Check to see if any operands to the operation is constant and whether
-  // the operation knows how to constant fold itself.
-  operandConstants.assign(op->getNumOperands(), Attribute());
-  for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
-    matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
-
-  // Attempt to constant fold the operation. If we failed, check to see if we at
-  // least updated the operands of the operation. We treat this as an in-place
-  // fold.
   SmallVector<OpFoldResult, 8> foldResults;
-  if (failed(op->fold(operandConstants, foldResults)) ||
+  if (failed(op->fold(foldResults)) ||
       failed(processFoldResults(op, results, foldResults)))
     return failure();
   return success();