for this operation. If it is `1`, then `::getCanonicalizationPatterns()` should
be defined.
-### `hasConstantFolder`
-
-This boolean field indicate whether constant folding rules have been defined
-for this operation. If it is `1`, then `::constantFold()` should be defined.
-
### `hasFolder`
This boolean field indicate whether general folding rules have been defined
Operations can also have custom parser, printer, builder, verifier, constant
folder, or canonicalizer. These require specifying additional C++ methods to
invoke for additional functionality. For example, if an operation is marked to
-have a constant folder, the constant folder also needs to be added, e.g.,:
+have a folder, the constant folder also needs to be added, e.g.,:
```c++
-Attribute SpecificOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) {
if (unable_to_fold)
return {};
....
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
- Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
+ OpFoldResult fold(ArrayRef<Attribute> operands);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
namespace llvm {
-// Attribute hash just like pointers
+// Attribute hash just like pointers.
template <> struct DenseMapInfo<mlir::Attribute> {
static mlir::Attribute getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
}
};
+/// Allow LLVM to steal the low bits of Attributes.
+template <> struct PointerLikeTypeTraits<mlir::Attribute> {
+public:
+ static inline void *getAsVoidPointer(mlir::Attribute attr) {
+ return const_cast<void *>(attr.getAsOpaquePointer());
+ }
+ static inline mlir::Attribute getFromVoidPointer(void *ptr) {
+ return mlir::Attribute::getFromOpaquePointer(ptr);
+ }
+ enum { NumLowBitsAvailable = 3 };
+};
+
} // namespace llvm
#endif
#ifndef MLIR_MATCHERS_H
#define MLIR_MATCHERS_H
-#include "mlir/IR/Operation.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include <type_traits>
if (!op->hasNoSideEffect())
return false;
- SmallVector<Attribute, 1> foldedAttr;
- if (succeeded(op->constantFold(/*operands=*/llvm::None, foldedAttr))) {
- *bind_value = foldedAttr.front();
+ SmallVector<OpFoldResult, 1> foldedOp;
+ if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
+ *bind_value = foldedOp.front().dyn_cast<Attribute>();
return true;
}
return false;
// and C++ implementations.
bit hasCanonicalizer = 0;
- // Whether this op has a constant folder.
- bit hasConstantFolder = 0;
-
// Whether this op has a folder.
bit hasFolder = 0;
return lhs.getOperation() != rhs.getOperation();
}
-/// This template defines the constantFoldHook and foldHook as used by
-/// AbstractOperation.
+/// This class represents a single result from folding an operation.
+class OpFoldResult : public llvm::PointerUnion<Attribute, Value *> {
+ using llvm::PointerUnion<Attribute, Value *>::PointerUnion;
+};
+
+/// This template defines the foldHook as used by AbstractOperation.
///
-/// The default implementation uses a general constantFold/fold method that can
-/// be defined on custom ops which can return multiple results.
+/// The default implementation uses a general fold method that can be defined on
+/// custom ops which can return multiple results.
template <typename ConcreteType, bool isSingleResult, typename = void>
class FoldingHook {
public:
/// This is an implementation detail of the constant folder hook for
/// AbstractOperation.
- static LogicalResult constantFoldHook(Operation *op,
- ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) {
- return cast<ConcreteType>(op).constantFold(operands, results,
- op->getContext());
- }
-
- /// Op implementations can implement this hook. It should attempt to constant
- /// fold this operation with the specified constant operand values - the
- /// elements in "operands" will correspond directly to the operands of the
- /// operation, but may be null if non-constant. If constant folding is
- /// successful, this fills in the `results` vector. If not, `results` is
- /// unspecified.
- ///
- /// If not overridden, this fallback implementation always fails to fold.
- ///
- LogicalResult constantFold(ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results,
- MLIRContext *context) {
- return failure();
- }
-
- /// This is an implementation detail of the folder hook for AbstractOperation.
- static LogicalResult foldHook(Operation *op,
- SmallVectorImpl<Value *> &results) {
- return cast<ConcreteType>(op).fold(results);
+ static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return cast<ConcreteType>(op).fold(operands, results);
}
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
- /// the FuncBuilder::foldOrCreate API and the canonicalization pass.
+ /// the Builder::foldOrCreate API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
- /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
- /// not allow for canonicalizations that need to introduce new operations, not
- /// even constants (e.g. "x-x -> 0" cannot be expressed).
+ /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+ /// generalized constant folding.
///
/// If not overridden, this fallback implementation always fails to fold.
///
- LogicalResult fold(SmallVectorImpl<Value *> &results) { return failure(); }
+ LogicalResult fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return failure();
+ }
};
-/// This template specialization defines the constantFoldHook and foldHook as
-/// used by AbstractOperation for single-result operations. This gives the hook
-/// a nicer signature that is easier to implement.
+/// This template specialization defines the foldHook as used by
+/// AbstractOperation for single-result operations. This gives the hook a nicer
+/// signature that is easier to implement.
template <typename ConcreteType, bool isSingleResult>
class FoldingHook<ConcreteType, isSingleResult,
typename std::enable_if<isSingleResult>::type> {
public:
- /// If the operation returns a single value, then the Op can be implicitly
+ /// If the operation returns a single value, then the Op can be implicitly
/// converted to an Value*. This yields the value of the only result.
operator Value *() {
return static_cast<ConcreteType *>(this)->getOperation()->getResult(0);
/// This is an implementation detail of the constant folder hook for
/// AbstractOperation.
- static LogicalResult constantFoldHook(Operation *op,
- ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) {
- auto result =
- cast<ConcreteType>(op).constantFold(operands, op->getContext());
+ static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ auto result = cast<ConcreteType>(op).fold(operands);
if (!result)
return failure();
return success();
}
- /// Op implementations can implement this hook. It should attempt to constant
- /// fold this operation with the specified constant operand values - the
- /// elements in "operands" will correspond directly to the operands of the
- /// operation, but may be null if non-constant. If constant folding is
- /// successful, this returns a non-null attribute, otherwise it returns null
- /// on failure.
- ///
- /// If not overridden, this fallback implementation always fails to fold.
- ///
- Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context) {
- return nullptr;
- }
-
- /// This is an implementation detail of the folder hook for AbstractOperation.
- static LogicalResult foldHook(Operation *op,
- SmallVectorImpl<Value *> &results) {
- auto *result = cast<ConcreteType>(op).fold();
- if (!result)
- return failure();
- if (result != op->getResult(0))
- results.push_back(result);
- return success();
- }
-
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
- /// the FuncBuilder::foldOrCreate API and the canonicalization pass.
+ /// the Builder::foldOrCreate API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
/// remove the operation and use that result instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
- /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
- /// not allow for canonicalizations that need to introduce new operations, not
- /// even constants (e.g. "x-x -> 0" cannot be expressed).
+ /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+ /// generalized constant folding.
///
/// If not overridden, this fallback implementation always fails to fold.
///
- Value *fold() { return nullptr; }
+ OpFoldResult fold(ArrayRef<Attribute> operands) { return {}; }
};
//===----------------------------------------------------------------------===//
return getTerminatorStatus() == TerminatorStatus::NonTerminator;
}
- /// Attempt to constant fold this operation with the specified constant
- /// operand values - the elements in "operands" will correspond directly to
- /// the operands of the operation, but may be null if non-constant. If
- /// constant folding is successful, this fills in the `results` vector. If
- /// not, `results` is unspecified.
- LogicalResult constantFold(ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results);
-
- /// Attempt to fold this operation using the Op's registered foldHook.
- LogicalResult fold(SmallVectorImpl<Value *> &results);
+ /// Attempt to fold this operation with the specified constant operand values
+ /// - the elements in "operands" will correspond directly to the operands of
+ /// the operation, but may be null if non-constant. If folding is successful,
+ /// this fills in the `results` vector. If not, `results` is unspecified.
+ LogicalResult fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results);
//===--------------------------------------------------------------------===//
// Operation Walkers
class OpAsmParser;
class OpAsmParserResult;
class OpAsmPrinter;
+class OpFoldResult;
class ParseResult;
class Pattern;
class Region;
/// success if everything is ok.
LogicalResult (&verifyInvariants)(Operation *op);
- /// This hook implements a constant folder for this operation. It fills in
- /// `results` on success.
- LogicalResult (&constantFoldHook)(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results);
-
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
- /// the FuncBuilder::foldOrCreate API and the canonicalization pass.
+ /// the Builder::foldOrCreate API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
- /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
- /// not allow for canonicalizations that need to introduce new operations, not
- /// even constants (e.g. "x-x -> 0" cannot be expressed).
- LogicalResult (&foldHook)(Operation *op, SmallVectorImpl<Value *> &results);
+ /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+ /// generalized constant folding.
+ LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results);
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
template <typename T> static AbstractOperation get(Dialect &dialect) {
return AbstractOperation(
T::getOperationName(), dialect, T::getOperationProperties(), T::classof,
- T::parseAssembly, T::printAssembly, T::verifyInvariants,
- T::constantFoldHook, T::foldHook, T::getCanonicalizationPatterns);
+ T::parseAssembly, T::printAssembly, T::verifyInvariants, T::foldHook,
+ T::getCanonicalizationPatterns);
}
private:
ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result),
void (&printAssembly)(Operation *op, OpAsmPrinter *p),
LogicalResult (&verifyInvariants)(Operation *op),
- LogicalResult (&constantFoldHook)(Operation *op,
- ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results),
- LogicalResult (&foldHook)(Operation *op,
- SmallVectorImpl<Value *> &results),
+ LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results),
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
MLIRContext *context))
: name(name), dialect(dialect), classof(classof),
parseAssembly(parseAssembly), printAssembly(printAssembly),
- verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
- foldHook(foldHook),
+ verifyInvariants(verifyInvariants), foldHook(foldHook),
getCanonicalizationPatterns(getCanonicalizationPatterns),
opProperties(opProperties) {}
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
- Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
+ OpFoldResult fold(ArrayRef<Attribute> operands);
};
/// The predicate indicates the type of the comparison to perform:
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
- Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
+ OpFoldResult fold(ArrayRef<Attribute> operands);
};
/// The "cond_br" operation represents a conditional branch operation in a
Value *getTrueValue() { return getOperand(1); }
Value *getFalseValue() { return getOperand(2); }
- Value *fold();
+ OpFoldResult fold(ArrayRef<Attribute> operands);
};
/// The "store" op writes an element to a memref specified by an index list.
def AddFOp : FloatArithmeticOp<"addf"> {
let summary = "floating point addition operation";
- let hasConstantFolder = 1;
+ let hasFolder = 1;
}
def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
let summary = "integer addition operation";
let hasFolder = 1;
- let hasConstantFolder = 1;
}
def AllocOp : Std_Op<"alloc"> {
def AndOp : IntArithmeticOp<"and", [Commutative]> {
let summary = "integer binary and";
- let hasConstantFolder = 1;
let hasFolder = 1;
}
Attribute getValue() { return getAttr("value"); }
}];
- let hasConstantFolder = 1;
+ let hasFolder = 1;
}
def DeallocOp : Std_Op<"dealloc"> {
}
}];
- let hasConstantFolder = 1;
+ let hasFolder = 1;
}
def DivFOp : FloatArithmeticOp<"divf"> {
def DivISOp : IntArithmeticOp<"divis"> {
let summary = "signed integer division operation";
- let hasConstantFolder = 1;
+ let hasFolder = 1;
}
def DivIUOp : IntArithmeticOp<"diviu"> {
let summary = "unsigned integer division operation";
- let hasConstantFolder = 1;
+ let hasFolder = 1;
}
def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
}
}];
- let hasConstantFolder = 1;
+ let hasFolder = 1;
}
def MemRefCastOp : CastOp<"memref_cast"> {
def MulFOp : FloatArithmeticOp<"mulf"> {
let summary = "foating point multiplication operation";
- let hasConstantFolder = 1;
+ let hasFolder = 1;
}
def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
let summary = "integer multiplication operation";
- let hasConstantFolder = 1;
let hasFolder = 1;
}
def OrOp : IntArithmeticOp<"or", [Commutative]> {
let summary = "integer binary or";
- let hasConstantFolder = 1;
let hasFolder = 1;
}
def RemISOp : IntArithmeticOp<"remis"> {
let summary = "signed integer division remainder operation";
- let hasConstantFolder = 1;
+ let hasFolder = 1;
}
def RemIUOp : IntArithmeticOp<"remiu"> {
let summary = "unsigned integer division remainder operation";
- let hasConstantFolder = 1;
+ let hasFolder = 1;
}
def ReturnOp : Std_Op<"return", [Terminator]> {
def SubFOp : FloatArithmeticOp<"subf"> {
let summary = "floating point subtraction operation";
- let hasConstantFolder = 1;
+ let hasFolder = 1;
}
def SubIOp : IntArithmeticOp<"subi"> {
let summary = "integer subtraction operation";
- let hasConstantFolder = 1;
- let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
def TensorCastOp : CastOp<"tensor_cast"> {
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
let summary = "integer binary xor";
- let hasConstantFolder = 1;
- let hasCanonicalizer = 1;
let hasFolder = 1;
}
-//===- ConstantFoldUtils.h - Constant Fold Utilities ------------*- C++ -*-===//
+//===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// limitations under the License.
// =============================================================================
//
-// This header file declares various constant fold utilities. These utilities
-// are intended to be used by passes to unify and simply their logic.
+// This header file declares various operation folding utilities. These
+// utilities are intended to be used by passes to unify and simply their logic.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TRANSFORMS_CONSTANT_UTILS_H
-#define MLIR_TRANSFORMS_CONSTANT_UTILS_H
+#ifndef MLIR_TRANSFORMS_FOLDUTILS_H
+#define MLIR_TRANSFORMS_FOLDUTILS_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"
class Function;
class Operation;
-/// A helper class for constant folding operations, and unifying duplicated
-/// constants along the way.
+/// A helper class for folding operations, and unifying duplicated constants
+/// generated along the way.
///
-/// To make sure constants' proper dominance of all their uses, constants are
+/// To make sure constants properly dominate all their uses, constants are
/// moved to the beginning of the entry block of the function when tracked by
/// this class.
-class ConstantFoldHelper {
+class FoldHelper {
public:
/// Constructs an instance for managing constants in the given function `f`.
/// Constants tracked by this instance will be moved to the entry block of
/// This instance does not proactively walk the operations inside `f`;
/// instead, users must invoke the following methods to manually handle each
/// operation of interest.
- ConstantFoldHelper(Function *f);
+ FoldHelper(Function *f);
- /// Tries to perform constant folding on the given `op`, including unifying
- /// deplicated constants. If successful, calls `preReplaceAction` (if
+ /// Tries to perform folding on the given `op`, including unifying
+ /// deduplicated constants. If successful, calls `preReplaceAction` (if
/// provided) by passing in `op`, then replaces `op`'s uses with folded
- /// constants, and returns true.
- ///
- /// Note: `op` will *not* be erased to avoid invalidating potential walkers in
- /// the caller.
- bool
- tryToConstantFold(Operation *op,
- std::function<void(Operation *)> preReplaceAction = {});
+ /// results, and returns success. If the op was completely folded it is
+ /// erased.
+ LogicalResult
+ tryToFold(Operation *op,
+ std::function<void(Operation *)> preReplaceAction = {});
/// Notifies that the given constant `op` should be remove from this
- /// ConstantFoldHelper's internal bookkeeping.
+ /// FoldHelper's internal bookkeeping.
///
/// Note: this method must be called if a constant op is to be deleted
- /// externally to this ConstantFoldHelper. `op` must be a constant op.
+ /// externally to this FoldHelper. `op` must be a constant op.
void notifyRemoval(Operation *op);
private:
- /// Tries to deduplicate the given constant and returns true if that can be
+ /// Tries to deduplicate the given constant and returns success if that can be
/// done. This moves the given constant to the top of the entry block if it
/// is first seen. If there is already an existing constant that is the same,
/// this does *not* erases the given constant.
- bool tryToUnify(Operation *op);
+ LogicalResult tryToUnify(Operation *op);
/// Moves the given constant `op` to entry block to guarantee dominance.
void moveConstantToEntryBlock(Operation *op);
} // end namespace mlir
-#endif // MLIR_TRANSFORMS_CONSTANT_UTILS_H
+#endif // MLIR_TRANSFORMS_FOLDUTILS_H
[](Value *op) { return mlir::isValidSymbol(op); });
}
-Attribute AffineApplyOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
auto map = getAffineMap();
SmallVector<Attribute, 1> result;
if (failed(map.constantFold(operands, result)))
- return Attribute();
+ return {};
return result[0];
}
succOperandIndex + getNumSuccessorOperands(index))};
}
-/// Attempt to constant fold this operation with the specified constant
-/// operand values. If successful, this fills in the results vector. If not,
-/// results is unspecified.
-LogicalResult Operation::constantFold(ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) {
- if (auto *abstractOp = getAbstractOperation()) {
- // If we have a registered operation definition matching this one, use it to
- // try to constant fold the operation.
- if (succeeded(abstractOp->constantFoldHook(this, operands, results)))
- return success();
-
- // Otherwise, fall back on the dialect hook to handle it.
- return abstractOp->dialect.constantFoldHook(this, operands, results);
- }
-
- // If this operation hasn't been registered or doesn't have abstract
- // operation, fall back to a dialect which matches the prefix.
- auto opName = getName().getStringRef();
- auto dialectPrefix = opName.split('.').first;
- if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix))
- return dialect->constantFoldHook(this, operands, results);
-
- return failure();
-}
-
/// Attempt to fold this operation using the Op's registered foldHook.
-LogicalResult Operation::fold(SmallVectorImpl<Value *> &results) {
- if (auto *abstractOp = getAbstractOperation()) {
- // If we have a registered operation definition matching this one, use it to
- // try to constant fold the operation.
- if (succeeded(abstractOp->foldHook(this, results)))
- return success();
+LogicalResult Operation::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // If we have a registered operation definition matching this one, use it to
+ // try to constant fold the operation.
+ auto *abstractOp = getAbstractOperation();
+ if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results)))
+ return success();
+
+ // Otherwise, fall back on the dialect hook to handle it.
+ Dialect *dialect;
+ if (abstractOp) {
+ dialect = &abstractOp->dialect;
+ } else {
+ // If this operation hasn't been registered, lookup the parent dialect.
+ auto opName = getName().getStringRef();
+ auto dialectPrefix = opName.split('.').first;
+ if (!(dialect = getContext()->getRegisteredDialect(dialectPrefix)))
+ return failure();
}
- return failure();
+
+ SmallVector<Attribute, 8> constants;
+ if (failed(dialect->constantFoldHook(this, operands, constants)))
+ return failure();
+ results.assign(constants.begin(), constants.end());
+ return success();
}
/// Emit an error with the op name prefixed, like "'dim' op " which is
// AddFOp
//===----------------------------------------------------------------------===//
-Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
return constFoldBinaryOp<FloatAttr>(
operands, [](APFloat a, APFloat b) { return a + b; });
}
// AddIOp
//===----------------------------------------------------------------------===//
-Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a + b; });
-}
-
-Value *AddIOp::fold() {
+OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
/// addi(x, 0) -> x
- if (matchPattern(getOperand(1), m_Zero()))
- return getOperand(0);
+ if (matchPattern(rhs(), m_Zero()))
+ return lhs();
- return nullptr;
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a + b; });
}
//===----------------------------------------------------------------------===//
}
// Constant folding hook for comparisons.
-Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two arguments");
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
- return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
+ return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
}
//===----------------------------------------------------------------------===//
}
// Constant folding hook for comparisons.
-Attribute CmpFOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpf takes two arguments");
auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
- return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
+ return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
}
//===----------------------------------------------------------------------===//
"requires a result type that aligns with the 'value' attribute");
}
-Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
return getValue();
}
return success();
}
-Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
// Constant fold dim when the size along the index referred to is a constant.
auto opType = getOperand()->getType();
int64_t indexSize = -1;
indexSize = memrefType.getShape()[getIndex()];
if (indexSize >= 0)
- return IntegerAttr::get(IndexType::get(context), indexSize);
+ return IntegerAttr::get(IndexType::get(getContext()), indexSize);
- return nullptr;
+ return {};
}
//===----------------------------------------------------------------------===//
// DivISOp
//===----------------------------------------------------------------------===//
-Attribute DivISOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary operation takes two operands");
- (void)context;
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
return {};
// Don't fold if it requires division by zero.
- if (rhs.getValue().isNullValue()) {
+ if (rhs.getValue().isNullValue())
return {};
- }
// Don't fold if it would overflow.
bool overflow;
// DivIUOp
//===----------------------------------------------------------------------===//
-Attribute DivIUOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult DivIUOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary operation takes two operands");
- (void)context;
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
return success();
}
-Attribute ExtractElementOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
assert(!operands.empty() && "extract_element takes atleast one operand");
// The aggregate operand must be a known constant.
Attribute aggregate = operands.front();
if (!aggregate)
- return Attribute();
+ return {};
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : llvm::drop_begin(operands, 1)) {
if (!indice || !indice.isa<IntegerAttr>())
- return Attribute();
+ return {};
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
// If this is an elements attribute, query the value at the given indices.
if (auto elementsAttr = aggregate.dyn_cast<ElementsAttr>())
return elementsAttr.getValue(indices);
- return Attribute();
+ return {};
}
//===----------------------------------------------------------------------===//
return true;
}
-Value *MemRefCastOp::fold() { return impl::foldCastOp(*this); }
+OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
+ return impl::foldCastOp(*this);
+}
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
-Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
return constFoldBinaryOp<FloatAttr>(
operands, [](APFloat a, APFloat b) { return a * b; });
}
// MulIOp
//===----------------------------------------------------------------------===//
-Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
- // TODO: Handle the overflow case.
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a * b; });
-}
-
-Value *MulIOp::fold() {
+OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
/// muli(x, 0) -> 0
- if (matchPattern(getOperand(1), m_Zero()))
- return getOperand(1);
+ if (matchPattern(rhs(), m_Zero()))
+ return rhs();
/// muli(x, 1) -> x
- if (matchPattern(getOperand(1), m_One()))
+ if (matchPattern(rhs(), m_One()))
return getOperand(0);
- return nullptr;
+
+ // TODO: Handle the overflow case.
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a * b; });
}
//===----------------------------------------------------------------------===//
// RemISOp
//===----------------------------------------------------------------------===//
-Attribute RemISOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult RemISOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "remis takes two operands");
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
APInt(rhs.getValue().getBitWidth(), 0));
// Don't fold if it requires division by zero.
- if (rhs.getValue().isNullValue()) {
+ if (rhs.getValue().isNullValue())
return {};
- }
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
if (!lhs)
// RemIUOp
//===----------------------------------------------------------------------===//
-Attribute RemIUOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult RemIUOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "remiu takes two operands");
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
APInt(rhs.getValue().getBitWidth(), 0));
// Don't fold if it requires division by zero.
- if (rhs.getValue().isNullValue()) {
+ if (rhs.getValue().isNullValue())
return {};
- }
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
if (!lhs)
return success();
}
-Value *SelectOp::fold() {
+OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
auto *condition = getCondition();
// select true, %0, %1 => %0
// SubFOp
//===----------------------------------------------------------------------===//
-Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
return constFoldBinaryOp<FloatAttr>(
operands, [](APFloat a, APFloat b) { return a - b; });
}
// SubIOp
//===----------------------------------------------------------------------===//
-Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
+OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
+ // subi(x,x) -> 0
+ if (getOperand(0) == getOperand(1))
+ return Builder(getContext()).getZeroAttr(getType());
+
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a - b; });
}
-namespace {
-/// subi(x,x) -> 0
-///
-struct SimplifyXMinusX : public RewritePattern {
- SimplifyXMinusX(MLIRContext *context)
- : RewritePattern(SubIOp::getOperationName(), 1, context) {}
-
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- auto subi = cast<SubIOp>(op);
- if (subi.getOperand(0) != subi.getOperand(1))
- return matchFailure();
-
- rewriter.replaceOpWithNewOp<ConstantOp>(
- op, subi.getType(), rewriter.getZeroAttr(subi.getType()));
- return matchSuccess();
- }
-};
-} // end anonymous namespace.
-
-void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyXMinusX>(context));
-}
-
//===----------------------------------------------------------------------===//
// AndOp
//===----------------------------------------------------------------------===//
-Attribute AndOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a & b; });
-}
-
-Value *AndOp::fold() {
+OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
/// and(x, 0) -> 0
if (matchPattern(rhs(), m_Zero()))
return rhs();
if (lhs() == rhs())
return rhs();
- return nullptr;
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a & b; });
}
//===----------------------------------------------------------------------===//
// OrOp
//===----------------------------------------------------------------------===//
-Attribute OrOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a | b; });
-}
-
-Value *OrOp::fold() {
+OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
/// or(x, 0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
if (lhs() == rhs())
return rhs();
- return nullptr;
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a | b; });
}
//===----------------------------------------------------------------------===//
// XOrOp
//===----------------------------------------------------------------------===//
-Attribute XOrOp::constantFold(ArrayRef<Attribute> operands,
- MLIRContext *context) {
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a ^ b; });
-}
-
-Value *XOrOp::fold() {
+OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
/// xor(x, 0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
+ /// xor(x,x) -> 0
+ if (lhs() == rhs())
+ return Builder(getContext()).getZeroAttr(getType());
- return nullptr;
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a ^ b; });
}
-namespace {
-/// xor(x,x) -> 0
-///
-struct SimplifyXXOrX : public RewritePattern {
- SimplifyXXOrX(MLIRContext *context)
- : RewritePattern(XOrOp::getOperationName(), 1, context) {}
-
- PatternMatchResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- auto xorOp = cast<XOrOp>(op);
- if (xorOp.lhs() != xorOp.rhs())
- return matchFailure();
-
- rewriter.replaceOpWithNewOp<ConstantOp>(
- op, xorOp.getType(), rewriter.getZeroAttr(xorOp.getType()));
- return matchSuccess();
- }
-};
-} // end anonymous namespace.
-
-void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyXXOrX>(context));
-}
//===----------------------------------------------------------------------===//
// TensorCastOp
//===----------------------------------------------------------------------===//
return true;
}
-Value *TensorCastOp::fold() { return impl::foldCastOp(*this); }
+OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
+ return impl::foldCastOp(*this);
+}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
SimplifyAffineStructures.cpp
StripDebugInfo.cpp
TestConstantFold.cpp
- Utils/ConstantFoldUtils.cpp
+ Utils/FoldUtils.cpp
Utils/GreedyPatternRewriteDriver.cpp
Utils/LoopUtils.cpp
Utils/Utils.cpp
#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
-#include "mlir/Transforms/ConstantFoldUtils.h"
+#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
struct TestConstantFold : public FunctionPass<TestConstantFold> {
// All constants in the function post folding.
SmallVector<Operation *, 8> existingConstants;
- // Operations that were folded and that need to be erased.
- std::vector<Operation *> opsToErase;
- void foldOperation(Operation *op, ConstantFoldHelper &helper);
+ void foldOperation(Operation *op, FoldHelper &helper);
void runOnFunction() override;
};
} // end anonymous namespace
-void TestConstantFold::foldOperation(Operation *op,
- ConstantFoldHelper &helper) {
+void TestConstantFold::foldOperation(Operation *op, FoldHelper &helper) {
// Attempt to fold the specified operation, including handling unused or
// duplicated constants.
- if (helper.tryToConstantFold(op)) {
- opsToErase.push_back(op);
- }
+ if (succeeded(helper.tryToFold(op)))
+ return;
+
// If this op is a constant that are used and cannot be de-duplicated,
// remember it for cleanup later.
- else if (auto constant = dyn_cast<ConstantOp>(op)) {
+ if (auto constant = dyn_cast<ConstantOp>(op))
existingConstants.push_back(op);
- }
}
// For now, we do a simple top-down pass over a function folding constants. We
// branches, or anything else fancy.
void TestConstantFold::runOnFunction() {
existingConstants.clear();
- opsToErase.clear();
auto &f = getFunction();
- ConstantFoldHelper helper(&f);
+ FoldHelper helper(&f);
// Collect and fold the operations within the function.
SmallVector<Operation *, 8> ops;
for (Operation *op : llvm::reverse(ops))
foldOperation(op, helper);
- // At this point, these operations are dead, remove them.
- for (auto *op : opsToErase) {
- assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
- op->erase();
- }
-
// By the time we are done, we may have simplified a bunch of code, leaving
// around dead constants. Check for them now and remove them.
for (auto *cst : existingConstants) {
-//===- ConstantFoldUtils.cpp ---- Constant Fold Utilities -----------------===//
+//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// limitations under the License.
// =============================================================================
//
-// This file defines various constant fold utilities. These utilities are
+// This file defines various operation fold utilities. These utilities are
// intended to be used by passes to unify and simply their logic.
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/ConstantFoldUtils.h"
+#include "mlir/Transforms/FoldUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
using namespace mlir;
-ConstantFoldHelper::ConstantFoldHelper(Function *f) : function(f) {}
+FoldHelper::FoldHelper(Function *f) : function(f) {}
-bool ConstantFoldHelper::tryToConstantFold(
- Operation *op, std::function<void(Operation *)> preReplaceAction) {
+LogicalResult
+FoldHelper::tryToFold(Operation *op,
+ std::function<void(Operation *)> preReplaceAction) {
assert(op->getFunction() == function &&
"cannot constant fold op from another function");
// If this constant is dead, update bookkeeping and signal the caller.
if (constant.use_empty()) {
notifyRemoval(op);
- return true;
+ op->erase();
+ return success();
}
// Otherwise, try to see if we can de-duplicate it.
return tryToUnify(op);
}
- SmallVector<Attribute, 8> operandConstants, resultConstants;
+ SmallVector<Attribute, 8> operandConstants;
+ SmallVector<OpFoldResult, 8> results;
// Check to see if any operands to the operation is constant and whether
// the operation knows how to constant fold itself.
}
// Attempt to constant fold the operation.
- if (failed(op->constantFold(operandConstants, resultConstants)))
- return false;
+ if (failed(op->fold(operandConstants, results)))
+ return failure();
// Constant folding succeeded. We will start replacing this op's uses and
// eventually erase this op. Invoke the callback provided by the caller to
if (preReplaceAction)
preReplaceAction(op);
+ // Check to see if the operation was just updated in place.
+ if (results.empty())
+ return success();
+ assert(results.size() == op->getNumResults());
+
// Create the result constants and replace the results.
FuncBuilder builder(op);
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
auto *res = op->getResult(i);
if (res->use_empty()) // Ignore dead uses.
continue;
+ assert(!results[i].isNull() && "expected valid OpFoldResult");
+
+ // Check if the result was an SSA value.
+ if (auto *repl = results[i].dyn_cast<Value *>()) {
+ if (repl != res)
+ res->replaceAllUsesWith(repl);
+ continue;
+ }
// If we already have a canonicalized version of this constant, just reuse
// it. Otherwise create a new one.
+ Attribute attrRepl = results[i].get<Attribute>();
auto &constInst =
- uniquedConstants[std::make_pair(resultConstants[i], res->getType())];
+ uniquedConstants[std::make_pair(attrRepl, res->getType())];
if (!constInst) {
// TODO: Extend to support dialect-specific constant ops.
- auto newOp = builder.create<ConstantOp>(op->getLoc(), res->getType(),
- resultConstants[i]);
+ auto newOp =
+ builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl);
// Register to the constant map and also move up to entry block to
// guarantee dominance.
constInst = newOp.getOperation();
}
res->replaceAllUsesWith(constInst->getResult(0));
}
+ op->erase();
- return true;
+ return success();
}
-void ConstantFoldHelper::notifyRemoval(Operation *op) {
+void FoldHelper::notifyRemoval(Operation *op) {
assert(op->getFunction() == function &&
"cannot remove constant from another function");
Attribute constValue;
- matchPattern(op, m_Constant(&constValue));
- assert(constValue);
+ if (!matchPattern(op, m_Constant(&constValue)))
+ return;
// This constant is dead. keep uniquedConstants up to date.
auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()});
uniquedConstants.erase(it);
}
-bool ConstantFoldHelper::tryToUnify(Operation *op) {
+LogicalResult FoldHelper::tryToUnify(Operation *op) {
Attribute constValue;
matchPattern(op, m_Constant(&constValue));
assert(constValue);
if (constInst) {
// If this constant is already our uniqued one, then leave it alone.
if (constInst == op)
- return false;
+ return failure();
// Otherwise replace this redundant constant with the uniqued one. We know
// this is safe because we move constants to the top of the function when
// they are uniqued, so we know they dominate all uses.
op->getResult(0)->replaceAllUsesWith(constInst->getResult(0));
- return true;
+ op->erase();
+ return success();
}
// If we have no entry, then we should unique this constant as the
// entry block of the function.
constInst = op;
moveConstantToEntryBlock(op);
- return false;
+ return failure();
}
-void ConstantFoldHelper::moveConstantToEntryBlock(Operation *op) {
+void FoldHelper::moveConstantToEntryBlock(Operation *op) {
// Insert at the very top of the entry block.
auto &entryBB = function->front();
op->moveBefore(&entryBB, entryBB.begin());
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/StandardOps/Ops.h"
-#include "mlir/Transforms/ConstantFoldUtils.h"
+#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
/// Perform the rewrites.
bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
Function *fn = builder.getFunction();
- ConstantFoldHelper helper(fn);
+ FoldHelper helper(fn);
bool changed = false;
int i = 0;
// If the operation has no side effects, and no users, then it is
// trivially dead - remove it.
if (op->hasNoSideEffect() && op->use_empty()) {
- // Be careful to update bookkeeping in ConstantHelper to keep
- // consistency if this is a constant op.
- if (isa<ConstantOp>(op))
- helper.notifyRemoval(op);
+ // Be careful to update bookkeeping in FoldHelper to keep consistency if
+ // this is a constant op.
+ helper.notifyRemoval(op);
op->erase();
continue;
}
// Collects all the operands and result uses of the given `op` into work
// list.
- auto collectOperandsAndUses = [this](Operation *op) {
+ originalOperands.assign(op->operand_begin(), op->operand_end());
+ auto collectOperandsAndUses = [&](Operation *op) {
// Add the operands to the worklist for visitation.
- addToWorklist(op->getOperands());
+ addToWorklist(originalOperands);
+
// Add all the users of the result to the worklist so we make sure
// to revisit them.
//
// TODO: Add a result->getUsers() iterator.
- for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
+ for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
for (auto &operand : op->getResult(i)->getUses())
addToWorklist(operand.getOwner());
- }
};
- // Try to constant fold this op.
- if (helper.tryToConstantFold(op, collectOperandsAndUses)) {
- assert(op->hasNoSideEffect() &&
- "Constant folded op with side effects?");
- op->erase();
- changed |= true;
- continue;
- }
-
- // Otherwise see if we can use the generic folder API to simplify the
- // operation.
- originalOperands.assign(op->operand_begin(), op->operand_end());
- resultValues.clear();
- if (succeeded(op->fold(resultValues))) {
- // If the result was an in-place simplification (e.g. max(x,x,y) ->
- // max(x,y)) then add the original operands to the worklist so we can
- // make sure to revisit them.
- if (resultValues.empty()) {
- // Add the operands back to the worklist as there may be more
- // canonicalization opportunities now.
- addToWorklist(originalOperands);
- } else {
- // Otherwise, the operation is simplified away completely.
- assert(resultValues.size() == op->getNumResults());
-
- // Notify that we are replacing this operation.
- notifyRootReplaced(op);
-
- // Replace the result values and erase the operation.
- for (unsigned i = 0, e = resultValues.size(); i != e; ++i) {
- auto *res = op->getResult(i);
- if (!res->use_empty())
- res->replaceAllUsesWith(resultValues[i]);
- }
-
- notifyOperationRemoved(op);
- op->erase();
- }
+ // Try to fold this op.
+ if (succeeded(helper.tryToFold(op, collectOperandsAndUses))) {
changed |= true;
continue;
}
let verifier = [{ baz }];
let hasCanonicalizer = 1;
- let hasConstantFolder = 1;
let hasFolder = 1;
let extraClassDeclaration = [{
// CHECK: void print(OpAsmPrinter *p);
// CHECK: LogicalResult verify();
// CHECK: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context);
-// CHECK: LogicalResult constantFold(ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, MLIRContext *context);
-// CHECK: bool fold(SmallVectorImpl<Value *> &results);
+// CHECK: LogicalResult fold(ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results);
// CHECK: // Display a graph for debugging purposes.
// CHECK: void displayGraph();
// CHECK: };
void OpEmitter::genFolderDecls() {
bool hasSingleResult = op.getNumResults() == 1;
- if (def.getValueAsBit("hasConstantFolder")) {
- if (hasSingleResult) {
- const char *const params =
- "ArrayRef<Attribute> operands, MLIRContext *context";
- opClass.newMethod("Attribute", "constantFold", params, OpMethod::MP_None,
- /*declOnly=*/true);
- } else {
- const char *const params =
- "ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, "
- "MLIRContext *context";
- opClass.newMethod("LogicalResult", "constantFold", params,
- OpMethod::MP_None, /*declOnly=*/true);
- }
- }
-
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
- opClass.newMethod("Value *", "fold", /*params=*/"", OpMethod::MP_None,
+ const char *const params = "ArrayRef<Attribute> operands";
+ opClass.newMethod("OpFoldResult", "fold", params, OpMethod::MP_None,
/*declOnly=*/true);
} else {
- opClass.newMethod("bool", "fold", "SmallVectorImpl<Value *> &results",
- OpMethod::MP_None,
+ const char *const params = "ArrayRef<Attribute> operands, "
+ "SmallVectorImpl<OpFoldResult> &results";
+ opClass.newMethod("LogicalResult", "fold", params, OpMethod::MP_None,
/*declOnly=*/true);
}
}