[mlir] Create a generic reduction detection utility
authorDiego Caballero <diegocaballero@google.com>
Wed, 22 Sep 2021 17:11:45 +0000 (17:11 +0000)
committerDiego Caballero <diegocaballero@google.com>
Fri, 24 Sep 2021 20:45:59 +0000 (20:45 +0000)
This patch introduces a generic reduction detection utility that works
across different dialecs. It is mostly a generalization of the reduction
detection algorithm in Affine. The reduction detection logic in Affine,
Linalg and SCFToOpenMP have been replaced with this new generic utility.

The utility takes some basic components of the potential reduction and
returns: 1) the reduced value, and 2) a list with the combiner operations.
The logic to match reductions involving multiple combiner operations disabled
until we can properly test it.

Reviewed By: ftynse, bondhugula, nicolasvasilache, pifon2a

Differential Revision: https://reviews.llvm.org/D110303

14 files changed:
mlir/include/mlir/Analysis/LoopAnalysis.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/IR/Diagnostics.h
mlir/lib/Analysis/AffineAnalysis.cpp
mlir/lib/Analysis/LoopAnalysis.cpp
mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/IR/Diagnostics.cpp
mlir/test/Analysis/test-match-reduction.mlir [new file with mode: 0644]
mlir/test/lib/Analysis/CMakeLists.txt
mlir/test/lib/Analysis/TestMatchReduction.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index b9c7d5e..145019f 100644 (file)
@@ -22,6 +22,7 @@ namespace mlir {
 class AffineExpr;
 class AffineForOp;
 class AffineMap;
+class BlockArgument;
 class MemRefType;
 class NestedPattern;
 class Operation;
@@ -83,6 +84,37 @@ bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim,
 // TODO: extend this to check for memory-based dependence violation when we have
 // the support.
 bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts);
+
+/// Utility to match a generic reduction given a list of iteration-carried
+/// arguments, `iterCarriedArgs` and the position of the potential reduction
+/// argument within the list, `redPos`. If a reduction is matched, returns the
+/// reduced value and the topologically-sorted list of combiner operations
+/// involved in the reduction. Otherwise, returns a null value.
+///
+/// The matching algorithm relies on the following invariants, which are subject
+/// to change:
+///  1. The first combiner operation must be a binary operation with the
+///     iteration-carried value and the reduced value as operands.
+///  2. The iteration-carried value and combiner operations must be side
+///     effect-free, have single result and a single use.
+///  3. Combiner operations must be immediately nested in the region op
+///     performing the reduction.
+///  4. Reduction def-use chain must end in a terminator op that yields the
+///     next iteration/output values in the same order as the iteration-carried
+///     values in `iterCarriedArgs`.
+///  5. `iterCarriedArgs` must contain all the iteration-carried/output values
+///     of the region op performing the reduction.
+///
+/// This utility is generic enough to detect reductions involving multiple
+/// combiner operations (disabled for now) across multiple dialects, including
+/// Linalg, Affine and SCF. For the sake of genericity, it does not return
+/// specific enum values for the combiner operations since its goal is also
+/// matching reductions without pre-defined semantics in core MLIR. It's up to
+/// each client to make sense out of the list of combiner operations. It's also
+/// up to each client to check for additional invariants on the expected
+/// reductions not covered by this generic matching.
+Value matchReduction(ArrayRef<BlockArgument> iterCarriedArgs, unsigned redPos,
+                     SmallVectorImpl<Operation *> &combinerOps);
 } // end namespace mlir
 
 #endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H
index 6aa2fb7..6977b26 100644 (file)
@@ -597,6 +597,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
+        Return the output block arguments of the region.
+      }],
+      /*retTy=*/"Block::BlockArgListType",
+      /*methodName=*/"getRegionOutputArgs",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        Block &entryBlock = this->getOperation()->getRegion(0).front();
+        return entryBlock.getArguments().take_back(this->getNumOutputs());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
         Return the `opOperand` shape or an empty vector for scalars.
       }],
       /*retTy=*/"ArrayRef<int64_t>",
index 89f50ab..7bf35c7 100644 (file)
@@ -30,6 +30,7 @@ class MLIRContext;
 class Operation;
 class OperationName;
 class Type;
