From 80e287108771685d1eb20ad3f27f4459068604f0 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 9 Jul 2019 05:26:18 -0700 Subject: [PATCH] Extend AffineToGPU to support Linalg loops Extend the utility that converts affine loop nests to support other types of loops by abstracting away common behavior through templates. This also slightly simplifies the existing Affine to GPU conversion by always passing in the loop step as an additional kernel argument even though it is a known constant. If it is used, it will be propagated into the loop body by the existing canonicalization pattern and can be further constant-folded, otherwise it will be dropped by canonicalization. This prepares for the common loop abstraction that will be used for converting to GPU kernels, which is conceptually close to Linalg loops, while maintaining the existing conversion operational. PiperOrigin-RevId: 257172216 --- .../AffineToGPU.h => LoopsToGPU/LoopsToGPU.h} | 26 +- .../LoopsToGPUPass.h} | 16 +- mlir/include/mlir/Transforms/RegionUtils.h | 2 +- mlir/lib/Conversion/AffineToGPU/AffineToGPU.cpp | 220 -------------- mlir/lib/Conversion/AffineToGPU/CMakeLists.txt | 20 -- mlir/lib/Conversion/CMakeLists.txt | 2 +- mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt | 21 ++ mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp | 338 +++++++++++++++++++++ .../LoopsToGPUPass.cpp} | 25 +- mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir | 30 ++ .../{AffineToGPU => LoopsToGPU}/step_one.mlir | 30 +- .../{AffineToGPU => LoopsToGPU}/step_positive.mlir | 18 +- mlir/tools/mlir-opt/CMakeLists.txt | 2 +- 13 files changed, 462 insertions(+), 288 deletions(-) rename mlir/include/mlir/Conversion/{AffineToGPU/AffineToGPU.h => LoopsToGPU/LoopsToGPU.h} (60%) rename mlir/include/mlir/Conversion/{AffineToGPU/AffineToGPUPass.h => LoopsToGPU/LoopsToGPUPass.h} (68%) delete mode 100644 mlir/lib/Conversion/AffineToGPU/AffineToGPU.cpp delete mode 100644 mlir/lib/Conversion/AffineToGPU/CMakeLists.txt create mode 100644 mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt create mode 100644 mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp rename mlir/lib/Conversion/{AffineToGPU/AffineToGPUPass.cpp => LoopsToGPU/LoopsToGPUPass.cpp} (73%) create mode 100644 mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir rename mlir/test/Conversion/{AffineToGPU => LoopsToGPU}/step_one.mlir (76%) rename mlir/test/Conversion/{AffineToGPU => LoopsToGPU}/step_positive.mlir (54%) diff --git a/mlir/include/mlir/Conversion/AffineToGPU/AffineToGPU.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h similarity index 60% rename from mlir/include/mlir/Conversion/AffineToGPU/AffineToGPU.h rename to mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h index 36baad8..a1c94f1 100644 --- a/mlir/include/mlir/Conversion/AffineToGPU/AffineToGPU.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h @@ -1,4 +1,4 @@ -//===- AffineToGPU.h - Convert an affine loops to GPU kernels ---*- C++ -*-===// +//===- LoopsToGPU.h - Convert loop nests to GPU kernels ---------*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // @@ -14,13 +14,17 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef MLIR_CONVERSION_AFFINETOGPU_AFFINETOGPU_H_ -#define MLIR_CONVERSION_AFFINETOGPU_AFFINETOGPU_H_ +#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ +#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ namespace mlir { class AffineForOp; struct LogicalResult; +namespace linalg { +class ForOp; +} + /// Convert a perfect affine loop nest with the outermost loop identified by /// `forOp` into a gpu::Launch operation. Map `numBlockDims` outer loops to /// GPU blocks and `numThreadDims` to GPU threads. The bounds of the loops that @@ -34,6 +38,20 @@ struct LogicalResult; LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp, unsigned numBlockDims, unsigned numThreadDims); + +/// Convert a perfect linalg loop nest with the outermost loop identified by +/// `forOp` into a gpu::Launch operation. Map `numBlockDims` outer loops to +/// GPU blocks and `numThreadDims` to GPU threads. The bounds of the loops that +/// are mapped should be independent of the induction variables of the other +/// mapped loops. +/// +/// No check on the size of the block or grid, or on the validity of +/// parallelization is performed, it is under the responsibility of the caller +/// to strip-mine the loops and to perform the dependence analysis before +/// calling the conversion. +LogicalResult convertLinalgLoopNestToGPULaunch(linalg::ForOp forOp, + unsigned numBlockDims, + unsigned numThreadDims); } // namespace mlir -#endif // MLIR_CONVERSION_AFFINETOGPU_AFFINETOGPU_H_ +#endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ diff --git a/mlir/include/mlir/Conversion/AffineToGPU/AffineToGPUPass.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h similarity index 68% rename from mlir/include/mlir/Conversion/AffineToGPU/AffineToGPUPass.h rename to mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h index 1ed03da..52f0dd4 100644 --- a/mlir/include/mlir/Conversion/AffineToGPU/AffineToGPUPass.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h @@ -1,4 +1,4 @@ -//===- AffineToGPUPass.h - Pass converting loops to GPU kernels -*- C++ -*-===// +//===- LoopsToGPUPass.h - Pass converting loops to GPU kernels --*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // @@ -14,22 +14,22 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef MLIR_CONVERSION_AFFINETOGPU_AFFINETOGPUPASS_H_ -#define MLIR_CONVERSION_AFFINETOGPU_AFFINETOGPUPASS_H_ +#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ +#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ namespace mlir { class FunctionPassBase; /// Create a pass that converts loop nests into GPU kernels. It considers -/// top-level affine.for operations as roots of loop nests and converts them -/// to the gpu.launch operations if possible. +/// top-level affine.for and linalg.for operations as roots of loop nests and +/// converts them to the gpu.launch operations if possible. /// /// No check on the size of the block or grid, or on the validity of /// parallelization is performed, it is under the responsibility of the caller /// to strip-mine the loops and to perform the dependence analysis before /// calling the conversion. -FunctionPassBase *createSimpleAffineToGPUPass(unsigned numBlockDims, - unsigned numThreadDims); +FunctionPassBase *createSimpleLoopsToGPUPass(unsigned numBlockDims, + unsigned numThreadDims); } // namespace mlir -#endif // MLIR_CONVERSION_AFFINETOGPU_AFFINETOGPUPASS_H_ +#endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index 6b6380f..5ea79de 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -31,7 +31,7 @@ namespace mlir { template bool areValuesDefinedAbove(Range values, Region &limit) { for (Value *v : values) - if (v->getContainingRegion()->isProperAncestor(&limit)) + if (!v->getContainingRegion()->isProperAncestor(&limit)) return false; return true; } diff --git a/mlir/lib/Conversion/AffineToGPU/AffineToGPU.cpp b/mlir/lib/Conversion/AffineToGPU/AffineToGPU.cpp deleted file mode 100644 index aebce54..0000000 --- a/mlir/lib/Conversion/AffineToGPU/AffineToGPU.cpp +++ /dev/null @@ -1,220 +0,0 @@ -//===- AffineToGPU.cpp - Convert an affine loop nest to a GPU kernel ------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This implements a straightforward conversion of an affine loop nest into a -// GPU kernel. The caller is expected to guarantee that the conversion is -// correct or to further transform the kernel to ensure correctness. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/AffineToGPU/AffineToGPU.h" -#include "mlir/AffineOps/AffineOps.h" -#include "mlir/GPU/GPUDialect.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Builders.h" -#include "mlir/StandardOps/Ops.h" -#include "mlir/Transforms/LowerAffine.h" -#include "mlir/Transforms/RegionUtils.h" - -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "affine-to-gpu" - -using namespace mlir; - -// Extract an indexed value from KernelDim3. -static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { - switch (pos) { - case 0: - return dim3.x; - case 1: - return dim3.y; - case 2: - return dim3.z; - default: - llvm_unreachable("dim3 position out of bounds"); - } - return nullptr; -} - -LogicalResult mlir::convertAffineLoopNestToGPULaunch(AffineForOp forOp, - unsigned numBlockDims, - unsigned numThreadDims) { - if (numBlockDims < 1 || numThreadDims < 1) { - LLVM_DEBUG(llvm::dbgs() << "nothing to map"); - return success(); - } - - OpBuilder builder(forOp.getOperation()); - if (numBlockDims > 3) { - return emitError(builder.getUnknownLoc(), - "cannot map to more than 3 block dimensions"); - } - if (numThreadDims > 3) { - return emitError(builder.getUnknownLoc(), - "cannot map to more than 3 thread dimensions"); - } - - // Check the structure of the loop nest: - // - there is enough loops to map to numBlockDims + numThreadDims; - // - the loops are perfectly nested; - // - the loop bounds can be computed above the outermost loop. - // This roughly corresponds to the "matcher" part of the pattern-based - // rewriting infrastructure. - AffineForOp currentLoop = forOp; - Region &limit = forOp.getRegion(); - for (unsigned i = 0, e = numBlockDims + numThreadDims; i < e; ++i) { - Operation *nested = ¤tLoop.getBody()->front(); - if (currentLoop.getStep() <= 0) - return currentLoop.emitError("only positive loop steps are supported"); - if (!areValuesDefinedAbove(currentLoop.getLowerBoundOperands(), limit) || - !areValuesDefinedAbove(currentLoop.getUpperBoundOperands(), limit)) - return currentLoop.emitError( - "loops with bounds depending on other mapped loops " - "are not supported"); - - // The innermost loop can have an arbitrary body, skip the perfect nesting - // check for it. - if (i == e - 1) - break; - - auto begin = currentLoop.getBody()->begin(), - end = currentLoop.getBody()->end(); - if (currentLoop.getBody()->empty() || std::next(begin, 2) != end) - return currentLoop.emitError( - "expected perfectly nested loops in the body"); - - if (!(currentLoop = dyn_cast(nested))) - return nested->emitError("expected a nested loop"); - } - - // Compute the ranges of the loops and collect lower bounds and induction - // variables. - SmallVector dims; - SmallVector lbs; - SmallVector ivs; - SmallVector steps; - dims.reserve(numBlockDims + numThreadDims); - lbs.reserve(numBlockDims + numThreadDims); - ivs.reserve(numBlockDims + numThreadDims); - steps.reserve(numBlockDims + numThreadDims); - currentLoop = forOp; - for (unsigned i = 0, e = numBlockDims + numThreadDims; i < e; ++i) { - Value *lowerBound = lowerAffineLowerBound(currentLoop, builder); - Value *upperBound = lowerAffineUpperBound(currentLoop, builder); - if (!lowerBound || !upperBound) - return failure(); - - Value *range = - builder.create(currentLoop.getLoc(), upperBound, lowerBound); - int64_t step = currentLoop.getStep(); - if (step > 1) { - auto divExpr = - getAffineSymbolExpr(0, currentLoop.getContext()).floorDiv(step); - range = expandAffineExpr(builder, currentLoop.getLoc(), divExpr, - llvm::None, range); - } - dims.push_back(range); - - lbs.push_back(lowerBound); - ivs.push_back(currentLoop.getInductionVar()); - steps.push_back(step); - - if (i != e - 1) - currentLoop = cast(¤tLoop.getBody()->front()); - } - // At this point, currentLoop points to the innermost loop we are mapping. - - // Prepare the grid and block sizes for the launch operation. If there is - // no loop mapped to a specific dimension, use constant "1" as its size. - Value *constOne = (numBlockDims < 3 || numThreadDims < 3) - ? builder.create(forOp.getLoc(), 1) - : nullptr; - Value *gridSizeX = dims[0]; - Value *gridSizeY = numBlockDims > 1 ? dims[1] : constOne; - Value *gridSizeZ = numBlockDims > 2 ? dims[2] : constOne; - Value *blockSizeX = dims[numBlockDims]; - Value *blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne; - Value *blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne; - - // Create a launch op and move the body region of the innermost loop to the - // launch op. Pass the values defined outside the outermost loop and used - // inside the innermost loop and loop lower bounds as kernel data arguments. - // Still assuming perfect nesting so there are no values other than induction - // variables that are defined in one loop and used in deeper loops. - llvm::SetVector valuesToForwardSet; - getUsedValuesDefinedAbove(forOp.getRegion(), forOp.getRegion(), - valuesToForwardSet); - auto valuesToForward = valuesToForwardSet.takeVector(); - auto originallyForwardedValues = valuesToForward.size(); - valuesToForward.insert(valuesToForward.end(), lbs.begin(), lbs.end()); - auto launchOp = builder.create( - forOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, - blockSizeZ, valuesToForward); - valuesToForward.resize(originallyForwardedValues); - - // Replace the affine terminator (loops contain only a single block) with the - // gpu return and move the operations from the loop body block to the gpu - // launch body block. Do not move the entire block because of the difference - // in block arguments. - Operation &terminator = currentLoop.getBody()->back(); - Location terminatorLoc = terminator.getLoc(); - terminator.erase(); - builder.setInsertionPointToEnd(currentLoop.getBody()); - builder.create(terminatorLoc); - launchOp.getBody().front().getOperations().splice( - launchOp.getBody().front().begin(), - currentLoop.getBody()->getOperations()); - - // Remap the loop iterators to use block/thread identifiers instead. Loops - // may iterate from LB with step S whereas GPU thread/block ids always iterate - // from 0 to N with step 1. Therefore, loop induction variables are replaced - // with (gpu-thread/block-id * S) + LB. - builder.setInsertionPointToStart(&launchOp.getBody().front()); - auto lbArgumentIt = std::next(launchOp.getKernelArguments().begin(), - originallyForwardedValues); - for (auto en : llvm::enumerate(ivs)) { - Value *id = - en.index() < numBlockDims - ? getDim3Value(launchOp.getBlockIds(), en.index()) - : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); - if (steps[en.index()] > 1) { - Value *factor = - builder.create(forOp.getLoc(), steps[en.index()]); - id = builder.create(forOp.getLoc(), factor, id); - } - Value *ivReplacement = - builder.create(forOp.getLoc(), *lbArgumentIt, id); - en.value()->replaceAllUsesWith(ivReplacement); - std::advance(lbArgumentIt, 1); - } - - // Remap the values defined outside the body to use kernel arguments instead. - // The list of kernel arguments also contains the lower bounds for loops at - // trailing positions, make sure we don't touch those. - for (const auto &pair : - llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) { - Value *from = std::get<0>(pair); - Value *to = std::get<1>(pair); - replaceAllUsesInRegionWith(from, to, launchOp.getBody()); - } - - // We are done and can erase the original outermost loop. - forOp.erase(); - - return success(); -} diff --git a/mlir/lib/Conversion/AffineToGPU/CMakeLists.txt b/mlir/lib/Conversion/AffineToGPU/CMakeLists.txt deleted file mode 100644 index 1782677..0000000 --- a/mlir/lib/Conversion/AffineToGPU/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -set(LIBS - MLIRAffineOps - MLIRGPU - MLIRIR - MLIRPass - MLIRStandardOps - MLIRSupport - MLIRTransforms - LLVMSupport -) - -add_llvm_library(MLIRAffineToGPU - AffineToGPU.cpp - AffineToGPUPass.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AffineToGPU -) -add_dependencies(MLIRAffineToGPU ${LIBS}) -target_link_libraries(MLIRAffineToGPU ${LIBS}) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 7ab64e2..35c0a32 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,4 +1,4 @@ -add_subdirectory(AffineToGPU) +add_subdirectory(LoopsToGPU) add_subdirectory(GPUToCUDA) add_subdirectory(GPUToNVVM) add_subdirectory(StandardToLLVM) diff --git a/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt b/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt new file mode 100644 index 0000000..2dacc80 --- /dev/null +++ b/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt @@ -0,0 +1,21 @@ +set(LIBS + MLIRAffineOps + MLIRGPU + MLIRIR + MLIRLinalg + MLIRPass + MLIRStandardOps + MLIRSupport + MLIRTransforms + LLVMSupport +) + +add_llvm_library(MLIRLoopsToGPU + LoopsToGPU.cpp + LoopsToGPUPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LoopsToGPU +) +add_dependencies(MLIRLoopsToGPU ${LIBS}) +target_link_libraries(MLIRLoopsToGPU ${LIBS}) diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp new file mode 100644 index 0000000..96ac947 --- /dev/null +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -0,0 +1,338 @@ +//===- LoopsToGPU.cpp - Convert an affine loop nest to a GPU kernel -------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This implements a straightforward conversion of an loop nest into a GPU +// kernel. The caller is expected to guarantee that the conversion is correct +// or to further transform the kernel to ensure correctness. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/GPU/GPUDialect.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/Linalg/IR/LinalgOps.h" +#include "mlir/StandardOps/Ops.h" +#include "mlir/Transforms/LowerAffine.h" +#include "mlir/Transforms/RegionUtils.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "loops-to-gpu" + +using namespace mlir; + +// Extract an indexed value from KernelDim3. +static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { + switch (pos) { + case 0: + return dim3.x; + case 1: + return dim3.y; + case 2: + return dim3.z; + default: + llvm_unreachable("dim3 position out of bounds"); + } + return nullptr; +} + +// Get the lower bound-related operands of a loop operation. +static Operation::operand_range getLowerBoundOperands(AffineForOp forOp) { + return forOp.getLowerBoundOperands(); +} +static SmallVector getLowerBoundOperands(linalg::ForOp forOp) { + SmallVector bounds(1, forOp.getLowerBound()); + return bounds; +} + +// Get the upper bound-related operands of a loop operation. +static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) { + return forOp.getUpperBoundOperands(); +} +static SmallVector getUpperBoundOperands(linalg::ForOp forOp) { + SmallVector bounds(1, forOp.getUpperBound()); + return bounds; +} + +// Get a Value that corresponds to the loop step. If the step is an attribute, +// materialize a corresponding constant using builder. +static Value *getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { + return builder.create(forOp.getLoc(), forOp.getStep()); +} +static Value *getOrCreateStep(linalg::ForOp forOp, OpBuilder &) { + return forOp.getStep(); +} + +// Get a Value for the loop lower bound. If the value requires computation, +// materialize the instructions using builder. +static Value *getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) { + return lowerAffineLowerBound(forOp, builder); +} +static Value *getOrEmitLowerBound(linalg::ForOp forOp, OpBuilder &) { + return forOp.getLowerBound(); +} + +// Get a Value for the loop upper bound. If the value requires computation, +// materialize the instructions using builder. +static Value *getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) { + return lowerAffineUpperBound(forOp, builder); +} +static Value *getOrEmitUpperBound(linalg::ForOp forOp, OpBuilder &) { + return forOp.getUpperBound(); +} + +// Check the structure of the loop nest: +// - there are enough loops to map to numBlockDims + numThreadDims; +// - the loops are perfectly nested; +// - the loop bounds can be computed above the outermost loop. +// This roughly corresponds to the "matcher" part of the pattern-based +// rewriting infrastructure. +template +LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims, + unsigned numThreadDims) { + if (numBlockDims < 1 || numThreadDims < 1) { + LLVM_DEBUG(llvm::dbgs() << "nothing to map"); + return success(); + } + + OpBuilder builder(forOp.getOperation()); + if (numBlockDims > 3) { + return emitError(builder.getUnknownLoc(), + "cannot map to more than 3 block dimensions"); + } + if (numThreadDims > 3) { + return emitError(builder.getUnknownLoc(), + "cannot map to more than 3 thread dimensions"); + } + + OpTy currentLoop = forOp; + Region &limit = forOp.getRegion(); + for (unsigned i = 0, e = numBlockDims + numThreadDims; i < e; ++i) { + Operation *nested = ¤tLoop.getBody()->front(); + if (!areValuesDefinedAbove(getLowerBoundOperands(currentLoop), limit) || + !areValuesDefinedAbove(getUpperBoundOperands(currentLoop), limit)) + return currentLoop.emitError( + "loops with bounds depending on other mapped loops " + "are not supported"); + + // The innermost loop can have an arbitrary body, skip the perfect nesting + // check for it. + if (i == e - 1) + break; + + auto begin = currentLoop.getBody()->begin(), + end = currentLoop.getBody()->end(); + if (currentLoop.getBody()->empty() || std::next(begin, 2) != end) + return currentLoop.emitError( + "expected perfectly nested loops in the body"); + + if (!(currentLoop = dyn_cast(nested))) + return nested->emitError("expected a nested loop"); + } + + return success(); +} + +namespace { +// Helper structure that holds common state of the loop to GPU kernel +// conversion. +struct LoopToGpuConverter { + template + Optional collectBounds(OpTy forOp, unsigned numLoops); + + template + void createLaunch(OpTy rootForOp, OpTy innermostForOp, unsigned numBlockDims, + unsigned numThreadDims); + + // Ranges of the loops mapped to blocks or threads. + SmallVector dims; + // Lower bounds of the loops mapped to blocks or threads. + SmallVector lbs; + // Induction variables of the loops mapped to blocks or threads. + SmallVector ivs; + // Steps of the loops mapped to blocks or threads. + SmallVector steps; +}; +} // namespace + +// Return true if the value is obviously a constant "one". +static bool isConstantOne(Value *value) { + if (auto def = dyn_cast_or_null(value->getDefiningOp())) + return def.getValue() == 1; + return false; +} + +// Collect ranges, bounds, steps and induction variables in preparation for +// mapping a loop nest of depth "numLoops" rooted at "forOp" to a GPU kernel. +// This may fail if the IR for computing loop bounds cannot be constructed, for +// example if an affine loop uses semi-affine maps. Return the last loop to be +// mapped on success, llvm::None on failure. +template +Optional LoopToGpuConverter::collectBounds(OpTy forOp, + unsigned numLoops) { + OpBuilder builder(forOp.getOperation()); + dims.reserve(numLoops); + lbs.reserve(numLoops); + ivs.reserve(numLoops); + steps.reserve(numLoops); + OpTy currentLoop = forOp; + for (unsigned i = 0; i < numLoops; ++i) { + Value *lowerBound = getOrEmitLowerBound(currentLoop, builder); + Value *upperBound = getOrEmitUpperBound(currentLoop, builder); + if (!lowerBound || !upperBound) { + return llvm::None; + } + + Value *range = + builder.create(currentLoop.getLoc(), upperBound, lowerBound); + Value *step = getOrCreateStep(currentLoop, builder); + if (!isConstantOne(step)) + range = builder.create(currentLoop.getLoc(), range, step); + dims.push_back(range); + + lbs.push_back(lowerBound); + ivs.push_back(currentLoop.getInductionVar()); + steps.push_back(step); + + if (i != numLoops - 1) + currentLoop = cast(¤tLoop.getBody()->front()); + } + return currentLoop; +} + +// Replace the rooted at "rootForOp" with a GPU launch operation. This expects +// "innermostForOp" to point to the last loop to be transformed to the kernel, +// and to have (numBlockDims + numThreadDims) perfectly nested loops between +// "rootForOp" and "innermostForOp". +template +void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, + unsigned numBlockDims, + unsigned numThreadDims) { + OpBuilder builder(rootForOp.getOperation()); + // Prepare the grid and block sizes for the launch operation. If there is + // no loop mapped to a specific dimension, use constant "1" as its size. + Value *constOne = (numBlockDims < 3 || numThreadDims < 3) + ? builder.create(rootForOp.getLoc(), 1) + : nullptr; + Value *gridSizeX = dims[0]; + Value *gridSizeY = numBlockDims > 1 ? dims[1] : constOne; + Value *gridSizeZ = numBlockDims > 2 ? dims[2] : constOne; + Value *blockSizeX = dims[numBlockDims]; + Value *blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne; + Value *blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne; + + // Create a launch op and move the body region of the innermost loop to the + // launch op. Pass the values defined outside the outermost loop and used + // inside the innermost loop and loop lower bounds as kernel data arguments. + // Still assuming perfect nesting so there are no values other than induction + // variables that are defined in one loop and used in deeper loops. + llvm::SetVector valuesToForwardSet; + getUsedValuesDefinedAbove(innermostForOp.getRegion(), rootForOp.getRegion(), + valuesToForwardSet); + auto valuesToForward = valuesToForwardSet.takeVector(); + auto originallyForwardedValues = valuesToForward.size(); + valuesToForward.insert(valuesToForward.end(), lbs.begin(), lbs.end()); + valuesToForward.insert(valuesToForward.end(), steps.begin(), steps.end()); + auto launchOp = builder.create( + rootForOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX, + blockSizeY, blockSizeZ, valuesToForward); + valuesToForward.resize(originallyForwardedValues); + + // Replace the loop terminator (loops contain only a single block) with the + // gpu return and move the operations from the loop body block to the gpu + // launch body block. Do not move the entire block because of the difference + // in block arguments. + Operation &terminator = innermostForOp.getBody()->back(); + Location terminatorLoc = terminator.getLoc(); + terminator.erase(); + builder.setInsertionPointToEnd(innermostForOp.getBody()); + builder.create(terminatorLoc); + launchOp.getBody().front().getOperations().splice( + launchOp.getBody().front().begin(), + innermostForOp.getBody()->getOperations()); + + // Remap the loop iterators to use block/thread identifiers instead. Loops + // may iterate from LB with step S whereas GPU thread/block ids always iterate + // from 0 to N with step 1. Therefore, loop induction variables are replaced + // with (gpu-thread/block-id * S) + LB. + builder.setInsertionPointToStart(&launchOp.getBody().front()); + auto lbArgumentIt = std::next(launchOp.getKernelArguments().begin(), + originallyForwardedValues); + auto stepArgumentIt = std::next(lbArgumentIt, lbs.size()); + for (auto en : llvm::enumerate(ivs)) { + Value *id = + en.index() < numBlockDims + ? getDim3Value(launchOp.getBlockIds(), en.index()) + : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); + Value *step = steps[en.index()]; + if (!isConstantOne(step)) + id = builder.create(rootForOp.getLoc(), step, id); + + Value *ivReplacement = + builder.create(rootForOp.getLoc(), *lbArgumentIt, id); + en.value()->replaceAllUsesWith(ivReplacement); + replaceAllUsesInRegionWith(steps[en.index()], *stepArgumentIt, + launchOp.getBody()); + std::advance(lbArgumentIt, 1); + std::advance(stepArgumentIt, 1); + } + + // Remap the values defined outside the body to use kernel arguments instead. + // The list of kernel arguments also contains the lower bounds for loops at + // trailing positions, make sure we don't touch those. + for (const auto &pair : + llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) { + Value *from = std::get<0>(pair); + Value *to = std::get<1>(pair); + replaceAllUsesInRegionWith(from, to, launchOp.getBody()); + } + + // We are done and can erase the original outermost loop. + rootForOp.erase(); +} + +// Generic loop to GPU kernel conversion function. +template +static LogicalResult convertLoopNestToGPULaunch(OpTy forOp, + unsigned numBlockDims, + unsigned numThreadDims) { + if (failed(checkLoopNestMappable(forOp, numBlockDims, numThreadDims))) + return failure(); + + LoopToGpuConverter converter; + auto maybeInnerLoop = + converter.collectBounds(forOp, numBlockDims + numThreadDims); + if (!maybeInnerLoop) + return failure(); + converter.createLaunch(forOp, *maybeInnerLoop, numBlockDims, numThreadDims); + + return success(); +} + +LogicalResult mlir::convertAffineLoopNestToGPULaunch(AffineForOp forOp, + unsigned numBlockDims, + unsigned numThreadDims) { + return convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims); +} + +LogicalResult mlir::convertLinalgLoopNestToGPULaunch(linalg::ForOp forOp, + unsigned numBlockDims, + unsigned numThreadDims) { + return convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims); +} diff --git a/mlir/lib/Conversion/AffineToGPU/AffineToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp similarity index 73% rename from mlir/lib/Conversion/AffineToGPU/AffineToGPUPass.cpp rename to mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 601340b..a8ef3d3 100644 --- a/mlir/lib/Conversion/AffineToGPU/AffineToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -1,4 +1,4 @@ -//===- AffineToGPUPass.cpp - Convert an affine loop nest to a GPU kernel --===// +//===- LoopsToGPUPass.cpp - Convert a loop nest to a GPU kernel -----------===// // // Copyright 2019 The MLIR Authors. // @@ -15,14 +15,15 @@ // limitations under the License. // ============================================================================= -#include "mlir/Conversion/AffineToGPU/AffineToGPUPass.h" +#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" #include "mlir/AffineOps/AffineOps.h" -#include "mlir/Conversion/AffineToGPU/AffineToGPU.h" +#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" +#include "mlir/Linalg/IR/LinalgOps.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/CommandLine.h" -#define PASS_NAME "convert-affine-to-gpu" +#define PASS_NAME "convert-loops-to-gpu" using namespace mlir; @@ -46,11 +47,17 @@ struct AffineForGPUMapper : public FunctionPass { void runOnFunction() override { for (Block &block : getFunction()) - for (Operation &op : llvm::make_early_inc_range(block)) - if (auto forOp = dyn_cast(&op)) + for (Operation &op : llvm::make_early_inc_range(block)) { + if (auto forOp = dyn_cast(&op)) { if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims))) signalPassFailure(); + } else if (auto forOp = dyn_cast(&op)) { + if (failed(convertLinalgLoopNestToGPULaunch(forOp, numBlockDims, + numThreadDims))) + signalPassFailure(); + } + } } unsigned numBlockDims; @@ -64,10 +71,10 @@ struct AffineForGPUMapperCLI : public AffineForGPUMapper { }; } // namespace -FunctionPassBase *mlir::createSimpleAffineToGPUPass(unsigned numBlockDims, - unsigned numThreadDims) { +FunctionPassBase *mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims, + unsigned numThreadDims) { return new AffineForGPUMapper(numBlockDims, numThreadDims); } static PassRegistration - registration(PASS_NAME, "Convert top-level affine loops to GPU kernels"); + registration(PASS_NAME, "Convert top-level loops to GPU kernels"); diff --git a/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir b/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir new file mode 100644 index 0000000..14012fd --- /dev/null +++ b/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt -convert-loops-to-gpu %s | FileCheck %s + +// CHECK-LABEL: @foo +func @foo(%arg0: !linalg.buffer, %arg1 : index) { + %c0 = constant 0 : index + %c42 = constant 42 : index + %c3 = constant 3 : index + // CHECK: subi %c42, %c0 : index + // CHECK-NEXT: %[[range_i:.*]] = divis {{.*}}, %c3 : index + linalg.for %i0 = %c0 to %c42 step %c3 { + // CHECK: subi %c42, %c3 : index + // CHECK-NEXT: %[[range_j:.*]] = divis {{.*}}, %arg1 : index + linalg.for %i1 = %c3 to %c42 step %arg1 { + // CHECK: gpu.launch + // CHECK-SAME: blocks(%i0, %i1, %i2) in (%i6 = %[[range_i]], %i7 = %c1, %i8 = %c1) + // CHECK-SAME: threads(%i3, %i4, %i5) in (%i9 = %[[range_j]], %i10 = %c1, %i11 = %c1) + // CHECK-SAME: args(%i12 = %c0, %i13 = %c3, %i14 = %c3, %i15 = %arg1) + + // Replacements of loop induction variables. Take a product with the + // step and add the lower bound. + // CHECK: %[[prod_i:.*]] = muli %i14, %i0 : index + // CHECK: addi %i12, %[[prod_i]] : index + // CHECK: %[[prod_j:.*]] = muli %i15, %i3 : index + // CHECK: addi %i13, %[[prod_j]] : index + + // CHECK: gpu.return + } + } + return +} diff --git a/mlir/test/Conversion/AffineToGPU/step_one.mlir b/mlir/test/Conversion/LoopsToGPU/step_one.mlir similarity index 76% rename from mlir/test/Conversion/AffineToGPU/step_one.mlir rename to mlir/test/Conversion/LoopsToGPU/step_one.mlir index 7112298..1b5e236 100644 --- a/mlir/test/Conversion/AffineToGPU/step_one.mlir +++ b/mlir/test/Conversion/LoopsToGPU/step_one.mlir @@ -1,32 +1,36 @@ -// RUN: mlir-opt -convert-affine-to-gpu -gpu-block-dims=1 -gpu-thread-dims=1 %s | FileCheck --check-prefix=CHECK-11 %s -// RUN: mlir-opt -convert-affine-to-gpu -gpu-block-dims=2 -gpu-thread-dims=2 %s | FileCheck --check-prefix=CHECK-22 %s +// RUN: mlir-opt -convert-loops-to-gpu -gpu-block-dims=1 -gpu-thread-dims=1 %s | FileCheck --check-prefix=CHECK-11 %s +// RUN: mlir-opt -convert-loops-to-gpu -gpu-block-dims=2 -gpu-thread-dims=2 %s | FileCheck --check-prefix=CHECK-22 %s // CHECK-11-LABEL: @step_1 // CHECK-22-LABEL: @step_1 func @step_1(%A : memref, %B : memref) { - // Bounds of the loop and its range. + // Bounds of the loop, its range and step. // CHECK-11-NEXT: %c0 = constant 0 : index // CHECK-11-NEXT: %c42 = constant 42 : index // CHECK-11-NEXT: %0 = subi %c42, %c0 : index + // CHECK-11-NEXT: %c1 = constant 1 : index // // CHECK-22-NEXT: %c0 = constant 0 : index // CHECK-22-NEXT: %c42 = constant 42 : index // CHECK-22-NEXT: %0 = subi %c42, %c0 : index + // CHECK-22-NEXT: %c1 = constant 1 : index affine.for %i = 0 to 42 { - // Bounds of the loop and its range. + // Bounds of the loop, its range and step. // CHECK-11-NEXT: %c0_0 = constant 0 : index // CHECK-11-NEXT: %c10 = constant 10 : index // CHECK-11-NEXT: %1 = subi %c10, %c0_0 : index + // CHECK-11-NEXT: %c1_1 = constant 1 : index // // CHECK-22-NEXT: %c0_0 = constant 0 : index // CHECK-22-NEXT: %c10 = constant 10 : index // CHECK-22-NEXT: %1 = subi %c10, %c0_0 : index + // CHECK-22-NEXT: %c1_1 = constant 1 : index affine.for %j = 0 to 10 { // CHECK-11: gpu.launch - // CHECK-11-SAME: blocks(%i0, %i1, %i2) in (%i6 = %0, %i7 = %c1, %i8 = %c1) - // CHECK-11-SAME: threads(%i3, %i4, %i5) in (%i9 = %1, %i10 = %c1, %i11 = %c1) - // CHECK-11-SAME: args(%i12 = %arg0, %i13 = %arg1, %i14 = %c0, %i15 = %c0_0) + // CHECK-11-SAME: blocks(%i0, %i1, %i2) in (%i6 = %0, %i7 = %c1_2, %i8 = %c1_2) + // CHECK-11-SAME: threads(%i3, %i4, %i5) in (%i9 = %1, %i10 = %c1_2, %i11 = %c1_2) + // CHECK-11-SAME: args(%i12 = %arg0, %i13 = %arg1, %i14 = %c0, %i15 = %c0_0, %i16 = %c1, %i17 = %c1_1) // Remapping of the loop induction variables. // CHECK-11: %[[i:.*]] = addi %i14, %i0 : index @@ -35,23 +39,25 @@ func @step_1(%A : memref, %B : memref) { // This loop is not converted if mapping to 1, 1 dimensions. // CHECK-11-NEXT: affine.for %[[ii:.*]] = 2 to 16 // - // Bounds of the loop and its range. + // Bounds of the loop, its range and step. // CHECK-22-NEXT: %c2 = constant 2 : index // CHECK-22-NEXT: %c16 = constant 16 : index // CHECK-22-NEXT: %2 = subi %c16, %c2 : index + // CHECK-22-NEXT: %c1_2 = constant 1 : index affine.for %ii = 2 to 16 { // This loop is not converted if mapping to 1, 1 dimensions. // CHECK-11-NEXT: affine.for %[[jj:.*]] = 5 to 17 // - // Bounds of the loop and its range. + // Bounds of the loop, its range and step. // CHECK-22-NEXT: %c5 = constant 5 : index // CHECK-22-NEXT: %c17 = constant 17 : index // CHECK-22-NEXT: %3 = subi %c17, %c5 : index + // CHECK-22-NEXT: %c1_3 = constant 1 : index affine.for %jj = 5 to 17 { // CHECK-22: gpu.launch - // CHECK-22-SAME: blocks(%i0, %i1, %i2) in (%i6 = %0, %i7 = %1, %i8 = %c1) - // CHECK-22-SAME: threads(%i3, %i4, %i5) in (%i9 = %2, %i10 = %3, %i11 = %c1) - // CHECK-22-SAME: args(%i12 = %arg0, %i13 = %arg1, %i14 = %c0, %i15 = %c0_0, %i16 = %c2, %i17 = %c5) + // CHECK-22-SAME: blocks(%i0, %i1, %i2) in (%i6 = %0, %i7 = %1, %i8 = %c1_4) + // CHECK-22-SAME: threads(%i3, %i4, %i5) in (%i9 = %2, %i10 = %3, %i11 = %c1_4) + // CHECK-22-SAME: args(%i12 = %arg0, %i13 = %arg1, %i14 = %c0, %i15 = %c0_0, %i16 = %c2, %i17 = %c5, %i18 = %c1, %i19 = %c1_1, %i20 = %c1_2, %i21 = %c1_3) // Remapping of the loop induction variables in the last mapped loop. // CHECK-22: %[[i:.*]] = addi %i14, %i0 : index diff --git a/mlir/test/Conversion/AffineToGPU/step_positive.mlir b/mlir/test/Conversion/LoopsToGPU/step_positive.mlir similarity index 54% rename from mlir/test/Conversion/AffineToGPU/step_positive.mlir rename to mlir/test/Conversion/LoopsToGPU/step_positive.mlir index b8da5de..b9af8a3 100644 --- a/mlir/test/Conversion/AffineToGPU/step_positive.mlir +++ b/mlir/test/Conversion/LoopsToGPU/step_positive.mlir @@ -1,14 +1,10 @@ -// RUN: mlir-opt -convert-affine-to-gpu -gpu-block-dims=1 -gpu-thread-dims=1 %s | FileCheck %s +// RUN: mlir-opt -convert-loops-to-gpu -gpu-block-dims=1 -gpu-thread-dims=1 %s | FileCheck %s // CHECK-LABEL: @step_var func @step_var(%A : memref, %B : memref) { - // The loop range computation is performed by lowering the affine expression - // floordiv(upper - lower, step). The lowering of affine expressions has its - // own test, here we only check the fact of division by step. - // CHECK: divis {{.*}}, %c4 - // CHECK: %[[range_i:.*]] = select - // CHECK: divis {{.*}}, %c7 - // CHECK: %[[range_j:.*]] = select + // Check that we divide by step. + // CHECK: %[[range_i:.*]] = divis {{.*}}, %c4 + // CHECK: %[[range_j:.*]] = divis {{.*}}, %c7 // CHECK: gpu.launch // CHECK-SAME: blocks(%i0, %i1, %i2) in (%i6 = %[[range_i]], %i7 = %c1, %i8 = %c1) @@ -17,11 +13,9 @@ func @step_var(%A : memref, %B : memref) { affine.for %j = 3 to 19 step 7 { // Loop induction variable remapping: // iv = thread(block)_id * step + lower_bound - // CHECK: %[[c4:.*]] = constant 4 : index - // CHECK-NEXT: %[[prod_i:.*]] = muli %[[c4]], %i0 : index + // CHECK: %[[prod_i:.*]] = muli %i16, %i0 : index // CHECK-NEXT: %[[i:.*]] = addi %i14, %[[prod_i]] : index - // CHECK-NEXT: %[[c7:.*]] = constant 7 : index - // CHECK-NEXT: %[[prod_j:.*]] = muli %[[c7]], %i3 : index + // CHECK-NEXT: %[[prod_j:.*]] = muli %i17, %i3 : index // CHECK-NEXT: %[[j:.*]] = addi %i15, %[[prod_j]] : index // CHECK: {{.*}} = load %i12[%[[i]], %[[j]]] : memref diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 1608e36..2c4d434 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -17,7 +17,7 @@ target_link_libraries(MLIRMlirOptLib ${LIB_LIBS}) set(LIBS MLIRAffineOps - MLIRAffineToGPU + MLIRLoopsToGPU MLIRAnalysis MLIREDSC MLIRFxpMathOps -- 2.7.4