results.add<SimplifyAffineOp<AffineApplyOp>>(context);
}
-//===----------------------------------------------------------------------===//
-// Common canonicalization pattern support logic
-//===----------------------------------------------------------------------===//
-
-/// This is a common class used for patterns of the form
-/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
-/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
- bool folded = false;
- for (OpOperand &operand : op->getOpOperands()) {
- auto cast = operand.get().getDefiningOp<memref::CastOp>();
- if (cast && operand.get() != ignore &&
- !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
- operand.set(cast.getOperand());
- folded = true;
- }
- }
- return success(folded);
-}
-
//===----------------------------------------------------------------------===//
// AffineDmaStartOp
//===----------------------------------------------------------------------===//
LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_start(memrefcast) -> dma_start
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
/// load(memrefcast) -> load
- if (succeeded(foldMemRefCast(*this)))
+ if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
// Fold load from a global constant memref.
LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
- return foldMemRefCast(*this, getValueToStore());
+ return memref::foldMemRefCast(*this, getValueToStore());
}
//===----------------------------------------------------------------------===//
LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// prefetch(memrefcast) -> prefetch
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
return success();
}
-/// This is a common class used for patterns of the form
-/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
-/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
- bool folded = false;
- for (OpOperand &operand : op->getOpOperands()) {
- auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>();
- if (cast) {
- operand.set(cast.getOperand());
- folded = true;
- }
- }
- return success(folded);
-}
-
LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<::mlir::OpFoldResult> &results) {
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<::mlir::OpFoldResult> &results) {
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
// Region is elided.
}
-/// This is a common class used for patterns of the form
-/// ```
-/// someop(memrefcast(%src)) -> someop(%src)
-/// ```
-/// It folds the source of the memref.cast into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
- bool folded = false;
- for (OpOperand &operand : op->getOpOperands()) {
- auto castOp = operand.get().getDefiningOp<memref::CastOp>();
- if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
- operand.set(castOp.getOperand());
- folded = true;
- }
- }
- return success(folded);
-}
-
//===----------------------------------------------------------------------===//
// Region builder helper.
// TODO: Move this to a utility library.
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
[&](Twine t) { return emitOpError(t); });
}
-/// This is a common class used for patterns of the form
-/// ```
-/// someop(memrefcast) -> someop
-/// ```
-/// It folds the source of the memref.cast into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
- bool folded = false;
- for (OpOperand &operand : op->getOpOperands()) {
- auto castOp = operand.get().getDefiningOp<memref::CastOp>();
- if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
- operand.set(castOp.getOperand());
- folded = true;
- }
- }
- return success(folded);
-}
-
-static LogicalResult foldTensorCast(Operation *op) {
- bool folded = false;
- for (OpOperand &operand : op->getOpOperands()) {
- auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
- if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
- operand.set(castOp.getOperand());
- folded = true;
- }
- }
- return success(folded);
-}
-
template <typename TransferOp>
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
// TODO: support more aggressive createOrFold on:
/// transfer_read(memrefcast) -> transfer_read
if (succeeded(foldTransferInBoundsAttribute(*this)))
return getResult();
- if (succeeded(foldMemRefCast(*this)))
+ if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
- if (succeeded(foldTensorCast(*this)))
+ if (succeeded(tensor::foldTensorCast(*this)))
return getResult();
return OpFoldResult();
}
return success();
if (succeeded(foldTransferInBoundsAttribute(*this)))
return success();
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
}
OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
- if (succeeded(foldMemRefCast(*this)))
+ if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
return OpFoldResult();
}
LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
}
OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
- if (succeeded(foldMemRefCast(*this)))
+ if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
return OpFoldResult();
}
LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
const char structuredOpFoldersFormat[] = R"FMT(
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{
- return foldMemRefCast(*this);
+ return memref::foldMemRefCast(*this);
}
void {0}::getEffects(SmallVectorImpl<
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{