+class Value;
 
 namespace detail {
 struct DiagnosticEngineImpl;
@@ -218,6 +219,9 @@ public:
     return *this << *val;
   }
 
+  /// Stream in a Value.
+  Diagnostic &operator<<(Value val);
+
   /// Stream in a range.
   template <typename T, typename ValueT = llvm::detail::ValueOfRange<T>>
   std::enable_if_t<!std::is_constructible<DiagnosticArgument, T>::value,
index 01209cf..1966305 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -22,7 +23,6 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -33,29 +33,6 @@ using namespace mlir;
 
 using llvm::dbgs;
 
-/// Returns true if `value` (transitively) depends on iteration arguments of the
-/// given `forOp`.
-static bool dependsOnIterArgs(Value value, AffineForOp forOp) {
-  // Compute the backward slice of the value.
-  SetVector<Operation *> slice;
-  getBackwardSlice(value, &slice,
-                   [&](Operation *op) { return !forOp->isAncestor(op); });
-
-  // Check that none of the operands of the operations in the backward slice are
-  // loop iteration arguments, and neither is the value itself.
-  auto argRange = forOp.getRegionIterArgs();
-  llvm::SmallPtrSet<Value, 8> iterArgs(argRange.begin(), argRange.end());
-  if (iterArgs.contains(value))
-    return true;
-
-  for (Operation *op : slice)
-    for (Value operand : op->getOperands())
-      if (iterArgs.contains(operand))
-        return true;
-
-  return false;
-}
-
 /// Get the value that is being reduced by `pos`-th reduction in the loop if
 /// such a reduction can be performed by affine parallel loops. This assumes
 /// floating-point operations are commutative. On success, `kind` will be the
@@ -63,18 +40,19 @@ static bool dependsOnIterArgs(Value value, AffineForOp forOp) {
 /// reduction is not supported, returns null.
 static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
                                    AtomicRMWKind &kind) {
-  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->back());
-  Value yielded = yieldOp.operands()[pos];
-  Operation *definition = yielded.getDefiningOp();
-  if (!definition)
+  SmallVector<Operation *> combinerOps;
+  Value reducedVal =
+      matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
+  if (!reducedVal)
     return nullptr;
-  if (!forOp.getRegionIterArgs()[pos].hasOneUse())
-    return nullptr;
-  if (!yielded.hasOneUse())
+
+  // Expected only one combiner operation.
+  if (combinerOps.size() > 1)
     return nullptr;
 
+  Operation *combinerOp = combinerOps.back();
   Optional<AtomicRMWKind> maybeKind =
-      TypeSwitch<Operation *, Optional<AtomicRMWKind>>(definition)
+      TypeSwitch<Operation *, Optional<AtomicRMWKind>>(combinerOp)
           .Case<AddFOp>([](Operation *) { return AtomicRMWKind::addf; })
           .Case<MulFOp>([](Operation *) { return AtomicRMWKind::mulf; })
           .Case<AddIOp>([](Operation *) { return AtomicRMWKind::addi; })
@@ -88,14 +66,7 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
     return nullptr;
 
   kind = *maybeKind;
-  if (definition->getOperand(0) == forOp.getRegionIterArgs()[pos] &&
-      !dependsOnIterArgs(definition->getOperand(1), forOp))
-    return definition->getOperand(1);
-  if (definition->getOperand(1) == forOp.getRegionIterArgs()[pos] &&
-      !dependsOnIterArgs(definition->getOperand(0), forOp))
-    return definition->getOperand(0);
-
-  return nullptr;
+  return reducedVal;
 }
 
 /// Returns true if `forOp' is a parallel loop. If `parallelReductions` is
index 932559c..f26875b 100644 (file)
 #include "mlir/Analysis/AffineAnalysis.h"
 #include "mlir/Analysis/AffineStructures.h"
 #include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Support/MathExtras.h"
 
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallString.h"
 #include <type_traits>
 
@@ -392,3 +394,105 @@ bool mlir::isOpwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts) {
   }
   return true;
 }
