--- /dev/null
+//===- AffineToGPU.h - Convert an affine loops to GPU kernels ---*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_CONVERSION_AFFINETOGPU_AFFINETOGPU_H_
+#define MLIR_CONVERSION_AFFINETOGPU_AFFINETOGPU_H_
+
+namespace mlir {
+class AffineForOp;
+struct LogicalResult;
+
+/// Convert a perfect affine loop nest with the outermost loop identified by
+/// `forOp` into a gpu::Launch operation. Map `numBlockDims` outer loops to
+/// GPU blocks and `numThreadDims` to GPU threads. The bounds of the loops that
+/// are mapped should be independent of the induction variables of the other
+/// mapped loops.
+///
+/// No check on the size of the block or grid, or on the validity of
+/// parallelization is performed, it is under the responsibility of the caller
+/// to strip-mine the loops and to perform the dependence analysis before
+/// calling the conversion.
+LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp,
+ unsigned numBlockDims,
+ unsigned numThreadDims);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_AFFINETOGPU_AFFINETOGPU_H_
/// Get the SSA values passed as operands to specify the block size.
KernelDim3 getBlockSizeOperandValues();
+ /// Get the SSA values of the kernel arguments.
+ llvm::iterator_range<Block::args_iterator> getKernelArguments();
+
LogicalResult verify();
/// Custom syntax support.
blocks.splice(blocks.end(), other.getBlocks());
}
+ /// Get a list of values used by operations in the region, including nested
+ /// regions, defined outside this region.
+ SmallVector<Value *, 8> getUsedValuesDefinedAbove();
+
/// Check that this does not use any value defined outside it.
/// Emit errors if `noteLoc` is provided; this location is used to point
/// to the operation containing the region, the actual error is reported at
#ifndef MLIR_TRANSFORMS_LOWERAFFINE_H
#define MLIR_TRANSFORMS_LOWERAFFINE_H
+#include "mlir/Support/LLVM.h"
+
namespace mlir {
+class AffineExpr;
+class AffineForOp;
class Function;
+class Location;
struct LogicalResult;
+class OpBuilder;
+class Value;
+
+/// Emit code that computes the given affine expression using standard
+/// arithmetic operations applied to the provided dimension and symbol values.
+Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
+ ArrayRef<Value *> dimValues,
+ ArrayRef<Value *> symbolValues);
+/// Convert from the Affine dialect to the Standard dialect, in particular
+/// convert structured affine control flow into CFG branch-based control flow.
LogicalResult lowerAffineConstructs(Function &function);
+
+/// Emit code that computes the lower bound of the given affine loop using
+/// standard arithmetic operations.
+Value *lowerAffineLowerBound(AffineForOp op, OpBuilder &builder);
+
+/// Emit code that computes the upper bound of the given affine loop using
+/// standard arithmetic operations.
+Value *lowerAffineUpperBound(AffineForOp op, OpBuilder &builder);
} // namespace mlir
#endif // MLIR_TRANSFORMS_LOWERAFFINE_H
--- /dev/null
+//===- RegionUtils.h - Region-related transformation utilities --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_TRANSFORMS_REGIONUTILS_H_
+#define MLIR_TRANSFORMS_REGIONUTILS_H_
+
+#include "mlir/IR/Block.h"
+
+namespace mlir {
+
+/// Check if all values in the provided range are defined above the `limit`
+/// region. That is, if they are defined in a region that is a proper ancestor
+/// of `limit`.
+template <typename Range>
+bool areValuesDefinedAbove(Range values, Region &limit) {
+ for (Value *v : values)
+ if (v->getContainingRegion()->isProperAncestor(&limit))
+ return false;
+ return true;
+}
+
+/// Replace all uses of `orig` within the given region with `replacement`.
+void replaceAllUsesInRegionWith(Value *orig, Value *replacement,
+ Region ®ion);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_REGIONUTILS_H_
add_subdirectory(AffineOps)
add_subdirectory(Analysis)
+add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(EDSC)
add_subdirectory(ExecutionEngine)
--- /dev/null
+//===- AffineToGPU.cpp - Convert an affine loop nest to a GPU kernel ------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This implements a straightforward conversion of an affine loop nest into a
+// GPU kernel. The caller is expected to guarantee that the conversion is
+// correct or to further transform the kernel to ensure correctness.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AffineToGPU/AffineToGPU.h"
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/GPU/GPUDialect.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/LowerAffine.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "affine-to-gpu"
+
+using namespace mlir;
+
+// Extract an indexed value from KernelDim3.
+static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) {
+ switch (pos) {
+ case 0:
+ return dim3.x;
+ case 1:
+ return dim3.y;
+ case 2:
+ return dim3.z;
+ default:
+ llvm_unreachable("dim3 position out of bounds");
+ }
+ return nullptr;
+}
+
+LogicalResult mlir::convertAffineLoopNestToGPULaunch(AffineForOp forOp,
+ unsigned numBlockDims,
+ unsigned numThreadDims) {
+ if (numBlockDims < 1 || numThreadDims < 1) {
+ LLVM_DEBUG(llvm::dbgs() << "nothing to map");
+ return success();
+ }
+
+ OpBuilder builder(forOp.getOperation());
+
+ if (numBlockDims > 3) {
+ forOp.getContext()->emitError(builder.getUnknownLoc(),
+ "cannot map to more than 3 block dimensions");
+ return failure();
+ }
+ if (numThreadDims > 3) {
+ forOp.getContext()->emitError(
+ builder.getUnknownLoc(), "cannot map to more than 3 thread dimensions");
+ return failure();
+ }
+
+ // Check the structure of the loop nest:
+ // - there is enough loops to map to numBlockDims + numThreadDims;
+ // - the loops are perfectly nested;
+ // - the loop bounds can be computed above the outermost loop.
+ // This roughly corresponds to the "matcher" part of the pattern-based
+ // rewriting infrastructure.
+ AffineForOp currentLoop = forOp;
+ Region &limit = forOp.getRegion();
+ for (unsigned i = 0, e = numBlockDims + numThreadDims; i < e; ++i) {
+ Operation *nested = ¤tLoop.getBody()->front();
+ if (currentLoop.getStep() <= 0)
+ return currentLoop.emitError("only positive loop steps are supported");
+ if (!areValuesDefinedAbove(currentLoop.getLowerBoundOperands(), limit) ||
+ !areValuesDefinedAbove(currentLoop.getUpperBoundOperands(), limit))
+ return currentLoop.emitError(
+ "loops with bounds depending on other mapped loops "
+ "are not supported");
+
+ // The innermost loop can have an arbitrary body, skip the perfect nesting
+ // check for it.
+ if (i == e - 1)
+ break;
+
+ auto begin = currentLoop.getBody()->begin(),
+ end = currentLoop.getBody()->end();
+ if (currentLoop.getBody()->empty() || std::next(begin, 2) != end)
+ return currentLoop.emitError(
+ "expected perfectly nested loops in the body");
+
+ if (!(currentLoop = dyn_cast<AffineForOp>(nested)))
+ return nested->emitError("expected a nested loop");
+ }
+
+ // Compute the ranges of the loops and collect lower bounds and induction
+ // variables.
+ SmallVector<Value *, 6> dims;
+ SmallVector<Value *, 6> lbs;
+ SmallVector<Value *, 6> ivs;
+ SmallVector<int64_t, 6> steps;
+ dims.reserve(numBlockDims + numThreadDims);
+ lbs.reserve(numBlockDims + numThreadDims);
+ ivs.reserve(numBlockDims + numThreadDims);
+ steps.reserve(numBlockDims + numThreadDims);
+ currentLoop = forOp;
+ for (unsigned i = 0, e = numBlockDims + numThreadDims; i < e; ++i) {
+ Value *lowerBound = lowerAffineLowerBound(currentLoop, builder);
+ Value *upperBound = lowerAffineUpperBound(currentLoop, builder);
+ if (!lowerBound || !upperBound)
+ return failure();
+
+ Value *range =
+ builder.create<SubIOp>(currentLoop.getLoc(), upperBound, lowerBound);
+ int64_t step = currentLoop.getStep();
+ if (step > 1) {
+ auto divExpr =
+ getAffineSymbolExpr(0, currentLoop.getContext()).floorDiv(step);
+ range = expandAffineExpr(builder, currentLoop.getLoc(), divExpr,
+ llvm::None, range);
+ }
+ dims.push_back(range);
+
+ lbs.push_back(lowerBound);
+ ivs.push_back(currentLoop.getInductionVar());
+ steps.push_back(step);
+
+ if (i != e - 1)
+ currentLoop = cast<AffineForOp>(¤tLoop.getBody()->front());
+ }
+ // At this point, currentLoop points to the innermost loop we are mapping.
+
+ // Prepare the grid and block sizes for the launch operation. If there is
+ // no loop mapped to a specific dimension, use constant "1" as its size.
+ Value *constOne = (numBlockDims < 3 || numThreadDims < 3)
+ ? builder.create<ConstantIndexOp>(forOp.getLoc(), 1)
+ : nullptr;
+ Value *gridSizeX = dims[0];
+ Value *gridSizeY = numBlockDims > 1 ? dims[1] : constOne;
+ Value *gridSizeZ = numBlockDims > 2 ? dims[2] : constOne;
+ Value *blockSizeX = dims[numBlockDims];
+ Value *blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne;
+ Value *blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne;
+
+ // Create a launch op and move the body region of the innermost loop to the
+ // launch op. Pass the values defined outside the outermost loop and used
+ // inside the innermost loop and loop lower bounds as kernel data arguments.
+ // Still assuming perfect nesting so there are no values other than induction
+ // variables that are defined in one loop and used in deeper loops.
+ auto valuesToForward = forOp.getRegion().getUsedValuesDefinedAbove();
+ auto originallyForwardedValues = valuesToForward.size();
+ valuesToForward.append(lbs.begin(), lbs.end());
+ auto launchOp = builder.create<gpu::LaunchOp>(
+ forOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY,
+ blockSizeZ, valuesToForward);
+ valuesToForward.resize(originallyForwardedValues);
+
+ // Replace the affine terminator (loops contain only a single block) with the
+ // gpu return and move the operations from the loop body block to the gpu
+ // launch body block. Do not move the entire block because of the difference
+ // in block arguments.
+ Operation &terminator = currentLoop.getBody()->back();
+ Location terminatorLoc = terminator.getLoc();
+ terminator.erase();
+ builder.setInsertionPointToEnd(currentLoop.getBody());
+ builder.create<gpu::Return>(terminatorLoc);
+ launchOp.getBody().front().getOperations().splice(
+ launchOp.getBody().front().begin(),
+ currentLoop.getBody()->getOperations());
+
+ // Remap the loop iterators to use block/thread identifiers instead. Loops
+ // may iterate from LB with step S whereas GPU thread/block ids always iterate
+ // from 0 to N with step 1. Therefore, loop induction variables are replaced
+ // with (gpu-thread/block-id * S) + LB.
+ builder.setInsertionPointToStart(&launchOp.getBody().front());
+ auto lbArgumentIt = std::next(launchOp.getKernelArguments().begin(),
+ originallyForwardedValues);
+ for (auto en : llvm::enumerate(ivs)) {
+ Value *id =
+ en.index() < numBlockDims
+ ? getDim3Value(launchOp.getBlockIds(), en.index())
+ : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims);
+ if (steps[en.index()] > 1) {
+ Value *factor =
+ builder.create<ConstantIndexOp>(forOp.getLoc(), steps[en.index()]);
+ id = builder.create<MulIOp>(forOp.getLoc(), factor, id);
+ }
+ Value *ivReplacement =
+ builder.create<AddIOp>(forOp.getLoc(), *lbArgumentIt, id);
+ en.value()->replaceAllUsesWith(ivReplacement);
+ std::advance(lbArgumentIt, 1);
+ }
+
+ // Remap the values defined outside the body to use kernel arguments instead.
+ // The list of kernel arguments also contains the lower bounds for loops at
+ // trailing positions, make sure we don't touch those.
+ for (const auto &pair :
+ llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) {
+ Value *from = std::get<0>(pair);
+ Value *to = std::get<1>(pair);
+ replaceAllUsesInRegionWith(from, to, launchOp.getBody());
+ }
+
+ // We are done and can erase the original outermost loop.
+ forOp.erase();
+
+ return success();
+}
--- /dev/null
+//===- AffineToGPUPass.cpp - Convert an affine loop nest to a GPU kernel --===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Conversion/AffineToGPU/AffineToGPU.h"
+#include "mlir/Pass/Pass.h"
+
+#include "llvm/Support/CommandLine.h"
+
+#define PASS_NAME "convert-affine-to-gpu"
+
+using namespace mlir;
+
+static llvm::cl::OptionCategory clOptionsCategory(PASS_NAME " options");
+static llvm::cl::opt<unsigned>
+ clNumBlockDims("gpu-block-dims",
+ llvm::cl::desc("Number of GPU block dimensions for mapping"),
+ llvm::cl::cat(clOptionsCategory), llvm::cl::init(1u));
+static llvm::cl::opt<unsigned> clNumThreadDims(
+ "gpu-thread-dims",
+ llvm::cl::desc("Number of GPU thread dimensions for mapping"),
+ llvm::cl::cat(clOptionsCategory), llvm::cl::init(1u));
+
+namespace {
+// A pass that traverses top-level loops in the function and converts them to
+// GPU launch operations. Nested launches are not allowed, so this does not
+// walk the function recursively to avoid considering nested loops.
+struct AffineForGPUMapper : public FunctionPass<AffineForGPUMapper> {
+ void runOnFunction() override {
+ for (Block &block : getFunction())
+ for (Operation &op : llvm::make_early_inc_range(block))
+ if (auto forOp = dyn_cast<AffineForOp>(&op))
+ if (failed(convertAffineLoopNestToGPULaunch(
+ forOp, clNumBlockDims.getValue(),
+ clNumThreadDims.getValue())))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+static PassRegistration<AffineForGPUMapper>
+ registration(PASS_NAME, "Convert top-level affine loops to GPU kernels");
--- /dev/null
+set(LIBS
+ MLIRAffineOps
+ MLIRGPU
+ MLIRIR
+ MLIRPass
+ MLIRStandardOps
+ MLIRSupport
+ MLIRTransforms
+ LLVMSupport
+)
+
+add_llvm_library(MLIRAffineToGPU
+ AffineToGPU.cpp
+ AffineToGPUPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AffineToGPU
+)
+add_dependencies(MLIRAffineToGPU ${LIBS})
+target_link_libraries(MLIRAffineToGPU ${LIBS})
--- /dev/null
+add_subdirectory(AffineToGPU)
return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
}
+llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
+ auto args = getBody().getBlocks().front().getArguments();
+ return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
+}
+
LogicalResult LaunchOp::verify() {
// Kernel launch takes kNumConfigOperands leading operands for grid/block
// sizes and transforms them into kNumConfigRegionAttributes region arguments
it->walk(remapOperands);
}
-/// Check that the given `region` does not use any value defined outside its
-/// ancestor region `limit`. That is, given `A{B{C{}}}` with limit `B`, `C` is
-/// allowed to use values defined in `B` but not those defined in `A`.
-/// Emit errors if `noteLoc` is provided; this location is used to point to
-/// the operation containing the region, the actual error is reported at the
-/// operation with an offending use.
-static bool isRegionIsolatedAbove(Region ®ion, Region &limit,
- llvm::Optional<Location> noteLoc) {
+/// Find the values used by operations in `region` that are defined outside its
+/// ancestor region `limit`. That is, given `A{B{C{}}}` with region `C` and
+/// limit `B`, the values defined in `A` will be found while the values defined
+/// in `B` will not. Append these values to `values`. If `stopAfterOne` is
+/// set, return immediate after one such value was found (used for isolation
+/// checks). Additionally, emit errors if `noteLoc` is provided; this location
+/// is used to point to the operation containing the region, the actual error is
+/// reported at the operation with an offending use.
+static void findValuesDefinedAbove(Region ®ion, Region &limit,
+ SmallVectorImpl<Value *> &values,
+ llvm::Optional<Location> noteLoc,
+ bool stopAfterOne = false) {
assert(limit.isAncestor(®ion) &&
"expected isolation limit to be an ancestor of the given region");
.attachNote(noteLoc)
<< "required by region isolation constraints";
}
- return false;
+ values.push_back(operand);
+ if (stopAfterOne)
+ return;
}
}
// Schedule any regions the operations contain for further checking.
}
}
}
+}
- return true;
+SmallVector<Value *, 8> Region::getUsedValuesDefinedAbove() {
+ SmallVector<Value *, 8> values;
+ findValuesDefinedAbove(*this, *this, values, llvm::None,
+ /*stopAfterOne=*/false);
+ return values;
}
bool Region::isIsolatedFromAbove(llvm::Optional<Location> noteLoc) {
- return isRegionIsolatedAbove(*this, *this, noteLoc);
+ SmallVector<Value *, 1> values;
+ findValuesDefinedAbove(*this, *this, values, noteLoc, /*stopAfterOne=*/true);
+ return values.empty();
}
Region *llvm::ilist_traits<::mlir::Block>::getContainingRegion() {
Utils/GreedyPatternRewriteDriver.cpp
Utils/LoopFusionUtils.cpp
Utils/LoopUtils.cpp
+ Utils/RegionUtils.cpp
Utils/Utils.cpp
Vectorization
Vectorize.cpp
// Create a sequence of operations that implement the `expr` applied to the
// given dimension and symbol values.
-static mlir::Value *expandAffineExpr(OpBuilder *builder, Location loc,
- AffineExpr expr,
- ArrayRef<Value *> dimValues,
- ArrayRef<Value *> symbolValues) {
- return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
+mlir::Value *mlir::expandAffineExpr(OpBuilder &builder, Location loc,
+ AffineExpr expr,
+ ArrayRef<Value *> dimValues,
+ ArrayRef<Value *> symbolValues) {
+ return AffineApplyExpander(&builder, dimValues, symbolValues, loc)
+ .visit(expr);
}
// Create a sequence of operations that implement the `affineMap` applied to
auto numDims = affineMap.getNumDims();
auto expanded = functional::map(
[numDims, builder, loc, operands](AffineExpr expr) {
- return expandAffineExpr(builder, loc, expr,
+ return expandAffineExpr(*builder, loc, expr,
operands.take_front(numDims),
operands.drop_front(numDims));
},
return value;
}
+// Emit instructions that correspond to the affine map in the lower bound
+// applied to the respective operands, and compute the maximum value across
+// the results.
+Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
+ SmallVector<Value *, 8> boundOperands(op.getLowerBoundOperands());
+ auto lbValues = expandAffineMap(&builder, op.getLoc(), op.getLowerBoundMap(),
+ boundOperands);
+ if (!lbValues)
+ return nullptr;
+ return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SGT, *lbValues,
+ builder);
+}
+
+// Emit instructions that correspond to the affine map in the upper bound
+// applied to the respective operands, and compute the minimum value across
+// the results.
+Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
+ SmallVector<Value *, 8> boundOperands(op.getUpperBoundOperands());
+ auto ubValues = expandAffineMap(&builder, op.getLoc(), op.getUpperBoundMap(),
+ boundOperands);
+ if (!ubValues)
+ return nullptr;
+ return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SLT, *ubValues,
+ builder);
+}
+
namespace {
// Affine terminators are removed.
class AffineTerminatorLowering : public ConversionPattern {
rewriter.setInsertionPointToEnd(lastBodyBlock);
auto affStep = rewriter.getAffineConstantExpr(forOp.getStep());
auto affDim = rewriter.getAffineDimExpr(0);
- auto stepped = expandAffineExpr(&rewriter, loc, affDim + affStep, iv, {});
+ auto stepped = expandAffineExpr(rewriter, loc, affDim + affStep, iv, {});
if (!stepped)
return matchFailure();
rewriter.create<BranchOp>(loc, conditionBlock, stepped);
// Compute loop bounds before branching to the condition.
rewriter.setInsertionPointToEnd(initBlock);
- SmallVector<Value *, 8> boundOperands(forOp.getLowerBoundOperands());
- auto lbValues = expandAffineMap(&rewriter, loc, forOp.getLowerBoundMap(),
- boundOperands);
- if (!lbValues)
- return matchFailure();
- Value *lowerBound =
- buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, rewriter);
-
- boundOperands.assign(forOp.getUpperBoundOperands().begin(),
- forOp.getUpperBoundOperands().end());
- auto ubValues = expandAffineMap(&rewriter, loc, forOp.getUpperBoundMap(),
- boundOperands);
- if (!ubValues)
+ Value *lowerBound = lowerAffineLowerBound(forOp, rewriter);
+ Value *upperBound = lowerAffineUpperBound(forOp, rewriter);
+ if (!lowerBound || !upperBound)
return matchFailure();
- Value *upperBound =
- buildMinMaxReductionSeq(loc, CmpIPredicate::SLT, *ubValues, rewriter);
rewriter.create<BranchOp>(loc, conditionBlock, lowerBound);
// With the body block done, we can fill in the condition block.
// Build and apply an affine expression
auto numDims = integerSet.getNumDims();
- Value *affResult = expandAffineExpr(&rewriter, loc, constraintExpr,
+ Value *affResult = expandAffineExpr(rewriter, loc, constraintExpr,
operands.take_front(numDims),
operands.drop_front(numDims));
if (!affResult)
--- /dev/null
+//===- RegionUtils.cpp - Region-related transformation utilities ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Transforms/RegionUtils.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+
+using namespace mlir;
+
+void mlir::replaceAllUsesInRegionWith(Value *orig, Value *replacement,
+ Region ®ion) {
+ for (IROperand &use : llvm::make_early_inc_range(orig->getUses())) {
+ if (region.isAncestor(use.getOwner()->getContainingRegion()))
+ use.set(replacement);
+ }
+}
--- /dev/null
+// RUN: mlir-opt -convert-affine-to-gpu -gpu-block-dims=1 -gpu-thread-dims=1 %s | FileCheck --check-prefix=CHECK-11 %s
+// RUN: mlir-opt -convert-affine-to-gpu -gpu-block-dims=2 -gpu-thread-dims=2 %s | FileCheck --check-prefix=CHECK-22 %s
+
+// CHECK-11-LABEL: @step_1
+// CHECK-22-LABEL: @step_1
+func @step_1(%A : memref<?x?x?x?xf32>, %B : memref<?x?x?x?xf32>) {
+ // Bounds of the loop and its range.
+ // CHECK-11-NEXT: %c0 = constant 0 : index
+ // CHECK-11-NEXT: %c42 = constant 42 : index
+ // CHECK-11-NEXT: %0 = subi %c42, %c0 : index
+ //
+ // CHECK-22-NEXT: %c0 = constant 0 : index
+ // CHECK-22-NEXT: %c42 = constant 42 : index
+ // CHECK-22-NEXT: %0 = subi %c42, %c0 : index
+ affine.for %i = 0 to 42 {
+
+ // Bounds of the loop and its range.
+ // CHECK-11-NEXT: %c0_0 = constant 0 : index
+ // CHECK-11-NEXT: %c10 = constant 10 : index
+ // CHECK-11-NEXT: %1 = subi %c10, %c0_0 : index
+ //
+ // CHECK-22-NEXT: %c0_0 = constant 0 : index
+ // CHECK-22-NEXT: %c10 = constant 10 : index
+ // CHECK-22-NEXT: %1 = subi %c10, %c0_0 : index
+ affine.for %j = 0 to 10 {
+ // CHECK-11: gpu.launch
+ // CHECK-11-SAME: blocks(%i0, %i1, %i2) in (%i6 = %0, %i7 = %c1, %i8 = %c1)
+ // CHECK-11-SAME: threads(%i3, %i4, %i5) in (%i9 = %1, %i10 = %c1, %i11 = %c1)
+ // CHECK-11-SAME: args(%i12 = %arg0, %i13 = %arg1, %i14 = %c0, %i15 = %c0_0)
+
+ // Remapping of the loop induction variables.
+ // CHECK-11: %[[i:.*]] = addi %i14, %i0 : index
+ // CHECK-11-NEXT: %[[j:.*]] = addi %i15, %i3 : index
+
+ // This loop is not converted if mapping to 1, 1 dimensions.
+ // CHECK-11-NEXT: affine.for %[[ii:.*]] = 2 to 16
+ //
+ // Bounds of the loop and its range.
+ // CHECK-22-NEXT: %c2 = constant 2 : index
+ // CHECK-22-NEXT: %c16 = constant 16 : index
+ // CHECK-22-NEXT: %2 = subi %c16, %c2 : index
+ affine.for %ii = 2 to 16 {
+ // This loop is not converted if mapping to 1, 1 dimensions.
+ // CHECK-11-NEXT: affine.for %[[jj:.*]] = 5 to 17
+ //
+ // Bounds of the loop and its range.
+ // CHECK-22-NEXT: %c5 = constant 5 : index
+ // CHECK-22-NEXT: %c17 = constant 17 : index
+ // CHECK-22-NEXT: %3 = subi %c17, %c5 : index
+ affine.for %jj = 5 to 17 {
+ // CHECK-22: gpu.launch
+ // CHECK-22-SAME: blocks(%i0, %i1, %i2) in (%i6 = %0, %i7 = %1, %i8 = %c1)
+ // CHECK-22-SAME: threads(%i3, %i4, %i5) in (%i9 = %2, %i10 = %3, %i11 = %c1)
+ // CHECK-22-SAME: args(%i12 = %arg0, %i13 = %arg1, %i14 = %c0, %i15 = %c0_0, %i16 = %c2, %i17 = %c5)
+
+ // Remapping of the loop induction variables in the last mapped loop.
+ // CHECK-22: %[[i:.*]] = addi %i14, %i0 : index
+ // CHECK-22-NEXT: %[[j:.*]] = addi %i15, %i1 : index
+ // CHECK-22-NEXT: %[[ii:.*]] = addi %i16, %i3 : index
+ // CHECK-22-NEXT: %[[jj:.*]] = addi %i17, %i4 : index
+
+ // Using remapped values instead of loop iterators.
+ // CHECK-11: {{.*}} = load %i12[%[[i]], %[[j]], %[[ii]], %[[jj]]] : memref<?x?x?x?xf32>
+ // CHECK-22: {{.*}} = load %i12[%[[i]], %[[j]], %[[ii]], %[[jj]]] : memref<?x?x?x?xf32>
+ %0 = load %A[%i, %j, %ii, %jj] : memref<?x?x?x?xf32>
+ // CHECK-11-NEXT: store {{.*}}, %i13[%[[i]], %[[j]], %[[ii]], %[[jj]]] : memref<?x?x?x?xf32>
+ // CHECK-22-NEXT: store {{.*}}, %i13[%[[i]], %[[j]], %[[ii]], %[[jj]]] : memref<?x?x?x?xf32>
+ store %0, %B[%i, %j, %ii, %jj] : memref<?x?x?x?xf32>
+
+ // CHECK-11: gpu.return
+ // CHECK-22: gpu.return
+ }
+ }
+ }
+ }
+ return
+}
+
--- /dev/null
+// RUN: mlir-opt -convert-affine-to-gpu -gpu-block-dims=1 -gpu-thread-dims=1 %s | FileCheck %s
+
+// CHECK-LABEL: @step_var
+func @step_var(%A : memref<?x?xf32>, %B : memref<?x?xf32>) {
+ // The loop range computation is performed by lowering the affine expression
+ // floordiv(upper - lower, step). The lowering of affine expressions has its
+ // own test, here we only check the fact of division by step.
+ // CHECK: divis {{.*}}, %c4
+ // CHECK: %[[range_i:.*]] = select
+ // CHECK: divis {{.*}}, %c7
+ // CHECK: %[[range_j:.*]] = select
+
+ // CHECK: gpu.launch
+ // CHECK-SAME: blocks(%i0, %i1, %i2) in (%i6 = %[[range_i]], %i7 = %c1, %i8 = %c1)
+ // CHECK-SAME: threads(%i3, %i4, %i5) in (%i9 = %[[range_j]], %i10 = %c1, %i11 = %c1)
+ affine.for %i = 5 to 15 step 4 {
+ affine.for %j = 3 to 19 step 7 {
+ // Loop induction variable remapping:
+ // iv = thread(block)_id * step + lower_bound
+ // CHECK: %[[c4:.*]] = constant 4 : index
+ // CHECK-NEXT: %[[prod_i:.*]] = muli %[[c4]], %i0 : index
+ // CHECK-NEXT: %[[i:.*]] = addi %i14, %[[prod_i]] : index
+ // CHECK-NEXT: %[[c7:.*]] = constant 7 : index
+ // CHECK-NEXT: %[[prod_j:.*]] = muli %[[c7]], %i3 : index
+ // CHECK-NEXT: %[[j:.*]] = addi %i15, %[[prod_j]] : index
+
+ // CHECK: {{.*}} = load %i12[%[[i]], %[[j]]] : memref<?x?xf32>
+ %0 = load %A[%i, %j] : memref<?x?xf32>
+ // CHECK: store {{.*}}, %i13[%[[i]], %[[j]]] : memref<?x?xf32>
+ store %0, %B[%i, %j] : memref<?x?xf32>
+ }
+ }
+ return
+}
set(LIBS
MLIRAffineOps
+ MLIRAffineToGPU
MLIRAnalysis
MLIREDSC
MLIRFxpMathOps