[mlir] Delete dup code and use unified methods.
authorHanhan Wang <hanchung@google.com>
Fri, 21 Oct 2022 23:51:11 +0000 (16:51 -0700)
committerHanhan Wang <hanchung@google.com>
Fri, 21 Oct 2022 23:51:44 +0000 (16:51 -0700)
The foldMemRefCast method is defined in memref namespace; the
foldTensorCast method is defined in tensor namespace. This revision
deletes the dup code and use the unified methods.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D136379

mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

index 1d05601..01cc8b3 100644 (file)
@@ -1326,26 +1326,6 @@ void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 //===----------------------------------------------------------------------===//
-// Common canonicalization pattern support logic
-//===----------------------------------------------------------------------===//
-
-/// This is a common class used for patterns of the form
-/// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
-/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
-  bool folded = false;
-  for (OpOperand &operand : op->getOpOperands()) {
-    auto cast = operand.get().getDefiningOp<memref::CastOp>();
-    if (cast && operand.get() != ignore &&
-        !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
-      operand.set(cast.getOperand());
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
-//===----------------------------------------------------------------------===//
 // AffineDmaStartOp
 //===----------------------------------------------------------------------===//
 
@@ -1511,7 +1491,7 @@ LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
                                      SmallVectorImpl<OpFoldResult> &results) {
   /// dma_start(memrefcast) -> dma_start
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1589,7 +1569,7 @@ LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
                                     SmallVectorImpl<OpFoldResult> &results) {
   /// dma_wait(memrefcast) -> dma_wait
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2821,7 +2801,7 @@ void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
   /// load(memrefcast) -> load
-  if (succeeded(foldMemRefCast(*this)))
+  if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
 
   // Fold load from a global constant memref.
@@ -2939,7 +2919,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
                                   SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
-  return foldMemRefCast(*this, getValueToStore());
+  return memref::foldMemRefCast(*this, getValueToStore());
 }
 
 //===----------------------------------------------------------------------===//
@@ -3392,7 +3372,7 @@ void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
                                      SmallVectorImpl<OpFoldResult> &results) {
   /// prefetch(memrefcast) -> prefetch
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
index bfdcedf..4f5cfb6 100644 (file)
@@ -1266,29 +1266,14 @@ LogicalResult SubgroupMmaComputeOp::verify() {
   return success();
 }
 
-/// This is a common class used for patterns of the form
-/// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
-/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
-  bool folded = false;
-  for (OpOperand &operand : op->getOpOperands()) {
-    auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>();
-    if (cast) {
-      operand.set(cast.getOperand());
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
 LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
                              SmallVectorImpl<::mlir::OpFoldResult> &results) {
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
                              SmallVectorImpl<::mlir::OpFoldResult> &results) {
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
index 82e5024..5d6dd37 100644 (file)
@@ -254,23 +254,6 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
   // Region is elided.
 }
 
-/// This is a common class used for patterns of the form
-/// ```
-///    someop(memrefcast(%src)) -> someop(%src)
-/// ```
-/// It folds the source of the memref.cast into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
-  bool folded = false;
-  for (OpOperand &operand : op->getOpOperands()) {
-    auto castOp = operand.get().getDefiningOp<memref::CastOp>();
-    if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
-      operand.set(castOp.getOperand());
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
 //===----------------------------------------------------------------------===//
 // Region builder helper.
 // TODO: Move this to a utility library.
@@ -1290,7 +1273,7 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
                               SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
index 230c585..c25bfc6 100644 (file)
@@ -3085,35 +3085,6 @@ LogicalResult TransferReadOp::verify() {
                               [&](Twine t) { return emitOpError(t); });
 }
 
-/// This is a common class used for patterns of the form
-/// ```
-///    someop(memrefcast) -> someop
-/// ```
-/// It folds the source of the memref.cast into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
-  bool folded = false;
-  for (OpOperand &operand : op->getOpOperands()) {
-    auto castOp = operand.get().getDefiningOp<memref::CastOp>();
-    if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
-      operand.set(castOp.getOperand());
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
-static LogicalResult foldTensorCast(Operation *op) {
-  bool folded = false;
-  for (OpOperand &operand : op->getOpOperands()) {
-    auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
-    if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
-      operand.set(castOp.getOperand());
-      folded = true;
-    }
-  }
-  return success(folded);
-}
-
 template <typename TransferOp>
 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
   // TODO: support more aggressive createOrFold on:
@@ -3198,9 +3169,9 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
   /// transfer_read(memrefcast) -> transfer_read
   if (succeeded(foldTransferInBoundsAttribute(*this)))
     return getResult();
-  if (succeeded(foldMemRefCast(*this)))
+  if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
-  if (succeeded(foldTensorCast(*this)))
+  if (succeeded(tensor::foldTensorCast(*this)))
     return getResult();
   return OpFoldResult();
 }
@@ -3648,7 +3619,7 @@ LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
     return success();
   if (succeeded(foldTransferInBoundsAttribute(*this)))
     return success();
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
@@ -3948,7 +3919,7 @@ LogicalResult vector::LoadOp::verify() {
 }
 
 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
-  if (succeeded(foldMemRefCast(*this)))
+  if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
   return OpFoldResult();
 }
@@ -3982,7 +3953,7 @@ LogicalResult vector::StoreOp::verify() {
 
 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
                             SmallVectorImpl<OpFoldResult> &results) {
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4034,7 +4005,7 @@ void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
-  if (succeeded(foldMemRefCast(*this)))
+  if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
   return OpFoldResult();
 }
@@ -4086,7 +4057,7 @@ void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
                                   SmallVectorImpl<OpFoldResult> &results) {
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 
 //===----------------------------------------------------------------------===//
index d8e10ef..d531d8b 100644 (file)
@@ -655,7 +655,7 @@ ArrayAttr {0}::getIndexingMaps() {{
 const char structuredOpFoldersFormat[] = R"FMT(
 LogicalResult {0}::fold(ArrayRef<Attribute>,
                         SmallVectorImpl<OpFoldResult> &) {{
-  return foldMemRefCast(*this);
+  return memref::foldMemRefCast(*this);
 }
 void {0}::getEffects(SmallVectorImpl<
     SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{