using mul = ValueBuilder<mlir::LLVM::MulOp>;
using sub = ValueBuilder<mlir::LLVM::SubOp>;
using undef = ValueBuilder<mlir::LLVM::UndefOp>;
+using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
+using llvm_return = OperationBuilder<LLVM::ReturnOp>;
template <typename T>
static LLVMType getPtrToElementType(T containerType,
}
};
+// Create a function definition which takes as argument pointers to the input
+// types and returns pointers to the output types.
+static Function *getLLVMLibraryCallImplDefinition(Function *libFn) {
+ auto implFnName = (libFn->getName().str() + "_impl");
+ auto module = libFn->getModule();
+ if (auto *f = module->getNamedFunction(implFnName)) {
+ return f;
+ }
+ SmallVector<Type, 4> fnArgTypes;
+ for (auto t : libFn->getType().getInputs()) {
+ assert(t.isa<LLVMType>() &&
+ "Expected LLVM Type for argument while generating library Call "
+ "Implementation Definition");
+ fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
+ }
+ auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext());
+
+ // Insert the implementation function definition.
+ auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType);
+ module->getFunctions().push_back(implFnDefn);
+ return implFnDefn;
+}
+
+// Get function definition for the LinalgOp. If it doesn't exist, insert a
+// definition.
+template <typename LinalgOp>
+static Function *getLLVMLibraryCallDeclaration(Operation *op,
+ LLVMTypeConverter &lowering,
+ PatternRewriter &rewriter) {
+ assert(isa<LinalgOp>(op));
+ auto fnName = LinalgOp::getLibraryCallName();
+ auto module = op->getFunction()->getModule();
+ if (auto *f = module->getNamedFunction(fnName)) {
+ return f;
+ }
+
+ // Get the Function type consistent with LLVM Lowering.
+ SmallVector<Type, 4> inputTypes;
+ for (auto operand : op->getOperands()) {
+ // TODO(ravishankarm): convertLinalgType handles only a subset of Linalg
+ // types. Handle other types (as well as non-Linalg types) either here or in
+ // convertLinalgType.
+ inputTypes.push_back(convertLinalgType(operand->getType(), lowering));
+ }
+ assert(op->getNumResults() == 0 &&
+ "Library call for linalg operation can be generated only for ops that "
+ "have void return types");
+ auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
+ auto libFn = new Function(op->getLoc(), fnName, libFnType);
+ module->getFunctions().push_back(libFn);
+ // Return after creating the function definition. The body will be created
+ // later.
+ return libFn;
+}
+
+static void getLLVMLibraryCallDefinition(Function *fn,
+ LLVMTypeConverter &lowering) {
+ // Generate the implementation function definition.
+ auto implFn = getLLVMLibraryCallImplDefinition(fn);
+
+ // Generate the function body.
+ fn->addEntryBlock();
+
+ OpBuilder builder(fn->getBody());
+ edsc::ScopedContext scope(builder, fn->getLoc());
+ SmallVector<Value *, 4> implFnArgs;
+
+ // Create a constant 1.
+ auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()),
+ IntegerAttr::get(IndexType::get(fn->getContext()), 1));
+ for (auto arg : fn->getArguments()) {
+ // Allocate a stack for storing the argument value. The stack is passed to
+ // the implementation function.
+ auto alloca =
+ llvm_alloca(arg->getType().cast<LLVMType>().getPointerTo(), one)
+ .getValue();
+ implFnArgs.push_back(alloca);
+ llvm_store(arg, alloca);
+ }
+ call(ArrayRef<Type>(), builder.getFunctionAttr(implFn), implFnArgs);
+ llvm_return(ArrayRef<Value *>());
+}
+
+namespace {
+// The conversion class from Linalg to LLVMIR.
+class LinalgTypeConverter : public LLVMTypeConverter {
+ using LLVMTypeConverter::LLVMTypeConverter;
+
+public:
+ Type convertType(Type t) override {
+ if (auto result = LLVMTypeConverter::convertType(t))
+ return result;
+ return convertLinalgType(t, *this);
+ }
+
+ void addLibraryFnDeclaration(Function *fn) {
+ libraryFnDeclarations.push_back(fn);
+ }
+
+ ArrayRef<Function *> getLibraryFnDeclarations() {
+ return libraryFnDeclarations;
+ }
+
+private:
+ /// List of library functions declarations needed during dialect conversion
+ SmallVector<Function *, 2> libraryFnDeclarations;
+};
+} // end anonymous namespace
+
// LinalgOpConversion<LinalgOp> creates a new call to the
-// `LinalgOp::getLibraryCallName()` function, which is assumed to have been
-// declared in the current MLIR module.
+// `LinalgOp::getLibraryCallName()` function.
// The implementation of the function can be either in the same module or in an
// externally linked library.
template <typename LinalgOp> class LinalgOpConversion : public LLVMOpLowering {
public:
explicit LinalgOpConversion(MLIRContext *context,
- LLVMTypeConverter &lowering_)
+ LinalgTypeConverter &lowering_)
: LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {}
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
- auto *f = op->getFunction()->getModule()->getNamedFunction(
- LinalgOp::getLibraryCallName());
- if (!f) {
- op->emitError("Could not find function: ")
- << LinalgOp::getLibraryCallName() << "in lowering to LLVM ";
- return matchFailure();
- }
+ // Only emit library call declaration. Fill in the body later.
+ auto *f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
+ static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
auto fAttr = rewriter.getFunctionAttr(f);
auto named = rewriter.getNamedAttr("callee", fAttr);
}
};
-namespace {
-// The conversion class from Linalg to LLVMIR.
-struct LinalgTypeConverter : LLVMTypeConverter {
- using LLVMTypeConverter::LLVMTypeConverter;
-
- Type convertType(Type t) override {
- if (auto result = LLVMTypeConverter::convertType(t))
- return result;
- return convertLinalgType(t, *this);
- }
-};
-} // end anonymous namespace
-
/// Populate the given list with patterns that convert from Linalg to LLVM.
static void
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
if (failed(applyConversionPatterns(module, target, converter,
- std::move(patterns))))
+ std::move(patterns)))) {
signalPassFailure();
+ }
+
+ // Emit the function body of any Library function that was declared.
+ for (auto fn : converter.getLibraryFnDeclarations()) {
+ getLLVMLibraryCallDefinition(fn, converter);
+ }
}
ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() {
// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
-func @linalg_dot_impl(%arg0 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">,
- %arg1 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">,
- %arg2 : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">)
-
-func @linalg_dot(%arg0 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
- %arg1 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
- %arg2 : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
- %c1 = llvm.constant(1) : !llvm.i64
- %0 = llvm.alloca %c1 x !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
- %1 = llvm.alloca %c1 x !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
- %2 = llvm.alloca %c1 x !llvm<"{ float*, i64, [0 x i64], [0 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">
- llvm.store %arg0, %0 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
- llvm.store %arg1, %1 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
- llvm.store %arg2, %2 : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">
- call @linalg_dot_impl(%0, %1, %2) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">) -> ()
- return
-}
-
-func @linalg_matmul_impl(%arg0 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">,
- %arg1 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">,
- %arg2 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">)
-
-func @linalg_matmul(%arg0 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">,
- %arg1 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">,
- %arg2 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
- %c1 = llvm.constant(1) : !llvm.i64
- %0 = llvm.alloca %c1 x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
- %1 = llvm.alloca %c1 x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
- %2 = llvm.alloca %c1 x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
- llvm.store %arg0, %0 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
- llvm.store %arg1, %1 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
- llvm.store %arg2, %2 : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
- call @linalg_matmul_impl(%0, %1, %2) : (!llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">) -> ()
- return
-}
-
func @fill_f32(%arg0 : !linalg.buffer<f32>, %f : f32) {
%c0 = constant 0 : index
%c1 = constant 1 : index