[mlir] Expose a function to get vector::CombiningKind from Operation*.
authorAlexander Belyaev <pifon@google.com>
Fri, 14 Jan 2022 07:23:27 +0000 (08:23 +0100)
committerAlexander Belyaev <pifon@google.com>
Fri, 14 Jan 2022 07:28:18 +0000 (08:28 +0100)
Differential Revision: https://reviews.llvm.org/D117283

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

index 739d1b6..8144744 100644 (file)
@@ -920,6 +920,9 @@ private:
   LinalgTransformationFilter filter;
 };
 
+/// Return vector::CombiningKind for the given op.
+llvm::Optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
+
 //===----------------------------------------------------------------------===//
 // Transformation and lowering options exposed as auxiliary structs.
 //===----------------------------------------------------------------------===//
index f78e179..86eaed9 100644 (file)
@@ -109,25 +109,25 @@ struct VectorizationResult {
   Operation *newOp;
 };
 
-static llvm::Optional<vector::CombiningKind>
-getKindForOp(Operation *reductionOp) {
-  if (!reductionOp)
+llvm::Optional<vector::CombiningKind>
+mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
+  using ::mlir::vector::CombiningKind;
+
+  if (!combinerOp)
     return llvm::None;
-  return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
-             reductionOp)
+  return llvm::TypeSwitch<Operation *, llvm::Optional<CombiningKind>>(
+             combinerOp)
       .Case<arith::AddIOp, arith::AddFOp>(
-          [&](auto op) { return vector::CombiningKind::ADD; })
-      .Case<arith::AndIOp>([&](auto op) { return vector::CombiningKind::AND; })
-      .Case<arith::MaxSIOp>(
-          [&](auto op) { return vector::CombiningKind::MAXSI; })
-      .Case<arith::MaxFOp>([&](auto op) { return vector::CombiningKind::MAXF; })
-      .Case<arith::MinSIOp>(
-          [&](auto op) { return vector::CombiningKind::MINSI; })
-      .Case<arith::MinFOp>([&](auto op) { return vector::CombiningKind::MINF; })
+          [&](auto op) { return CombiningKind::ADD; })
+      .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
+      .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
+      .Case<arith::MaxFOp>([&](auto op) { return CombiningKind::MAXF; })
+      .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
+      .Case<arith::MinFOp>([&](auto op) { return CombiningKind::MINF; })
       .Case<arith::MulIOp, arith::MulFOp>(
-          [&](auto op) { return vector::CombiningKind::MUL; })
-      .Case<arith::OrIOp>([&](auto op) { return vector::CombiningKind::OR; })
-      .Case<arith::XOrIOp>([&](auto op) { return vector::CombiningKind::XOR; })
+          [&](auto op) { return CombiningKind::MUL; })
+      .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
+      .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
       .Default([&](auto op) { return llvm::None; });
 }
 
@@ -174,7 +174,7 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
 static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
                                  Value valueToReduce,
                                  const SmallVector<bool> &reductionMask) {
-  auto maybeKind = getKindForOp(reduceOp);
+  auto maybeKind = getCombinerOpKind(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
   return b.create<vector::MultiDimReductionOp>(
       reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
@@ -589,7 +589,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
   }
   for (OpOperand *opOperand : op.getOutputOperands()) {
     Operation *reduceOp = matchLinalgReduction(opOperand);
-    if (!reduceOp || !getKindForOp(reduceOp)) {
+    if (!reduceOp || !getCombinerOpKind(reduceOp)) {
       LDBG("reduction precondition failed: reduction detection failed");
       return failure();
     }
@@ -1458,10 +1458,10 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     if (!reduceOp)
       return;
     llvm::Optional<vector::CombiningKind> maybeKind;
-    maybeKind = getKindForOp(reduceOp);
+    maybeKind = getCombinerOpKind(reduceOp);
     if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
       return;
-    maybeKind = getKindForOp(&(linalgOp->getRegion(0).front().front()));
+    maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front()));
     if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
       return;