+
+/// Returns true if `value` (transitively) depends on iteration-carried values
+/// of the given `ancestorOp`.
+static bool dependsOnCarriedVals(Value value,
+                                 ArrayRef<BlockArgument> iterCarriedArgs,
+                                 Operation *ancestorOp) {
+  // Compute the backward slice of the value.
+  SetVector<Operation *> slice;
+  getBackwardSlice(value, &slice,
+                   [&](Operation *op) { return !ancestorOp->isAncestor(op); });
+
+  // Check that none of the operands of the operations in the backward slice are
+  // loop iteration arguments, and neither is the value itself.
+  SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
+                                          iterCarriedArgs.end());
+  if (iterCarriedValSet.contains(value))
+    return true;
+
+  for (Operation *op : slice)
+    for (Value operand : op->getOperands())
+      if (iterCarriedValSet.contains(operand))
+        return true;
+
+  return false;
+}
+
+/// Utility to match a generic reduction given a list of iteration-carried
+/// arguments, `iterCarriedArgs` and the position of the potential reduction
+/// argument within the list, `redPos`. If a reduction is matched, returns the
+/// reduced value and the topologically-sorted list of combiner operations
+/// involved in the reduction. Otherwise, returns a null value.
+///
+/// The matching algorithm relies on the following invariants, which are subject
+/// to change:
+///  1. The first combiner operation must be a binary operation with the
+///     iteration-carried value and the reduced value as operands.
+///  2. The iteration-carried value and combiner operations must be side
+///     effect-free, have single result and a single use.
+///  3. Combiner operations must be immediately nested in the region op
+///     performing the reduction.
+///  4. Reduction def-use chain must end in a terminator op that yields the
+///     next iteration/output values in the same order as the iteration-carried
+///     values in `iterCarriedArgs`.
+///  5. `iterCarriedArgs` must contain all the iteration-carried/output values
+///     of the region op performing the reduction.
+///
+/// This utility is generic enough to detect reductions involving multiple
+/// combiner operations (disabled for now) across multiple dialects, including
+/// Linalg, Affine and SCF. For the sake of genericity, it does not return
+/// specific enum values for the combiner operations since its goal is also
+/// matching reductions without pre-defined semantics in core MLIR. It's up to
+/// each client to make sense out of the list of combiner operations. It's also
+/// up to each client to check for additional invariants on the expected
+/// reductions not covered by this generic matching.
+Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
+                           unsigned redPos,
+                           SmallVectorImpl<Operation *> &combinerOps) {
+  assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
+
+  BlockArgument redCarriedVal = iterCarriedArgs[redPos];
+  if (!redCarriedVal.hasOneUse())
+    return nullptr;
+
+  // For now, the first combiner op must be a binary op.
+  Operation *combinerOp = *redCarriedVal.getUsers().begin();
+  if (combinerOp->getNumOperands() != 2)
+    return nullptr;
+  Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
+                         ? combinerOp->getOperand(1)
+                         : combinerOp->getOperand(0);
+
+  Operation *redRegionOp =
+      iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
+  if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
+    return nullptr;
+
+  // Traverse the def-use chain starting from the first combiner op until a
+  // terminator is found. Gather all the combiner ops along the way in
+  // topological order.
+  while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
+    if (!MemoryEffectOpInterface::hasNoEffect(combinerOp) ||
+        combinerOp->getNumResults() != 1 || !combinerOp->hasOneUse() ||
+        combinerOp->getParentOp() != redRegionOp)
+      return nullptr;
+
+    combinerOps.push_back(combinerOp);
+    combinerOp = *combinerOp->getUsers().begin();
+  }
+
+  // Limit matching to single combiner op until we can properly test reductions
+  // involving multiple combiners.
+  if (combinerOps.size() != 1)
+    return nullptr;
+
+  // Check that the yielded value is in the same position as in
+  // `iterCarriedArgs`.
+  Operation *terminatorOp = combinerOp;
+  if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
+    return nullptr;
+
+  return reducedVal;
+}
index 1a75a35..ce0fd9a 100644 (file)
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRSCFToOpenMP
   Core
 
   LINK_LIBS PUBLIC
+  MLIRAnalysis
   MLIRLLVMIR
   MLIROpenMP
   MLIRSCF
index 9c6fc6f..ddd92b1 100644 (file)
@@ -13,6 +13,7 @@
 
 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
 #include "../PassDetail.h"
