Currently canonicalizations of a store and a cast try to fold all casts into the store.
In the case where the operand being stored is itself a cast, this is illegal as the type of the value being stored
will change. This PR fixes this by not checking the value for folding with a cast.
Depends on https://reviews.llvm.org/
D103828
Differential Revision: https://reviews.llvm.org/
D103829
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
+static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto cast = operand.get().getDefiningOp<memref::CastOp>();
- if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
+ if (cast && operand.get() != ignore &&
+ !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
}
LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
- return foldMemRefCast(*this);
+ return foldMemRefCast(*this, getValueToStore());
}
//===----------------------------------------------------------------------===//
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
+static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto cast = operand.get().getDefiningOp<CastOp>();
- if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
+ if (cast && operand.get() != inner &&
+ !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
}
LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
- return foldMemRefCast(*this);
+ return foldMemRefCast(*this, getValueToStore());
}
//===----------------------------------------------------------------------===//
}
return
}
+
+// -----
+
+// CHECK-LABEL: func @no_fold_of_store
+// CHECK: %[[cst:.+]] = memref.cast %arg
+// CHECK: affine.store %[[cst]]
+func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
+ %0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
+ affine.store %0, %holder[] : memref<memref<?xi8>>
+ return
+}
+
return %1 : index
}
+// -----
+
+// CHECK-LABEL: func @no_fold_of_store
+// CHECK: %[[cst:.+]] = memref.cast %arg
+// CHECK: memref.store %[[cst]]
+func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
+ %0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
+ memref.store %0, %holder[] : memref<memref<?xi8>>
+ return
+}