}
};
+template <typename T>
+struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(T affineOp,
+ PatternRewriter &rewriter) const override {
+ if (affineOp.map().getNumResults() != 1)
+ return failure();
+ rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.map(),
+ affineOp.getOperands());
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// AffineMinOp
//===----------------------------------------------------------------------===//
void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<DeduplicateAffineMinMaxExpressions<AffineMinOp>,
+ patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
+ DeduplicateAffineMinMaxExpressions<AffineMinOp>,
MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>>(
context);
}
void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
+ patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
+ DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>>(
context);
}
return %1: index
}
-
// -----
// CHECK-LABEL: func @dont_merge_affine_max_if_not_single_sym
return
}
+// -----
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 + 16)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 * 4)>
+
+// CHECK: func @canonicalize_single_min_max
+// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index)
+func @canonicalize_single_min_max(%i0: index, %i1: index) -> (index, index) {
+ // CHECK-NOT: affine.min
+ // CHECK-NEXT: affine.apply #[[$MAP0]]()[%[[I0]]]
+ %0 = affine.min affine_map<()[s0] -> (s0 + 16)> ()[%i0]
+
+ // CHECK-NOT: affine.max
+ // CHECK-NEXT: affine.apply #[[$MAP1]]()[%[[I1]]]
+ %1 = affine.min affine_map<()[s0] -> (s0 * 4)> ()[%i1]
+
+ return %0, %1: index, index
+}