+#include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/SCF/SCF.h"
@@ -34,10 +35,21 @@ static bool matchSimpleReduction(Block &block) {
   if (block.empty() || llvm::hasSingleElement(block) ||
       std::next(block.begin(), 2) != block.end())
     return false;
-  return isa<OpTy...>(block.front()) &&
+
+  if (block.getNumArguments() != 2)
+    return false;
+
+  SmallVector<Operation *, 4> combinerOps;
+  Value reducedVal = matchReduction({block.getArguments()[1]},
+                                    /*redPos=*/0, combinerOps);
+
+  if (!reducedVal || !reducedVal.isa<BlockArgument>() ||
+      combinerOps.size() != 1)
+    return false;
+
+  return isa<OpTy...>(combinerOps[0]) &&
          isa<scf::ReduceReturnOp>(block.back()) &&
-         block.front().getOperands() == block.getArguments() &&
-         block.back().getOperand(0) == block.front().getResult(0);
+         block.front().getOperands() == block.getArguments();
 }
 
 /// Matches a block containing a select-based min/max reduction. The types of
index 913fa72..921f8ef 100644 (file)
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
@@ -110,46 +111,24 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
   return VectorType::get(st.getShape(), st.getElementType());
 }
 
-/// Given an `outputOperand` of a LinalgOp, compute the intersection of the
-/// forward slice starting from `outputOperand` and the backward slice
-/// starting from the corresponding linalg.yield operand.
-/// This intersection is assumed to have a single binary operation that is
-/// the reduction operation. Multiple reduction operations would impose an
+/// Check whether `outputOperand` is a reduction with a single combiner
+/// operation. Return the combiner operation of the reduction, which is assumed
+/// to be a binary operation. Multiple reduction operations would impose an
 /// ordering between reduction dimensions and is currently unsupported in
-/// Linalg. This limitation is motivated by the fact that e.g.
-/// min(max(X)) != max(min(X))
+/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
+/// max(min(X))
 // TODO: use in LinalgOp verification, there is a circular dependency atm.
 static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) {
   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
-  auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
-  unsigned yieldNum =
+  unsigned outputPos =
       outputOperand->getOperandNumber() - linalgOp.getNumInputs();
-  llvm::SetVector<Operation *> backwardSlice, forwardSlice;
-  BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
-      outputOperand->getOperandNumber());
-  Value yieldVal = yieldOp->getOperand(yieldNum);
-  getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
-    return op->getParentOp() == linalgOp;
-  });
-  backwardSlice.insert(yieldVal.getDefiningOp());
-  getForwardSlice(bbArg, &forwardSlice,
-                  [&](Operation *op) { return op->getParentOp() == linalgOp; });
-  // Search for the (assumed unique) elementwiseMappable op at the intersection
-  // of forward and backward slices.
-  Operation *reductionOp = nullptr;
-  for (Operation *op : llvm::reverse(backwardSlice)) {
-    if (!forwardSlice.contains(op))
-      continue;
-    if (OpTrait::hasElementwiseMappableTraits(op)) {
-      if (reductionOp) {
-        // Reduction detection fails: found more than 1 elementwise-mappable op.
-        return nullptr;
-      }
-      reductionOp = op;
-    }
-  }
+  SmallVector<Operation *, 4> combinerOps;
+  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
+      combinerOps.size() != 1)
+    return nullptr;
+
   // TODO: also assert no other subsequent ops break the reduction.
-  return reductionOp;
+  return combinerOps[0];
 }
 
 /// If `value` of assumed VectorType has a shape different than `shape`, try to
index f657d08..622fe11 100644 (file)
@@ -131,6 +131,14 @@ Diagnostic &Diagnostic::operator<<(Operation &val) {
   return *this << os.str();
 }
 
