From e33a636e26349bf42557908766911654d60cb6c0 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 3 Apr 2020 00:08:22 -0400 Subject: [PATCH] [mlir][Linalg] Employ finer-grained control of C interface emission Summary: Linalg makes it possible to interface codegen with externally precompiled HPC libraries. The mechanism to allow such interop uses a normalized ABI and the emission of C interface wrappers. The mechanism controlling these C interface emission is too aggressive and makes it very easy to obtained undefined symbols for external function (e.g. the ones coming from libm). This revision uses the newly introduced llvm.emit_c_interface function attribute which allows controlling this behavior at a function granularity. As a consequence LinalgToLLVM does not need to activate the C wrapper emission when adding the StdToLLVM patterns. Differential Revision: https://reviews.llvm.org/D77364 --- mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 12 ++++++++---- mlir/test/mlir-cpu-runner/unranked_memref.mlir | 4 ++-- mlir/test/mlir-cpu-runner/utils.mlir | 4 ++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index febf4eb..cb66ae9 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -400,8 +400,13 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, // Insert before module terminator. rewriter.setInsertionPoint(module.getBody(), std::prev(module.getBody()->end())); - rewriter.create(op->getLoc(), fnNameAttr.getValue(), libFnType, - ArrayRef{}); + FuncOp funcOp = + rewriter.create(op->getLoc(), fnNameAttr.getValue(), libFnType, + ArrayRef{}); + // Insert a function attribute that will trigger the emission of the + // corresponding `_mlir_ciface_xxx` interface so that external libraries see + // a normalized ABI. This interface is added during std to llvm conversion. + funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext())); return fnNameAttr; } @@ -573,8 +578,7 @@ void ConvertLinalgToLLVMPass::runOnModule() { LLVMTypeConverter converter(&getContext()); populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); - populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false, - /*emitCWrappers=*/true); + populateStdToLLVMConversionPatterns(converter, patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateLinalgToStandardConversionPatterns(patterns, &getContext()); diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir index 7447e9d..13d4db9 100644 --- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir +++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir @@ -55,5 +55,5 @@ func @main() -> () { return } -func @print_memref_i8(memref<*xi8>) -func @print_memref_f32(memref<*xf32>) +func @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface } +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir index 2a56920..d3ab617 100644 --- a/mlir/test/mlir-cpu-runner/utils.mlir +++ b/mlir/test/mlir-cpu-runner/utils.mlir @@ -50,7 +50,7 @@ func @print_3d() { // PRINT-3D-NEXT: 2, 2, 4, 2, 2 // PRINT-3D-NEXT: 2, 2, 2, 2, 2 -func @print_memref_f32(memref<*xf32>) +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } !vector_type_C = type vector<4x4xf32> !matrix_type_CC = type memref<1x1x!vector_type_C> @@ -71,4 +71,4 @@ func @vector_splat_2d() { // PRINT-VECTOR-SPLAT-2D: Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [1, 1] strides = [1, 1] data = // PRINT-VECTOR-SPLAT-2D-NEXT: [((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10))] -func @print_memref_vector_4x4xf32(memref) +func @print_memref_vector_4x4xf32(memref) attributes { llvm.emit_c_interface } -- 2.7.4