PatternRewriter &rewriter) const override;
};
+ /// Helper struct to return the results of `substituteMin`.
+struct AffineMapAndOperands {
+ AffineMap map;
+ SmallVector<Value> dims;
+ SmallVector<Value> symbols;
+};
+/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
+/// induction variables by new expressions involving the lower or upper bound:
+/// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
+/// replaced by the loop upper bound.
+/// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
+/// replaced by the loop lower bound.
+/// All loop induction variables are iteratively replaced, unless a
+/// `substituteOperation` hook is passed to more finely determine which
+/// operations are substituted.
+/// This is used as an intermediate step in computing bounding boxes and
+/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
+/// positive values (positive orthant assumptions).
+/// Return a new AffineMap, dims and symbols that have been canonicalized and
+/// simplified.
+AffineMapAndOperands substituteMin(
+ AffineMinOp affineMinOp,
+ llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
+
/// Converts Convolution op into vector contraction.
///
/// Conversion expects ConvOp to have dimensions marked in the *mask* as
/// Traverse the `dims` and substitute known min or max expressions in place of
/// induction variables in `exprs`.
-static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
- SmallVectorImpl<Value> &symbols) {
+static AffineMap substitute(
+ AffineMap map, SmallVectorImpl<Value> &dims,
+ SmallVectorImpl<Value> &symbols,
+ llvm::function_ref<bool(Operation *)> substituteOperation = nullptr) {
auto exprs = llvm::to_vector<4>(map.getResults());
for (AffineExpr &expr : exprs) {
bool substituted = true;
LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
AffineExpr substitutedExpr;
if (auto forOp = scf::getForInductionVarOwner(dim))
- substitutedExpr = substituteLoopInExpr(
- expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
- forOp.step(), dims, symbols);
+ if (!substituteOperation || substituteOperation(forOp))
+ substitutedExpr = substituteLoopInExpr(
+ expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
+ forOp.step(), dims, symbols);
if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim))
- for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
- ++idx)
- substitutedExpr = substituteLoopInExpr(
- expr, dimExpr, parallelForOp.lowerBound()[idx],
- parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
- dims, symbols);
+ if (!substituteOperation || substituteOperation(parallelForOp))
+ for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
+ ++idx)
+ substitutedExpr = substituteLoopInExpr(
+ expr, dimExpr, parallelForOp.lowerBound()[idx],
+ parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
+ dims, symbols);
if (!substitutedExpr)
continue;
exprs.front().getContext());
LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n");
+ LLVM_DEBUG(DBGS() << "Operands:\n");
+ for (Value v : operands)
+ LLVM_DEBUG(DBGS() << v << "\n");
// Pull in affine.apply operations and compose them fully into the
// result.
return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
}
+/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
+/// induction variables by new expressions involving the lower or upper bound:
+/// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
+/// replaced by the loop upper bound.
+/// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
+/// replaced by the loop lower bound.
+/// All loop induction variables are iteratively replaced, unless a
+/// `substituteOperation` hook is passed to more finely determine which
+/// operations are substituted.
+/// This is used as an intermediate step in computing bounding boxes and
+/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
+/// positive values (positive orthant assumptions).
+/// Return a new AffineMap, dims and symbols that have been canonicalized and
+/// simplified.
+AffineMapAndOperands mlir::linalg::substituteMin(
+ AffineMinOp affineMinOp,
+ llvm::function_ref<bool(Operation *)> substituteOperation) {
+ AffineMapAndOperands res{affineMinOp.getAffineMap(),
+ SmallVector<Value>(affineMinOp.getDimOperands()),
+ SmallVector<Value>(affineMinOp.getSymbolOperands())};
+ res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols,
+ substituteOperation);
+ return res;
+}
+
LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
AffineMinOp minOp, PatternRewriter &rewriter) const {
LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
<< "\n");
- SmallVector<Value, 4> dims(minOp.getDimOperands()),
- symbols(minOp.getSymbolOperands());
- AffineMap map = substitute(minOp.getAffineMap(), dims, symbols);
+ auto affineMapAndOperands = substituteMin(minOp);
+ AffineMap map = affineMapAndOperands.map;
LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
} else {
auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
- SmallVector<Value, 4> resultOperands = dims;
- resultOperands.append(symbols.begin(), symbols.end());
+ SmallVector<Value> resultOperands = affineMapAndOperands.dims;
+ llvm::append_range(resultOperands, affineMapAndOperands.symbols);
canonicalizeMapAndOperands(&resultMap, &resultOperands);
resultMap = simplifyAffineMap(resultMap);
rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,