// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
// function. The implementation of the function can be either in the same module
// or in an externally linked library.
-// This is a generic entry point for all LinalgOp, except for CopyOp and
-// IndexedGenericOp, for which more specialized patterns are provided.
+// This is a generic entry point for all LinalgOp, except for CopyOp, for which
+// more specialized patterns are provided.
class LinalgOpToLibraryCallRewrite
: public OpInterfaceRewritePattern<LinalgOp> {
public:
PatternRewriter &rewriter) const override;
};
-/// Conversion pattern specialization for IndexedGenericOp, has special handling
-/// for the extra index operands.
-class IndexedGenericOpToLibraryCallRewrite
- : public OpRewritePattern<IndexedGenericOp> {
-public:
- using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(IndexedGenericOp op,
- PatternRewriter &rewriter) const override;
-};
-
/// Populate the given list with patterns that convert from Linalg to Standard.
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
SmallVector<Type, 4> result;
result.reserve(op->getNumOperands());
- if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
- auto *ctx = op->getContext();
- auto numLoops = indexedGenericOp.getNumLoops();
- result.reserve(op->getNumOperands() + numLoops);
- result.assign(numLoops, IndexType::get(ctx));
- }
for (auto type : op->getOperandTypes()) {
// The underlying descriptor type (e.g. LLVM) does not have layout
// information. Canonicalizing the type at the level of std when going into
LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
LinalgOp op, PatternRewriter &rewriter) const {
// Only LinalgOp for which there is no specialized pattern go through this.
- if (isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
+ if (isa<CopyOp>(op))
+ return failure();
+
+ // Canonicalize indexed generic operations before library call conversion.
+ if (isa<IndexedGenericOp>(op))
return failure();
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
return success();
}
-LogicalResult
-mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
- IndexedGenericOp op, PatternRewriter &rewriter) const {
- auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
- if (!libraryCallName)
- return failure();
-
- // TODO: Use induction variables values instead of zeros, when
- // IndexedGenericOp is tiled.
- auto zero = rewriter.create<mlir::ConstantOp>(
- op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
- auto indexedGenericOp = cast<IndexedGenericOp>(op);
- auto numLoops = indexedGenericOp.getNumLoops();
- SmallVector<Value, 4> operands;
- operands.reserve(numLoops + op.getNumOperands());
- for (unsigned i = 0; i < numLoops; ++i)
- operands.push_back(zero);
- for (auto operand : op.getOperands())
- operands.push_back(operand);
- rewriter.replaceOpWithNewOp<mlir::CallOp>(
- op, libraryCallName.getValue(), TypeRange(),
- createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
- return success();
-}
-
/// Populate the given list with patterns that convert from Linalg to Standard.
void mlir::linalg::populateLinalgToStandardConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<
CopyOpToLibraryCallRewrite,
CopyTransposeRewrite,
- IndexedGenericOpToLibraryCallRewrite,
LinalgOpToLibraryCallRewrite>(patterns.getContext());
// clang-format on
}
}
// CHECK-LABEL: func @matmul_vec_impl(
// CHECK: call @external_outerproduct_matmul(%{{.*}}) :
-
-#indexed_matmul_trait = {
- iterator_types = ["parallel", "parallel", "reduction"],
- indexing_maps = #matmul_accesses,
- library_call = "external_indexed_outerproduct_matmul"
-}
-func @matmul_vec_indexed(%A: !matrix_type_A,
- %B: !matrix_type_B,
- %C: !matrix_type_C) {
- linalg.indexed_generic #indexed_matmul_trait
- ins(%A, %B : !matrix_type_A, !matrix_type_B)
- outs(%C : !matrix_type_C) {
- ^bb0(%i: index, %j: index, %k: index,
- %a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
- %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
- linalg.yield %d: !vector_type_C
- }
- return
-}
-// CHECK-LABEL: func @matmul_vec_indexed(
-// CHECK: %[[ZERO:.*]] = constant 0 : index
-// CHECK: call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}})