Operation *newOp;
};
-static llvm::Optional<vector::CombiningKind>
-getKindForOp(Operation *reductionOp) {
- if (!reductionOp)
+llvm::Optional<vector::CombiningKind>
+mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
+ using ::mlir::vector::CombiningKind;
+
+ if (!combinerOp)
return llvm::None;
- return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
- reductionOp)
+ return llvm::TypeSwitch<Operation *, llvm::Optional<CombiningKind>>(
+ combinerOp)
.Case<arith::AddIOp, arith::AddFOp>(
- [&](auto op) { return vector::CombiningKind::ADD; })
- .Case<arith::AndIOp>([&](auto op) { return vector::CombiningKind::AND; })
- .Case<arith::MaxSIOp>(
- [&](auto op) { return vector::CombiningKind::MAXSI; })
- .Case<arith::MaxFOp>([&](auto op) { return vector::CombiningKind::MAXF; })
- .Case<arith::MinSIOp>(
- [&](auto op) { return vector::CombiningKind::MINSI; })
- .Case<arith::MinFOp>([&](auto op) { return vector::CombiningKind::MINF; })
+ [&](auto op) { return CombiningKind::ADD; })
+ .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
+ .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
+ .Case<arith::MaxFOp>([&](auto op) { return CombiningKind::MAXF; })
+ .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
+ .Case<arith::MinFOp>([&](auto op) { return CombiningKind::MINF; })
.Case<arith::MulIOp, arith::MulFOp>(
- [&](auto op) { return vector::CombiningKind::MUL; })
- .Case<arith::OrIOp>([&](auto op) { return vector::CombiningKind::OR; })
- .Case<arith::XOrIOp>([&](auto op) { return vector::CombiningKind::XOR; })
+ [&](auto op) { return CombiningKind::MUL; })
+ .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
+ .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
.Default([&](auto op) { return llvm::None; });
}
static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
Value valueToReduce,
const SmallVector<bool> &reductionMask) {
- auto maybeKind = getKindForOp(reduceOp);
+ auto maybeKind = getCombinerOpKind(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind");
return b.create<vector::MultiDimReductionOp>(
reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
}
for (OpOperand *opOperand : op.getOutputOperands()) {
Operation *reduceOp = matchLinalgReduction(opOperand);
- if (!reduceOp || !getKindForOp(reduceOp)) {
+ if (!reduceOp || !getCombinerOpKind(reduceOp)) {
LDBG("reduction precondition failed: reduction detection failed");
return failure();
}
if (!reduceOp)
return;
llvm::Optional<vector::CombiningKind> maybeKind;
- maybeKind = getKindForOp(reduceOp);
+ maybeKind = getCombinerOpKind(reduceOp);
if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
return;
- maybeKind = getKindForOp(&(linalgOp->getRegion(0).front().front()));
+ maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front()));
if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
return;