From 27ad213680eae7aa75dd6dd72608957cac9198f2 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Mon, 19 Apr 2021 12:23:11 +0000 Subject: [PATCH] [mlir][linalg] enable library call rewrites for linalg operations with index semantics. The patch enables the library call lowering for linalg operations that contain index operations. Differential Revision: https://reviews.llvm.org/D100537 --- mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h | 2 +- mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h index eeb20c4..d317253 100644 --- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h +++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h @@ -29,7 +29,7 @@ namespace linalg { // function. The implementation of the function can be either in the same module // or in an externally linked library. // This is a generic entry point for all LinalgOp, except for CopyOp and -// IndexedGenericOp, for which omre specialized patterns are provided. +// IndexedGenericOp, for which more specialized patterns are provided. class LinalgOpToLibraryCallRewrite : public OpInterfaceRewritePattern { public: diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp index d385b46..dd4fafb 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -106,14 +106,12 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( if (isa(op) || isa(op)) return failure(); - // TODO: remove once index ops are supported. - if (op.hasIndexSemantics()) - return failure(); - auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); if (!libraryCallName) return failure(); + // TODO: Add support for more complex library call signatures that include + // indices or captured values. rewriter.replaceOpWithNewOp( op, libraryCallName.getValue(), TypeRange(), createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(), -- 2.7.4