Make it possible to override the lowering of MemRef to the LLVM dialect. NFC.
authorAlex Zinenko <zinenko@google.com>
Tue, 17 Dec 2019 20:09:33 +0000 (12:09 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 17 Dec 2019 20:10:04 +0000 (12:10 -0800)
The lowering of MemRef types to the LLVM dialect is connected to the underlying
runtime representation of structured memory buffers. It has changed several
times in the past and reached the current state of a LLVM structured-typed
descriptor containing two pointers and all sizes. In several reported use
cases, a different, often simpler, lowering scheme is required. For example,
lowering statically-shaped memrefs to bare LLVM pointers to simplify aliasing
annotation. Split the pattern population functions into those include
memref-related operations and the remaining ones. Users are expected to extend
TypeConverter::convertType to handle the memref types differently.
PiperOrigin-RevId: 286030610

mlir/g3doc/ConversionToLLVMDialect.md
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

index 3881ee0..19403e2 100644 (file)
@@ -302,11 +302,11 @@ llvm.func @bar(%arg0: !llvm.i64) {
   llvm.call @foo(%16) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> ()
   llvm.return
 }
-
-
-
 ```
 
+*This convention may or may not apply if the conversion of MemRef types is
+overridden by the user.*
+
 ## Repeated Successor Removal
 
 Since the goal of the LLVM IR dialect is to reflect LLVM IR in MLIR, the dialect
@@ -349,7 +349,7 @@ before the conversion to the LLVM IR dialect:
   llvm.br ^bb1(%2 : !llvm.i32)
 ```
 
-## Memref Model
+## Default Memref Model
 
 ### Memref Descriptor
 
index c5c17b3..d49c1c2 100644 (file)
@@ -52,6 +52,17 @@ using LLVMPatternListFiller =
 using LLVMTypeConverterMaker =
     std::function<std::unique_ptr<LLVMTypeConverter>(MLIRContext *)>;
 
+/// Collect a set of patterns to convert memory-related operations from the
+/// Standard dialect to the LLVM dialect, excluding the memory-related
+/// operations.
+void populateStdToLLVMMemoryConversionPatters(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
+/// Collect a set of patterns to convert from the Standard dialect to the LLVM
+/// dialect, excluding the memory-related operations.
+void populateStdToLLVMNonMemoryConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
 /// Collect a set of patterns to convert from the Standard dialect to LLVM.
 void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                          OwningRewritePatternList &patterns);
index 9b2113a..51cdd72 100644 (file)
@@ -2020,7 +2020,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
 }
 
 /// Collect a set of patterns to convert from the Standard dialect to LLVM.
-void mlir::populateStdToLLVMConversionPatterns(
+void mlir::populateStdToLLVMNonMemoryConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   // FIXME: this should be tablegen'ed
   // clang-format off
@@ -2035,7 +2035,6 @@ void mlir::populateStdToLLVMConversionPatterns(
       CmpIOpLowering,
       CondBranchOpLowering,
       ConstLLVMOpLowering,
-      DimOpLowering,
       DivFOpLowering,
       DivISOpLowering,
       DivIUOpLowering,
@@ -2045,10 +2044,7 @@ void mlir::populateStdToLLVMConversionPatterns(
       Log2OpLowering,
       FPExtLowering,
       FPTruncLowering,
-      FuncOpConversion,
       IndexCastOpLowering,
-      LoadOpLowering,
-      MemRefCastOpLowering,
       MulFOpLowering,
       MulIOpLowering,
       OrOpLowering,
@@ -2061,22 +2057,39 @@ void mlir::populateStdToLLVMConversionPatterns(
       SignExtendIOpLowering,
       SplatOpLowering,
       SplatNdOpLowering,
-      StoreOpLowering,
       SubFOpLowering,
       SubIOpLowering,
-      SubViewOpLowering,
       TanhOpLowering,
       TruncateIOpLowering,
-      ViewOpLowering,
       XOrOpLowering,
       ZeroExtendIOpLowering>(*converter.getDialect(), converter);
-   patterns.insert<
-       AllocOpLowering,
-       DeallocOpLowering>(
-           *converter.getDialect(), converter, clUseAlloca.getValue());
   // clang-format on
 }
 
+void mlir::populateStdToLLVMMemoryConversionPatters(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  // clang-format off
+  patterns.insert<
+      DimOpLowering,
+      FuncOpConversion,
+      LoadOpLowering,
+      MemRefCastOpLowering,
+      StoreOpLowering,
+      SubViewOpLowering,
+      ViewOpLowering>(*converter.getDialect(), converter);
+  patterns.insert<
+      AllocOpLowering,
+      DeallocOpLowering>(
+        *converter.getDialect(), converter, clUseAlloca.getValue());
+  // clang-format on
+}
+
+void mlir::populateStdToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
+  populateStdToLLVMMemoryConversionPatters(converter, patterns);
+}
+
 // Convert types using the stored LLVM IR module.
 Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); }