Add a definition of the library function to use when Linalg ops are
authorMahesh Ravishankar <ravishankarm@google.com>
Thu, 13 Jun 2019 20:47:08 +0000 (13:47 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 20 Jun 2019 06:01:12 +0000 (23:01 -0700)
lowered to LLVM, instead of expecting one to exist in the Module

PiperOrigin-RevId: 253097382

mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/test/Linalg/llvm.mlir
mlir/test/mlir-cpu-runner/linalg_integration_test.mlir

index 1d785554ac8d8ea74fcdfb9a98d79febe08a81db..e9348b6d283ddbb1cd7965891b60043b7e883f44 100644 (file)
@@ -65,6 +65,8 @@ using llvm_select = ValueBuilder<LLVM::SelectOp>;
 using mul = ValueBuilder<mlir::LLVM::MulOp>;
 using sub = ValueBuilder<mlir::LLVM::SubOp>;
 using undef = ValueBuilder<mlir::LLVM::UndefOp>;
+using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
+using llvm_return = OperationBuilder<LLVM::ReturnOp>;
 
 template <typename T>
 static LLVMType getPtrToElementType(T containerType,
@@ -561,26 +563,130 @@ public:
   }
 };
 
+// Create a function definition which takes as argument pointers to the input
+// types and returns pointers to the output types.
+static Function *getLLVMLibraryCallImplDefinition(Function *libFn) {
+  auto implFnName = (libFn->getName().str() + "_impl");
+  auto module = libFn->getModule();
+  if (auto *f = module->getNamedFunction(implFnName)) {
+    return f;
+  }
+  SmallVector<Type, 4> fnArgTypes;
+  for (auto t : libFn->getType().getInputs()) {
+    assert(t.isa<LLVMType>() &&
+           "Expected LLVM Type for argument while generating library Call "
+           "Implementation Definition");
+    fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
+  }
+  auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext());
+
+  // Insert the implementation function definition.
+  auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType);
+  module->getFunctions().push_back(implFnDefn);
+  return implFnDefn;
+}
+
+// Get function definition for the LinalgOp. If it doesn't exist, insert a
+// definition.
+template <typename LinalgOp>
+static Function *getLLVMLibraryCallDeclaration(Operation *op,
+                                               LLVMTypeConverter &lowering,
+                                               PatternRewriter &rewriter) {
+  assert(isa<LinalgOp>(op));
+  auto fnName = LinalgOp::getLibraryCallName();
+  auto module = op->getFunction()->getModule();
+  if (auto *f = module->getNamedFunction(fnName)) {
+    return f;
+  }
+
+  // Get the Function type consistent with LLVM Lowering.
+  SmallVector<Type, 4> inputTypes;
+  for (auto operand : op->getOperands()) {
+    // TODO(ravishankarm): convertLinalgType handles only a subset of Linalg
+    // types. Handle other types (as well as non-Linalg types) either here or in
+    // convertLinalgType.
+    inputTypes.push_back(convertLinalgType(operand->getType(), lowering));
+  }
+  assert(op->getNumResults() == 0 &&
+         "Library call for linalg operation can be generated only for ops that "
+         "have void return types");
+  auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
+  auto libFn = new Function(op->getLoc(), fnName, libFnType);
+  module->getFunctions().push_back(libFn);
+  // Return after creating the function definition. The body will be created
+  // later.
+  return libFn;
+}
+
+static void getLLVMLibraryCallDefinition(Function *fn,
+                                         LLVMTypeConverter &lowering) {
+  // Generate the implementation function definition.
+  auto implFn = getLLVMLibraryCallImplDefinition(fn);
+
+  // Generate the function body.
+  fn->addEntryBlock();
+
+  OpBuilder builder(fn->getBody());
+  edsc::ScopedContext scope(builder, fn->getLoc());
+  SmallVector<Value *, 4> implFnArgs;
+
+  // Create a constant 1.
+  auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()),
+                      IntegerAttr::get(IndexType::get(fn->getContext()), 1));
+  for (auto arg : fn->getArguments()) {
+    // Allocate a stack for storing the argument value. The stack is passed to
+    // the implementation function.
+    auto alloca =
+        llvm_alloca(arg->getType().cast<LLVMType>().getPointerTo(), one)
+            .getValue();
+    implFnArgs.push_back(alloca);
+    llvm_store(arg, alloca);
+  }
+  call(ArrayRef<Type>(), builder.getFunctionAttr(implFn), implFnArgs);
+  llvm_return(ArrayRef<Value *>());
+}
+
+namespace {
+// The conversion class from Linalg to LLVMIR.
+class LinalgTypeConverter : public LLVMTypeConverter {
+  using LLVMTypeConverter::LLVMTypeConverter;
+
+public:
+  Type convertType(Type t) override {
+    if (auto result = LLVMTypeConverter::convertType(t))
+      return result;
+    return convertLinalgType(t, *this);
+  }
+
+  void addLibraryFnDeclaration(Function *fn) {
+    libraryFnDeclarations.push_back(fn);
+  }
+
+  ArrayRef<Function *> getLibraryFnDeclarations() {
+    return libraryFnDeclarations;
+  }
+
+private:
+  /// List of library functions declarations needed during dialect conversion
+  SmallVector<Function *, 2> libraryFnDeclarations;
+};
+} // end anonymous namespace
+
 // LinalgOpConversion<LinalgOp> creates a new call to the
