#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
+#include "llvm/ADT/TypeSwitch.h"
+
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
assert(op.getOperation()->getNumRegions() == 1 &&
"Expected single region op");
auto &b = ScopedContext::getBuilderRef();
- auto &block = op.region().front();
+ auto &block = op.getOperation()->getRegion(0).front();
BlockAndValueMapping map;
map.map(block.getArguments(), indexedValues);
for (auto &op : block.without_terminator()) {
makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
}
-namespace {
-
/// Emits the MLIR for the scalar part of the generic op by:
/// 1. Emitting load ops for each input and output view in order. This is
/// achieved by applying the appropriate input or output map to the
/// }
/// }
/// ```
-// TODO: need a LinalgStructuredOpInterface.
-template <typename IndexedValueType, typename LinalgStructuredOpType>
-void emitScalarImplementation(ArrayRef<Value> allIvs,
- LinalgStructuredOpType linalgOp) {
+template <typename IndexedValueType>
+static void emitScalarImplementation(ArrayRef<Value> allIvs,
+ LinalgOp linalgOp) {
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto &b = ScopedContext::getBuilderRef();
auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end());
if (attr) {
- auto operand = linalgOp.getOperand(attr.getInt());
+ auto operand = linalgOp.getOperation()->getOperand(attr.getInt());
auto shapedType = operand.getType().template cast<ShapedType>();
allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank());
for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx)
}
template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
+static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
assert(copyOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = copyOp.getNumParallelLoops();
}
template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
+static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
assert(fillOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = fillOp.getNumParallelLoops();
}
template <typename IndexedValueType>
-Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
- MutableArrayRef<Value> imIdx) {
+static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
+ MutableArrayRef<Value> imIdx) {
// TODO: add a level of indirection to linalg.generic.
if (!convOp.padding())
return im(imIdx);
}
}
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
+template <typename IndexedValueType, typename OpType>
+static void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs,
+ OpType op) {
InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
// Emit scalar form.
IndexedValueType output(op.output());
Value lhs = output(indices.outputs);
Value rhs = input(indices.inputs);
using edsc::op::sgt;
- Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
- output(indices.outputs) = maxValue;
+ using edsc::op::slt;
+ Value value = std::is_same<OpType, PoolingMinOp>()
+ ? std_select(slt(lhs, rhs), lhs, rhs)
+ : std_select(sgt(lhs, rhs), lhs, rhs);
+ output(indices.outputs) = value;
}
template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
- InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
- // Emit scalar form.
- IndexedValueType output(op.output());
- IndexedValueType input(op.input());
- Value lhs = output(indices.outputs);
- Value rhs = input(indices.inputs);
- using edsc::op::slt;
- Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
- output(indices.outputs) = minValue;
+static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
+ emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs,
+ op);
}
+
template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
+static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
+ emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs,
+ op);
+}
+
+template <typename IndexedValueType>
+static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
auto indices = getInputAndOutputIndices(allIvs, op);
IndexedValueType input(op.input()), output(op.output());
// Emit scalar form.
output(indices.outputs) += input(indices.inputs);
}
+
/// Emits the MLIR for the scalar part of the indexed generic op by:
/// 1. Emitting load ops for each input and output view in order. This is
/// achieved by applying the appropriate input or output map to the
indexing, outputBuffers);
}
-template <typename LoopTy, typename ConcreteOpTy>
-Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
+template <typename LoopTy>
+static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
+ OpBuilder &builder) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
ScopedContext scope(builder, op->getLoc());
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
- auto linalgOp = cast<ConcreteOpTy>(op);
+ auto linalgOp = cast<LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto mapsRange =
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end());
- emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp);
+ llvm::TypeSwitch<Operation *>(op)
+ .Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
+ PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
+ emitScalarImplementation<IndexedValueTy>(allIvs, op);
+ })
+ .Default([&](Operation *op) { assert(false && "unexpected op"); });
return scf::ValueVector{};
});
// Number of loop ops might be different from the number of ivs since some
return loops;
}
-template <typename LoopType, typename ConcreteOp>
+namespace {
+template <typename LoopType>
class LinalgRewritePattern : public RewritePattern {
public:
- explicit LinalgRewritePattern(MLIRContext *context)
- : RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
+ LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter))
+ if (!isa<LinalgOp>(op))
+ return failure();
+ if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
return failure();
rewriter.eraseOp(op);
return success();
}
};
-template <typename LoopType, typename ConcreteOp>
-void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
-}
+struct FoldAffineOp;
+} // namespace
-template <typename LoopType, typename... Args>
-void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
- (void)std::initializer_list<int>{
- 0, (insertOnePattern<LoopType, Args>(patterns, ctx), 0)...};
+template <typename LoopType>
+static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
+ OwningRewritePatternList patterns;
+ patterns.insert<LinalgRewritePattern<LoopType>>();
+ DimOp::getCanonicalizationPatterns(patterns, context);
+ AffineApplyOp::getCanonicalizationPatterns(patterns, context);
+ patterns.insert<FoldAffineOp>(context);
+ // Just apply the patterns greedily.
+ applyPatternsAndFoldGreedily(funcOp, patterns);
}
+namespace {
/// Local folding pattern for AffineApplyOp that we can apply greedily.
/// This replaces AffineApplyOp by the proper value in cases where the
/// associated map is trivial.
return failure();
}
};
-} // namespace
-
-template <typename LoopType>
-static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
- OwningRewritePatternList patterns;
- // Canonicalization and folding patterns applied greedily allow cleaning up
- // the emitted IR on the fly.
- // TODO: fold view and subview ops?
- insertPatterns<LoopType,
-#define GET_OP_LIST
-#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
- >(patterns, context);
- DimOp::getCanonicalizationPatterns(patterns, context);
- AffineApplyOp::getCanonicalizationPatterns(patterns, context);
- patterns.insert<FoldAffineOp>(context);
- // Just apply the patterns greedily.
- applyPatternsAndFoldGreedily(funcOp, patterns);
-}
-
-namespace {
struct LowerToAffineLoops
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
void runOnFunction() override {
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
}
};
+
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
void runOnFunction() override {
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext());
}
};
+
struct LowerToParallelLoops
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
void runOnFunction() override {
return std::make_unique<LowerToAffineLoops>();
}
-// TODO: gradually remove this layer as more ops become "named".
-template <typename LoopTy>
-static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
- OpBuilder &builder) {
- assert(isa<LinalgOp>(op) && "LinalgOp expected");
- if (isa<CopyOp>(op))
- return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder);
- if (isa<FillOp>(op))
- return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder);
- if (isa<ConvOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder);
- if (isa<PoolingMaxOp>(op))
- return linalgOpToLoopsImpl<LoopTy, PoolingMaxOp>(op, builder);
- if (isa<PoolingMinOp>(op))
- return linalgOpToLoopsImpl<LoopTy, PoolingMinOp>(op, builder);
- if (isa<PoolingSumOp>(op))
- return linalgOpToLoopsImpl<LoopTy, PoolingSumOp>(op, builder);
- if (isa<IndexedGenericOp>(op))
- return linalgOpToLoopsImpl<LoopTy, IndexedGenericOp>(op, builder);
-
- // TODO: Cases below are generic and need a LinalgStructuredOpInterface.
- if (isa<GenericOp>(op))
- return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder);
- if (isa<MatmulOp>(op))
- return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder);
- if (isa<MatvecOp>(op))
- return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
- if (isa<VecmatOp>(op))
- return linalgOpToLoopsImpl<LoopTy, VecmatOp>(op, builder);
- if (isa<DotOp>(op))
- return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
- if (isa<BatchMatmulOp>(op))
- return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
- if (isa<ConvWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder);
- if (isa<ConvNWCOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder);
- if (isa<ConvNCWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder);
- if (isa<ConvHWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder);
- if (isa<ConvNHWCOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder);
- if (isa<ConvNCHWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder);
- if (isa<ConvDHWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder);
- if (isa<ConvNDHWCOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder);
- if (isa<ConvNCDHWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder);
- llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
-}
-
SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
AffineMap map,
ValueRange viewSizes) {
template <typename LoopTy>
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
Operation *op) {
- return linalgOpToLoopsImplSwitch<LoopTy>(op, builder);
+ return linalgOpToLoopsImpl<LoopTy>(op, builder);
}
template Optional<LinalgLoops>