From: Diego Caballero Date: Wed, 22 Sep 2021 17:11:45 +0000 (+0000) Subject: [mlir] Create a generic reduction detection utility X-Git-Tag: upstream/15.0.7~30548 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2a876a711dc7c644936017daf20e78f48bfd2270;p=platform%2Fupstream%2Fllvm.git [mlir] Create a generic reduction detection utility This patch introduces a generic reduction detection utility that works across different dialecs. It is mostly a generalization of the reduction detection algorithm in Affine. The reduction detection logic in Affine, Linalg and SCFToOpenMP have been replaced with this new generic utility. The utility takes some basic components of the potential reduction and returns: 1) the reduced value, and 2) a list with the combiner operations. The logic to match reductions involving multiple combiner operations disabled until we can properly test it. Reviewed By: ftynse, bondhugula, nicolasvasilache, pifon2a Differential Revision: https://reviews.llvm.org/D110303 --- diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index b9c7d5e..145019f 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -22,6 +22,7 @@ namespace mlir { class AffineExpr; class AffineForOp; class AffineMap; +class BlockArgument; class MemRefType; class NestedPattern; class Operation; @@ -83,6 +84,37 @@ bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim, // TODO: extend this to check for memory-based dependence violation when we have // the support. bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef shifts); + +/// Utility to match a generic reduction given a list of iteration-carried +/// arguments, `iterCarriedArgs` and the position of the potential reduction +/// argument within the list, `redPos`. If a reduction is matched, returns the +/// reduced value and the topologically-sorted list of combiner operations +/// involved in the reduction. Otherwise, returns a null value. +/// +/// The matching algorithm relies on the following invariants, which are subject +/// to change: +/// 1. The first combiner operation must be a binary operation with the +/// iteration-carried value and the reduced value as operands. +/// 2. The iteration-carried value and combiner operations must be side +/// effect-free, have single result and a single use. +/// 3. Combiner operations must be immediately nested in the region op +/// performing the reduction. +/// 4. Reduction def-use chain must end in a terminator op that yields the +/// next iteration/output values in the same order as the iteration-carried +/// values in `iterCarriedArgs`. +/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values +/// of the region op performing the reduction. +/// +/// This utility is generic enough to detect reductions involving multiple +/// combiner operations (disabled for now) across multiple dialects, including +/// Linalg, Affine and SCF. For the sake of genericity, it does not return +/// specific enum values for the combiner operations since its goal is also +/// matching reductions without pre-defined semantics in core MLIR. It's up to +/// each client to make sense out of the list of combiner operations. It's also +/// up to each client to check for additional invariants on the expected +/// reductions not covered by this generic matching. +Value matchReduction(ArrayRef iterCarriedArgs, unsigned redPos, + SmallVectorImpl &combinerOps); } // end namespace mlir #endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 6aa2fb7..6977b26 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -597,6 +597,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { >, InterfaceMethod< /*desc=*/[{ + Return the output block arguments of the region. + }], + /*retTy=*/"Block::BlockArgListType", + /*methodName=*/"getRegionOutputArgs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + Block &entryBlock = this->getOperation()->getRegion(0).front(); + return entryBlock.getArguments().take_back(this->getNumOutputs()); + }] + >, + InterfaceMethod< + /*desc=*/[{ Return the `opOperand` shape or an empty vector for scalars. }], /*retTy=*/"ArrayRef", diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h index 89f50ab..7bf35c7 100644 --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -30,6 +30,7 @@ class MLIRContext; class Operation; class OperationName; class Type; +class Value; namespace detail { struct DiagnosticEngineImpl; @@ -218,6 +219,9 @@ public: return *this << *val; } + /// Stream in a Value. + Diagnostic &operator<<(Value val); + /// Stream in a range. template > std::enable_if_t::value, diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 01209cf..1966305 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -22,7 +23,6 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -33,29 +33,6 @@ using namespace mlir; using llvm::dbgs; -/// Returns true if `value` (transitively) depends on iteration arguments of the -/// given `forOp`. -static bool dependsOnIterArgs(Value value, AffineForOp forOp) { - // Compute the backward slice of the value. - SetVector slice; - getBackwardSlice(value, &slice, - [&](Operation *op) { return !forOp->isAncestor(op); }); - - // Check that none of the operands of the operations in the backward slice are - // loop iteration arguments, and neither is the value itself. - auto argRange = forOp.getRegionIterArgs(); - llvm::SmallPtrSet iterArgs(argRange.begin(), argRange.end()); - if (iterArgs.contains(value)) - return true; - - for (Operation *op : slice) - for (Value operand : op->getOperands()) - if (iterArgs.contains(operand)) - return true; - - return false; -} - /// Get the value that is being reduced by `pos`-th reduction in the loop if /// such a reduction can be performed by affine parallel loops. This assumes /// floating-point operations are commutative. On success, `kind` will be the @@ -63,18 +40,19 @@ static bool dependsOnIterArgs(Value value, AffineForOp forOp) { /// reduction is not supported, returns null. static Value getSupportedReduction(AffineForOp forOp, unsigned pos, AtomicRMWKind &kind) { - auto yieldOp = cast(forOp.getBody()->back()); - Value yielded = yieldOp.operands()[pos]; - Operation *definition = yielded.getDefiningOp(); - if (!definition) + SmallVector combinerOps; + Value reducedVal = + matchReduction(forOp.getRegionIterArgs(), pos, combinerOps); + if (!reducedVal) return nullptr; - if (!forOp.getRegionIterArgs()[pos].hasOneUse()) - return nullptr; - if (!yielded.hasOneUse()) + + // Expected only one combiner operation. + if (combinerOps.size() > 1) return nullptr; + Operation *combinerOp = combinerOps.back(); Optional maybeKind = - TypeSwitch>(definition) + TypeSwitch>(combinerOp) .Case([](Operation *) { return AtomicRMWKind::addf; }) .Case([](Operation *) { return AtomicRMWKind::mulf; }) .Case([](Operation *) { return AtomicRMWKind::addi; }) @@ -88,14 +66,7 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos, return nullptr; kind = *maybeKind; - if (definition->getOperand(0) == forOp.getRegionIterArgs()[pos] && - !dependsOnIterArgs(definition->getOperand(1), forOp)) - return definition->getOperand(1); - if (definition->getOperand(1) == forOp.getRegionIterArgs()[pos] && - !dependsOnIterArgs(definition->getOperand(0), forOp)) - return definition->getOperand(0); - - return nullptr; + return reducedVal; } /// Returns true if `forOp' is a parallel loop. If `parallelReductions` is diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 932559c..f26875b 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -15,11 +15,13 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" #include @@ -392,3 +394,105 @@ bool mlir::isOpwiseShiftValid(AffineForOp forOp, ArrayRef shifts) { } return true; } + +/// Returns true if `value` (transitively) depends on iteration-carried values +/// of the given `ancestorOp`. +static bool dependsOnCarriedVals(Value value, + ArrayRef iterCarriedArgs, + Operation *ancestorOp) { + // Compute the backward slice of the value. + SetVector slice; + getBackwardSlice(value, &slice, + [&](Operation *op) { return !ancestorOp->isAncestor(op); }); + + // Check that none of the operands of the operations in the backward slice are + // loop iteration arguments, and neither is the value itself. + SmallPtrSet iterCarriedValSet(iterCarriedArgs.begin(), + iterCarriedArgs.end()); + if (iterCarriedValSet.contains(value)) + return true; + + for (Operation *op : slice) + for (Value operand : op->getOperands()) + if (iterCarriedValSet.contains(operand)) + return true; + + return false; +} + +/// Utility to match a generic reduction given a list of iteration-carried +/// arguments, `iterCarriedArgs` and the position of the potential reduction +/// argument within the list, `redPos`. If a reduction is matched, returns the +/// reduced value and the topologically-sorted list of combiner operations +/// involved in the reduction. Otherwise, returns a null value. +/// +/// The matching algorithm relies on the following invariants, which are subject +/// to change: +/// 1. The first combiner operation must be a binary operation with the +/// iteration-carried value and the reduced value as operands. +/// 2. The iteration-carried value and combiner operations must be side +/// effect-free, have single result and a single use. +/// 3. Combiner operations must be immediately nested in the region op +/// performing the reduction. +/// 4. Reduction def-use chain must end in a terminator op that yields the +/// next iteration/output values in the same order as the iteration-carried +/// values in `iterCarriedArgs`. +/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values +/// of the region op performing the reduction. +/// +/// This utility is generic enough to detect reductions involving multiple +/// combiner operations (disabled for now) across multiple dialects, including +/// Linalg, Affine and SCF. For the sake of genericity, it does not return +/// specific enum values for the combiner operations since its goal is also +/// matching reductions without pre-defined semantics in core MLIR. It's up to +/// each client to make sense out of the list of combiner operations. It's also +/// up to each client to check for additional invariants on the expected +/// reductions not covered by this generic matching. +Value mlir::matchReduction(ArrayRef iterCarriedArgs, + unsigned redPos, + SmallVectorImpl &combinerOps) { + assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds"); + + BlockArgument redCarriedVal = iterCarriedArgs[redPos]; + if (!redCarriedVal.hasOneUse()) + return nullptr; + + // For now, the first combiner op must be a binary op. + Operation *combinerOp = *redCarriedVal.getUsers().begin(); + if (combinerOp->getNumOperands() != 2) + return nullptr; + Value reducedVal = combinerOp->getOperand(0) == redCarriedVal + ? combinerOp->getOperand(1) + : combinerOp->getOperand(0); + + Operation *redRegionOp = + iterCarriedArgs.front().getOwner()->getParent()->getParentOp(); + if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp)) + return nullptr; + + // Traverse the def-use chain starting from the first combiner op until a + // terminator is found. Gather all the combiner ops along the way in + // topological order. + while (!combinerOp->mightHaveTrait()) { + if (!MemoryEffectOpInterface::hasNoEffect(combinerOp) || + combinerOp->getNumResults() != 1 || !combinerOp->hasOneUse() || + combinerOp->getParentOp() != redRegionOp) + return nullptr; + + combinerOps.push_back(combinerOp); + combinerOp = *combinerOp->getUsers().begin(); + } + + // Limit matching to single combiner op until we can properly test reductions + // involving multiple combiners. + if (combinerOps.size() != 1) + return nullptr; + + // Check that the yielded value is in the same position as in + // `iterCarriedArgs`. + Operation *terminatorOp = combinerOp; + if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0]) + return nullptr; + + return reducedVal; +} diff --git a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt index 1a75a35..ce0fd9a 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRSCFToOpenMP Core LINK_LIBS PUBLIC + MLIRAnalysis MLIRLLVMIR MLIROpenMP MLIRSCF diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 9c6fc6f..ddd92b1 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "../PassDetail.h" +#include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/SCF.h" @@ -34,10 +35,21 @@ static bool matchSimpleReduction(Block &block) { if (block.empty() || llvm::hasSingleElement(block) || std::next(block.begin(), 2) != block.end()) return false; - return isa(block.front()) && + + if (block.getNumArguments() != 2) + return false; + + SmallVector combinerOps; + Value reducedVal = matchReduction({block.getArguments()[1]}, + /*redPos=*/0, combinerOps); + + if (!reducedVal || !reducedVal.isa() || + combinerOps.size() != 1) + return false; + + return isa(combinerOps[0]) && isa(block.back()) && - block.front().getOperands() == block.getArguments() && - block.back().getOperand(0) == block.front().getResult(0); + block.front().getOperands() == block.getArguments(); } /// Matches a block containing a select-based min/max reduction. The types of diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 913fa72..921f8ef 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -110,46 +111,24 @@ static VectorType extractVectorTypeFromShapedValue(Value v) { return VectorType::get(st.getShape(), st.getElementType()); } -/// Given an `outputOperand` of a LinalgOp, compute the intersection of the -/// forward slice starting from `outputOperand` and the backward slice -/// starting from the corresponding linalg.yield operand. -/// This intersection is assumed to have a single binary operation that is -/// the reduction operation. Multiple reduction operations would impose an +/// Check whether `outputOperand` is a reduction with a single combiner +/// operation. Return the combiner operation of the reduction, which is assumed +/// to be a binary operation. Multiple reduction operations would impose an /// ordering between reduction dimensions and is currently unsupported in -/// Linalg. This limitation is motivated by the fact that e.g. -/// min(max(X)) != max(min(X)) +/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != +/// max(min(X)) // TODO: use in LinalgOp verification, there is a circular dependency atm. static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) { auto linalgOp = cast(outputOperand->getOwner()); - auto yieldOp = cast(linalgOp->getRegion(0).front().getTerminator()); - unsigned yieldNum = + unsigned outputPos = outputOperand->getOperandNumber() - linalgOp.getNumInputs(); - llvm::SetVector backwardSlice, forwardSlice; - BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument( - outputOperand->getOperandNumber()); - Value yieldVal = yieldOp->getOperand(yieldNum); - getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) { - return op->getParentOp() == linalgOp; - }); - backwardSlice.insert(yieldVal.getDefiningOp()); - getForwardSlice(bbArg, &forwardSlice, - [&](Operation *op) { return op->getParentOp() == linalgOp; }); - // Search for the (assumed unique) elementwiseMappable op at the intersection - // of forward and backward slices. - Operation *reductionOp = nullptr; - for (Operation *op : llvm::reverse(backwardSlice)) { - if (!forwardSlice.contains(op)) - continue; - if (OpTrait::hasElementwiseMappableTraits(op)) { - if (reductionOp) { - // Reduction detection fails: found more than 1 elementwise-mappable op. - return nullptr; - } - reductionOp = op; - } - } + SmallVector combinerOps; + if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || + combinerOps.size() != 1) + return nullptr; + // TODO: also assert no other subsequent ops break the reduction. - return reductionOp; + return combinerOps[0]; } /// If `value` of assumed VectorType has a shape different than `shape`, try to diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index f657d08..622fe11 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -131,6 +131,14 @@ Diagnostic &Diagnostic::operator<<(Operation &val) { return *this << os.str(); } +/// Stream in a Value. +Diagnostic &Diagnostic::operator<<(Value val) { + std::string str; + llvm::raw_string_ostream os(str); + val.print(os); + return *this << os.str(); +} + /// Outputs this diagnostic to a stream. void Diagnostic::print(raw_ostream &os) const { for (auto &arg : getArguments()) diff --git a/mlir/test/Analysis/test-match-reduction.mlir b/mlir/test/Analysis/test-match-reduction.mlir new file mode 100644 index 0000000..0b80e09 --- /dev/null +++ b/mlir/test/Analysis/test-match-reduction.mlir @@ -0,0 +1,114 @@ +// RUN: mlir-opt %s -test-match-reduction -verify-diagnostics -split-input-file + +// Verify that the generic reduction detection utility works on different +// dialects. + +// expected-remark@below {{Testing function}} +func @linalg_red_add(%in0t : tensor, %out0t : tensor<1xf32>) { + // expected-remark@below {{Reduction found in output #0!}} + // expected-remark@below {{Reduced Value: of type 'f32' at index: 0}} + // expected-remark@below {{Combiner Op: %1 = addf %arg2, %arg3 : f32}} + %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (0)>], + iterator_types = ["reduction"]} + ins(%in0t : tensor) + outs(%out0t : tensor<1xf32>) { + ^bb0(%in0: f32, %out0: f32): + %add = addf %in0, %out0 : f32 + linalg.yield %add : f32 + } -> tensor<1xf32> + return +} + +// ----- + +// expected-remark@below {{Testing function}} +func @affine_red_add(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + // expected-remark@below {{Reduction found in output #0!}} + // expected-remark@below {{Reduced Value: %1 = affine.load %arg0[%arg2, %arg3] : memref<256x512xf32>}} + // expected-remark@below {{Combiner Op: %2 = addf %arg4, %1 : f32}} + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %red_iter, %ld : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// ----- + +// TODO: Iteration-carried values with multiple uses are not supported yet. +// expected-remark@below {{Testing function}} +func @linalg_red_max(%in0t: tensor<4x4xf32>, %out0t: tensor<4xf32>) { + // expected-remark@below {{Reduction NOT found in output #0!}} + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%in0t : tensor<4x4xf32>) + outs(%out0t : tensor<4xf32>) { + ^bb0(%in0: f32, %out0: f32): + %cmp = cmpf ogt, %in0, %out0 : f32 + %sel = select %cmp, %in0, %out0 : f32 + linalg.yield %sel : f32 + } -> tensor<4xf32> + return +} + +// ----- + +// expected-remark@below {{Testing function}} +func @linalg_fused_red_add(%in0t: tensor<4x4xf32>, %out0t: tensor<4xf32>) { + // expected-remark@below {{Reduction found in output #0!}} + // expected-remark@below {{Reduced Value: %2 = subf %1, %arg2 : f32}} + // expected-remark@below {{Combiner Op: %3 = addf %2, %arg3 : f32}} + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%in0t : tensor<4x4xf32>) + outs(%out0t : tensor<4xf32>) { + ^bb0(%in0: f32, %out0: f32): + %mul = mulf %in0, %in0 : f32 + %sub = subf %mul, %in0 : f32 + %add = addf %sub, %out0 : f32 + linalg.yield %add : f32 + } -> tensor<4xf32> + return +} + +// ----- + +// expected-remark@below {{Testing function}} +func @affine_no_red_rec(%in: memref<512xf32>) { + %cst = constant 0.000000e+00 : f32 + // %rec is the value loaded in the previous iteration. + // expected-remark@below {{Reduction NOT found in output #0!}} + %final_val = affine.for %j = 0 to 512 iter_args(%rec = %cst) -> (f32) { + %ld = affine.load %in[%j] : memref<512xf32> + %add = addf %ld, %rec : f32 + affine.yield %ld : f32 + } + return +} + +// ----- + +// expected-remark@below {{Testing function}} +func @affine_output_dep(%in: memref<512xf32>) { + %cst = constant 0.000000e+00 : f32 + // Reduction %red is not supported because it depends on another + // loop-carried dependence. + // expected-remark@below {{Reduction NOT found in output #0!}} + // expected-remark@below {{Reduction NOT found in output #1!}} + %final_red, %final_dep = affine.for %j = 0 to 512 + iter_args(%red = %cst, %dep = %cst) -> (f32, f32) { + %ld = affine.load %in[%j] : memref<512xf32> + %add = addf %dep, %red : f32 + affine.yield %add, %ld : f32, f32 + } + return +} + diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt index 013d4f3..aa9eadb 100644 --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_library(MLIRTestAnalysis TestAliasAnalysis.cpp TestCallGraph.cpp TestLiveness.cpp + TestMatchReduction.cpp TestMemRefBoundCheck.cpp TestMemRefDependenceCheck.cpp TestMemRefStrideCalculation.cpp diff --git a/mlir/test/lib/Analysis/TestMatchReduction.cpp b/mlir/test/lib/Analysis/TestMatchReduction.cpp new file mode 100644 index 0000000..bbc15c6 --- /dev/null +++ b/mlir/test/lib/Analysis/TestMatchReduction.cpp @@ -0,0 +1,86 @@ +//===- TestMatchReduction.cpp - Test the match reduction utility ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a test pass for the match reduction utility. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +void printReductionResult(Operation *redRegionOp, unsigned numOutput, + Value reducedValue, + ArrayRef combinerOps) { + if (reducedValue) { + redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!"; + redRegionOp->emitRemark("Reduced Value: ") << reducedValue; + for (Operation *combOp : combinerOps) + redRegionOp->emitRemark("Combiner Op: ") << *combOp; + + return; + } + + redRegionOp->emitRemark("Reduction NOT found in output #") + << numOutput << "!"; +} + +struct TestMatchReductionPass + : public PassWrapper { + StringRef getArgument() const final { return "test-match-reduction"; } + StringRef getDescription() const final { + return "Test the match reduction utility."; + } + + void runOnFunction() override { + FuncOp func = getFunction(); + func->emitRemark("Testing function"); + + func.walk([](Operation *op) { + if (isa(op)) + return; + + // Limit testing to ops with only one region. + if (op->getNumRegions() != 1) + return; + + Region ®ion = op->getRegion(0); + if (!region.hasOneBlock()) + return; + + // We expect all the tested region ops to have 1 input by default. The + // remaining arguments are assumed to be outputs/reductions and there must + // be at least one. + // TODO: Extend it to support more generic cases. + Block ®ionEntry = region.front(); + auto args = regionEntry.getArguments(); + if (args.size() < 2) + return; + + auto outputs = args.drop_front(); + for (int i = 0, size = outputs.size(); i < size; ++i) { + SmallVector combinerOps; + Value reducedValue = matchReduction(outputs, i, combinerOps); + printReductionResult(op, i, reducedValue, combinerOps); + } + }); + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestMatchReductionPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index b3d9e54..ec8d002 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -94,6 +94,7 @@ void registerTestLivenessPass(); void registerTestLoopFusion(); void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); +void registerTestMatchReductionPass(); void registerTestMathAlgebraicSimplificationPass(); void registerTestMathPolynomialApproximationPass(); void registerTestMemRefDependenceCheck(); @@ -183,6 +184,7 @@ void registerTestPasses() { mlir::test::registerTestLoopFusion(); mlir::test::registerTestLoopMappingPass(); mlir::test::registerTestLoopUnrollingPass(); + mlir::test::registerTestMatchReductionPass(); mlir::test::registerTestMathAlgebraicSimplificationPass(); mlir::test::registerTestMathPolynomialApproximationPass(); mlir::test::registerTestMemRefDependenceCheck(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index f4dfb37..ec5264c 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4336,6 +4336,7 @@ cc_library( hdrs = ["include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"], includes = ["include"], deps = [ + ":Analysis", ":ConversionPassIncGen", ":IR", ":LLVMDialect",