-// `LinalgOp::getLibraryCallName()` function, which is assumed to have been
-// declared in the current MLIR module.
+// `LinalgOp::getLibraryCallName()` function.
 // The implementation of the function can be either in the same module or in an
 // externally linked library.
 template <typename LinalgOp> class LinalgOpConversion : public LLVMOpLowering {
 public:
   explicit LinalgOpConversion(MLIRContext *context,
-                              LLVMTypeConverter &lowering_)
+                              LinalgTypeConverter &lowering_)
       : LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {}
 
   PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                                      PatternRewriter &rewriter) const override {
-    auto *f = op->getFunction()->getModule()->getNamedFunction(
-        LinalgOp::getLibraryCallName());
-    if (!f) {
-      op->emitError("Could not find function: ")
-          << LinalgOp::getLibraryCallName() << "in lowering to LLVM ";
-      return matchFailure();
-    }
+    // Only emit library call declaration. Fill in the body later.
+    auto *f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
+    static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
 
     auto fAttr = rewriter.getFunctionAttr(f);
     auto named = rewriter.getNamedAttr("callee", fAttr);
@@ -590,19 +696,6 @@ public:
   }
 };
 
-namespace {
-// The conversion class from Linalg to LLVMIR.
-struct LinalgTypeConverter : LLVMTypeConverter {
-  using LLVMTypeConverter::LLVMTypeConverter;
-
-  Type convertType(Type t) override {
-    if (auto result = LLVMTypeConverter::convertType(t))
-      return result;
-    return convertLinalgType(t, *this);
-  }
-};
-} // end anonymous namespace
-
 /// Populate the given list with patterns that convert from Linalg to LLVM.
 static void
 populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
@@ -694,8 +787,14 @@ void LowerLinalgToLLVMPass::runOnModule() {
   ConversionTarget target(getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
   if (failed(applyConversionPatterns(module, target, converter,
-                                     std::move(patterns))))
+                                     std::move(patterns)))) {
     signalPassFailure();
+  }
+
+  // Emit the function body of any Library function that was declared.
+  for (auto fn : converter.getLibraryFnDeclarations()) {
+    getLLVMLibraryCallDefinition(fn, converter);
+  }
 }
 
 ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() {
index 3381d8f042157980bdbdca306f5850755509aba1..f4ea0ef7beef7142c2a31e6cbfd4ff120783f87b 100644 (file)
@@ -78,10 +78,6 @@ func @slice(%arg0: !linalg.buffer<f32>, %arg1: !linalg.range) {
 //  CHECK-NEXT:   %27 = llvm.mul %16, %26 : !llvm.i64
 //  CHECK-NEXT:   %28 = llvm.insertvalue %27, %25[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 
-func @linalg_dot(!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
-                 !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
-                 !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">)
-
 func @dot(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<f32>) {
   linalg.dot(%arg0, %arg1, %arg2) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
   return
index 6e37ebde36bd8981506e42f008537f2882453a9b..e72b49de47642efa3c53653aabdee7f898d3a6a9 100644 (file)
@@ -3,42 +3,6 @@
 // RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
 // RUN: mlir-opt %s -linalg-lower-to-loops -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
 
-func @linalg_dot_impl(%arg0 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">,
-                      %arg1 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">,
-                      %arg2 : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">)
-
-func @linalg_dot(%arg0 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
-                 %arg1 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
-                 %arg2 : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
-  %c1 = llvm.constant(1) : !llvm.i64
-  %0 = llvm.alloca %c1 x !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
-  %1 = llvm.alloca %c1 x !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
-  %2 = llvm.alloca %c1 x !llvm<"{ float*, i64, [0 x i64], [0 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">
-  llvm.store %arg0, %0 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
-  llvm.store %arg1, %1 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
-  llvm.store %arg2, %2 : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">
-  call @linalg_dot_impl(%0, %1, %2) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">) -> ()
-  return
-}
-
-func @linalg_matmul_impl(%arg0 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">,
-                         %arg1 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">,
-                         %arg2 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">)
-
-func @linalg_matmul(%arg0 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">,
-                    %arg1 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">,
-                    %arg2 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
-  %c1 = llvm.constant(1) : !llvm.i64
-  %0 = llvm.alloca %c1 x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
-  %1 = llvm.alloca %c1 x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
-  %2 = llvm.alloca %c1 x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
-  llvm.store %arg0, %0 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
-  llvm.store %arg1, %1 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
-  llvm.store %arg2, %2 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
-  call @linalg_matmul_impl(%0, %1, %2) : (!llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">) -> ()
-  return
-}
-
 func @fill_f32(%arg0 : !linalg.buffer<f32>, %f : f32) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index