+/// Stream in a Value.
+Diagnostic &Diagnostic::operator<<(Value val) {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  val.print(os);
+  return *this << os.str();
+}
+
 /// Outputs this diagnostic to a stream.
 void Diagnostic::print(raw_ostream &os) const {
   for (auto &arg : getArguments())
diff --git a/mlir/test/Analysis/test-match-reduction.mlir b/mlir/test/Analysis/test-match-reduction.mlir
new file mode 100644 (file)
index 0000000..0b80e09
--- /dev/null
@@ -0,0 +1,114 @@
+// RUN: mlir-opt %s -test-match-reduction -verify-diagnostics -split-input-file
+
+// Verify that the generic reduction detection utility works on different
+// dialects.
+
+// expected-remark@below {{Testing function}}
+func @linalg_red_add(%in0t : tensor<?xf32>, %out0t : tensor<1xf32>) {
+  // expected-remark@below {{Reduction found in output #0!}}
+  // expected-remark@below {{Reduced Value: <block argument> of type 'f32' at index: 0}}
+  // expected-remark@below {{Combiner Op: %1 = addf %arg2, %arg3 : f32}}
+  %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                          affine_map<(d0) -> (0)>],
+                                          iterator_types = ["reduction"]}
+   ins(%in0t : tensor<?xf32>)
+   outs(%out0t : tensor<1xf32>) {
+    ^bb0(%in0: f32, %out0: f32):
+      %add = addf %in0, %out0 : f32
+      linalg.yield %add : f32
+    } -> tensor<1xf32>
+  return
+}
+
+// -----
+
+// expected-remark@below {{Testing function}}
+func @affine_red_add(%in: memref<256x512xf32>, %out: memref<256xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ affine.for %i = 0 to 256 {
+   // expected-remark@below {{Reduction found in output #0!}}
+   // expected-remark@below {{Reduced Value: %1 = affine.load %arg0[%arg2, %arg3] : memref<256x512xf32>}}
+   // expected-remark@below {{Combiner Op: %2 = addf %arg4, %1 : f32}}
+   %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
+     %ld = affine.load %in[%i, %j] : memref<256x512xf32>
+     %add = addf %red_iter, %ld : f32
+     affine.yield %add : f32
+   }
+   affine.store %final_red, %out[%i] : memref<256xf32>
+ }
+ return
+}
+
+// -----
+
+// TODO: Iteration-carried values with multiple uses are not supported yet.
+// expected-remark@below {{Testing function}}
+func @linalg_red_max(%in0t: tensor<4x4xf32>, %out0t: tensor<4xf32>) {
+  // expected-remark@below {{Reduction NOT found in output #0!}}
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d0)>],
+   iterator_types = ["parallel", "reduction"]}
+   ins(%in0t : tensor<4x4xf32>)
+   outs(%out0t : tensor<4xf32>) {
+    ^bb0(%in0: f32, %out0: f32):
+      %cmp = cmpf ogt, %in0, %out0 : f32
+      %sel = select %cmp, %in0, %out0 : f32
+      linalg.yield %sel : f32
+    } -> tensor<4xf32>
+  return
+}
+
+// -----
+
+// expected-remark@below {{Testing function}}
+func @linalg_fused_red_add(%in0t: tensor<4x4xf32>, %out0t: tensor<4xf32>) {
+  // expected-remark@below {{Reduction found in output #0!}}
+  // expected-remark@below {{Reduced Value: %2 = subf %1, %arg2 : f32}}
+  // expected-remark@below {{Combiner Op: %3 = addf %2, %arg3 : f32}}
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d0)>],
+   iterator_types = ["parallel", "reduction"]}
+   ins(%in0t : tensor<4x4xf32>)
+   outs(%out0t : tensor<4xf32>) {
+    ^bb0(%in0: f32, %out0: f32):
+      %mul = mulf %in0, %in0 : f32
+      %sub = subf %mul, %in0 : f32
+      %add = addf %sub, %out0 : f32
+      linalg.yield %add : f32
+    } -> tensor<4xf32>
+  return
+}
+
+// -----
+
+// expected-remark@below {{Testing function}}
+func @affine_no_red_rec(%in: memref<512xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ // %rec is the value loaded in the previous iteration.
+ // expected-remark@below {{Reduction NOT found in output #0!}}
+ %final_val = affine.for %j = 0 to 512 iter_args(%rec = %cst) -> (f32) {
+   %ld = affine.load %in[%j] : memref<512xf32>
+   %add = addf %ld, %rec : f32
+   affine.yield %ld : f32
+ }
+ return
+}
+
+// -----
+
+// expected-remark@below {{Testing function}}
+func @affine_output_dep(%in: memref<512xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ // Reduction %red is not supported because it depends on another
+ // loop-carried dependence.
+ // expected-remark@below {{Reduction NOT found in output #0!}}
+ // expected-remark@below {{Reduction NOT found in output #1!}}
+ %final_red, %final_dep = affine.for %j = 0 to 512
+  iter_args(%red = %cst, %dep = %cst) -> (f32, f32) {
+   %ld = affine.load %in[%j] : memref<512xf32>
+   %add = addf %dep, %red : f32
+   affine.yield %add, %ld : f32, f32
+ }
+ return
+}
+
index 013d4f3..aa9eadb 100644 (file)
@@ -3,6 +3,7 @@ add_mlir_library(MLIRTestAnalysis
   TestAliasAnalysis.cpp
   TestCallGraph.cpp
   TestLiveness.cpp
+  TestMatchReduction.cpp
   TestMemRefBoundCheck.cpp
   TestMemRefDependenceCheck.cpp
   TestMemRefStrideCalculation.cpp
diff --git a/mlir/test/lib/Analysis/TestMatchReduction.cpp b/mlir/test/lib/Analysis/TestMatchReduction.cpp
new file mode 100644 (file)
index 0000000..bbc15c6
--- /dev/null
@@ -0,0 +1,86 @@
+//===- TestMatchReduction.cpp - Test the match reduction utility ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a test pass for the match reduction utility.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+void printReductionResult(Operation *redRegionOp, unsigned numOutput,
+                          Value reducedValue,
+                          ArrayRef<Operation *> combinerOps) {
+  if (reducedValue) {
+    redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!";
+    redRegionOp->emitRemark("Reduced Value: ") << reducedValue;
+    for (Operation *combOp : combinerOps)
+      redRegionOp->emitRemark("Combiner Op: ") << *combOp;
+
+    return;
+  }
+
+  redRegionOp->emitRemark("Reduction NOT found in output #")
+      << numOutput << "!";
+}
+
+struct TestMatchReductionPass
+    : public PassWrapper<TestMatchReductionPass, FunctionPass> {
+  StringRef getArgument() const final { return "test-match-reduction"; }
+  StringRef getDescription() const final {
+    return "Test the match reduction utility.";
+  }
+
+  void runOnFunction() override {
+    FuncOp func = getFunction();
+    func->emitRemark("Testing function");
+
+    func.walk<WalkOrder::PreOrder>([](Operation *op) {
+      if (isa<FuncOp>(op))
+        return;
+
+      // Limit testing to ops with only one region.
+      if (op->getNumRegions() != 1)
+        return;
+
+      Region &region = op->getRegion(0);
+      if (!region.hasOneBlock())
+        return;
+
+      // We expect all the tested region ops to have 1 input by default. The
+      // remaining arguments are assumed to be outputs/reductions and there must
+      // be at least one.
+      // TODO: Extend it to support more generic cases.
+      Block &regionEntry = region.front();
+      auto args = regionEntry.getArguments();
+      if (args.size() < 2)
+        return;
+
+      auto outputs = args.drop_front();
+      for (int i = 0, size = outputs.size(); i < size; ++i) {
+        SmallVector<Operation *, 4> combinerOps;
+        Value reducedValue = matchReduction(outputs, i, combinerOps);
+        printReductionResult(op, i, reducedValue, combinerOps);
+      }
+    });
+  }
+};
+
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestMatchReductionPass() {
+  PassRegistration<TestMatchReductionPass>();
+}
+} // namespace test
+} // namespace mlir
index b3d9e54..ec8d002 100644 (file)
@@ -94,6 +94,7 @@ void registerTestLivenessPass();
 void registerTestLoopFusion();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
+void registerTestMatchReductionPass();
 void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
@@ -183,6 +184,7 @@ void registerTestPasses() {
   mlir::test::registerTestLoopFusion();
   mlir::test::registerTestLoopMappingPass();
   mlir::test::registerTestLoopUnrollingPass();
+  mlir::test::registerTestMatchReductionPass();
   mlir::test::registerTestMathAlgebraicSimplificationPass();
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
index f4dfb37..ec5264c 100644 (file)
@@ -4336,6 +4336,7 @@ cc_library(
     hdrs = ["include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"],
     includes = ["include"],
     deps = [
+        ":Analysis",
         ":ConversionPassIncGen",
         ":IR",
         ":LLVMDialect",