From 59b473c231f6295bd4aa7199da5816caae5a5e3a Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 9 Aug 2019 07:33:34 -0700 Subject: [PATCH] External library name mangling support for linalg. This CL introduces the ability to generate the external library name for Linalg operations. The problem is that neither mlir or C support overloading and we want a simplified form of name mangling that is still reasonable to read. This CL creates the name of the external call that Linalg expects from the operation name and the type of its arguments. The interface library names are updated and use new cases are added for FillOp. PiperOrigin-RevId: 262556833 --- mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td | 19 ++++++------ mlir/include/mlir/Linalg/IR/LinalgOps.h | 22 ++++++++++++++ mlir/lib/Linalg/IR/LinalgOps.cpp | 33 +++++++++++++++++++++ mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 34 +++++++++++----------- mlir/test/Linalg/llvm.mlir | 2 +- mlir/test/mlir-cpu-runner/cblas_interface.cpp | 22 +++++++++++--- .../mlir-cpu-runner/linalg_integration_test.mlir | 22 +++++--------- 7 files changed, 108 insertions(+), 46 deletions(-) diff --git a/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td index 547a2c4..998d68b 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td @@ -80,10 +80,9 @@ class LinalgLibraryBase_Op props> class LinalgLibrary_Op props> : LinalgLibraryBase_Op { - - code classDeclaration = [{ - StringRef getLibraryCallName() { - return "linalg_}] # mnemonic # [{"; + code libraryCallName = [{ + std::string getLibraryCallName() { + return generateLibraryCallName(getOperation()); } }]; } @@ -138,7 +137,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { return build( builder, result, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; - let extraClassDeclaration = classDeclaration # [{ + let extraClassDeclaration = libraryCallName # [{ unsigned getNumParallelLoops() { auto *view = *(getOperands().begin()); return view->getType().cast().getRank(); @@ -151,7 +150,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> { let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>); - let extraClassDeclaration = classDeclaration # [{ + let extraClassDeclaration = libraryCallName # [{ unsigned getNumParallelLoops() { auto *view = *(getOperands().begin()); return view->getType().cast().getRank(); @@ -170,7 +169,7 @@ def DotOp : LinalgLibrary_Op<"dot", NLoopTypes<0, 1, 0>, ViewRanks<[1, 1, 0]>]> { let arguments = (ins View, View, View); - let extraClassDeclaration = classDeclaration; + let extraClassDeclaration = libraryCallName; } def MatvecOp : LinalgLibrary_Op<"matvec", @@ -178,7 +177,7 @@ def MatvecOp : LinalgLibrary_Op<"matvec", NLoopTypes<1, 1, 0>, ViewRanks<[2, 1, 1]>]> { let arguments = (ins View, View, View); - let extraClassDeclaration = classDeclaration; + let extraClassDeclaration = libraryCallName; } def MatmulOp : LinalgLibrary_Op<"matmul", @@ -186,7 +185,7 @@ def MatmulOp : LinalgLibrary_Op<"matmul", NLoopTypes<2, 1, 0>, ViewRanks<[2, 2, 2]>]> { let arguments = (ins View, View, View); - let extraClassDeclaration = classDeclaration; + let extraClassDeclaration = libraryCallName; } def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> { @@ -211,7 +210,7 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> { let arguments = (ins View:$filter, View:$input, View:$output, OptionalAttr:$strides, OptionalAttr:$dilations); - let extraClassDeclaration = classDeclaration # [{ + let extraClassDeclaration = libraryCallName # [{ // TODO(ntv) extend to support more than 1 dimensions and potentially // grouping too. unsigned getNumBatchDimensions() { return 1; } diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index 4085d06..3187f4f 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -186,6 +186,28 @@ public: } }; +/// Returns the name mangled library call name to disambiguate between different +/// overloads at the C level. The name mangling scheme is basic and uses MLIR +/// type names: +/// 1. form a string which is the concatenation of the linalg op name with all +/// the operand type names, separate by underscores; +/// 2. drop the `linalg.` prefix, and the `<`, `>`, `?` symbols from the type. +/// Assumes `op` is a LinalgOp. +/// +/// Examples: +/// +/// 1. linalg.fill(%A, %f) : !linalg.view, f32 +/// name mangles into `linalg_fill_viewf32_f32_impl` +/// +/// 2. linalg.dot(%A, %B, %C) : +/// !linalg.view, !linalg.view, !linalg.view +/// name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl` +/// +/// 3. linalg.matmul(...) : +/// !linalg.view, !linalg.view, !linalg.view +/// name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl` +std::string generateLibraryCallName(Operation *op); + #define GET_OP_CLASSES #include "mlir/Linalg/IR/LinalgOps.h.inc" diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 6549508..bce2b32 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -37,6 +37,7 @@ #include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::edsc; @@ -1085,3 +1086,35 @@ SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { } llvm_unreachable("Missing loopToOperandRangesMaps for op"); } + +static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { + if (auto view = t.dyn_cast()) { + ss << "view"; + for (unsigned i = 0, e = view.getRank(); i < e; ++i) + ss << "x"; + appendMangledType(ss, view.getElementType()); + } else if (auto vec = t.dyn_cast()) { + ss << "vector"; + interleave( + vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); + appendMangledType(ss, vec.getElementType()); + } else if (t.isIntOrIndexOrFloat()) { + ss << t; + } else { + llvm_unreachable("Invalid type for linalg library name mangling"); + } +} + +std::string mlir::linalg::generateLibraryCallName(Operation *op) { + assert(isa(op)); + std::string name(op->getName().getStringRef().str()); + name.reserve(128); + std::replace(name.begin(), name.end(), '.', '_'); + llvm::raw_string_ostream ss(name); + ss << "_"; + auto types = op->getOperandTypes(); + interleave( + types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, + [&]() { ss << "_"; }); + return ss.str(); +} diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 6967a9d..a45f943 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -39,6 +39,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/LowerAffine.h" #include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/SetVector.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" @@ -545,7 +547,7 @@ static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) { } SmallVector fnArgTypes; for (auto t : libFn.getType().getInputs()) { - assert(t.isa() && + assert(t && t.isa() && "Expected LLVM Type for argument while generating library Call " "Implementation Definition"); fnArgTypes.push_back(t.cast().getPointerTo()); @@ -577,12 +579,8 @@ getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering, // Get the Function type consistent with LLVM Lowering. SmallVector inputTypes; - for (auto operand : op->getOperands()) { - // TODO(ravishankarm): convertLinalgType handles only a subset of Linalg - // types. Handle other types (as well as non-Linalg types) either here or in - // convertLinalgType. - inputTypes.push_back(convertLinalgType(operand->getType(), lowering)); - } + for (auto operand : op->getOperands()) + inputTypes.push_back(lowering.convertType(operand->getType())); assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); @@ -632,15 +630,15 @@ public: return convertLinalgType(t, *this); } - void addLibraryFnDeclaration(FuncOp fn) { - libraryFnDeclarations.push_back(fn); - } + void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.insert(fn); } - ArrayRef getLibraryFnDeclarations() { return libraryFnDeclarations; } + ArrayRef getLibraryFnDeclarations() { + return libraryFnDeclarations.getArrayRef(); + } private: /// List of library functions declarations needed during dialect conversion - SmallVector libraryFnDeclarations; + llvm::SetVector libraryFnDeclarations; }; } // end anonymous namespace @@ -676,11 +674,13 @@ static void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert, LinalgOpConversion, - LoadOpConversion, RangeOpConversion, SliceOpConversion, - StoreOpConversion, ViewOpConversion>(ctx, converter); + patterns + .insert, LinalgOpConversion, + LinalgOpConversion, LoadOpConversion, RangeOpConversion, + SliceOpConversion, StoreOpConversion, ViewOpConversion>( + ctx, converter); } namespace { diff --git a/mlir/test/Linalg/llvm.mlir b/mlir/test/Linalg/llvm.mlir index a1739d4..e82274a 100644 --- a/mlir/test/Linalg/llvm.mlir +++ b/mlir/test/Linalg/llvm.mlir @@ -87,7 +87,7 @@ func @dot(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg return } // CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) { -// CHECK: llvm.call @linalg_dot(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> () +// CHECK: llvm.call @linalg_dot_viewxf32_viewxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> () func @dim(%arg0: !linalg.view) { %0 = linalg.dim %arg0, 1 : !linalg.view diff --git a/mlir/test/mlir-cpu-runner/cblas_interface.cpp b/mlir/test/mlir-cpu-runner/cblas_interface.cpp index 1a63237..973c7f2 100644 --- a/mlir/test/mlir-cpu-runner/cblas_interface.cpp +++ b/mlir/test/mlir-cpu-runner/cblas_interface.cpp @@ -36,16 +36,30 @@ template struct ViewType { unsigned long offset; }; -extern "C" void linalg_dot_impl(ViewType *X, ViewType *Y, - ViewType *Z) { +extern "C" void linalg_fill_viewf32_f32_impl(ViewType *X, float *pF) { + *(X->data + X->offset) = *pF; +} + +extern "C" void linalg_fill_viewxf32_f32_impl(ViewType *X, + float *pF) { + float f = *pF; + for (unsigned i = 0; i < X->sizes[0]; ++i) { + *(X->data + X->offset + i * X->strides[0]) = f; + } +} + +extern "C" void linalg_dot_viewxf32_viewxf32_viewf32_impl( + ViewType *X, ViewType *Y, ViewType *Z) { + assert(X->strides[0] == 1); + assert(Y->strides[0] == 1); assert(X->sizes[0] == Y->sizes[0] && "Expected X and Y of same size"); *(Z->data + Z->offset) += cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0], Y->data + Y->offset, Y->strides[0]); } -extern "C" void linalg_matmul_impl(ViewType *A, ViewType *B, - ViewType *C) { +extern "C" void linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl( + ViewType *A, ViewType *B, ViewType *C) { assert(A->strides[1] == B->strides[1]); assert(A->strides[1] == C->strides[1]); assert(A->strides[1] == 1); diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir index 2dc9748..76b28ab 100644 --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir @@ -3,24 +3,18 @@ // RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s // RUN: mlir-opt %s -linalg-lower-to-loops -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -func @fill_f32(%arg0 : !linalg.buffer, %f : f32) { +// Creates and returns a 1-D buffer of size %s filled with the value %f +func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer { %c0 = constant 0 : index %c1 = constant 1 : index - %s = linalg.buffer_size %arg0 : !linalg.buffer + %buf = linalg.buffer_alloc %s : !linalg.buffer %R = linalg.range %c0:%s:%c1 : !linalg.range - %V = linalg.view %arg0[%R] : !linalg.buffer -> !linalg.view - loop.for %i0 = %c0 to %s step %c1 { - linalg.store %f, %V[%i0] : !linalg.view - } - return -} - -func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer { - %A = linalg.buffer_alloc %s : !linalg.buffer - call @fill_f32(%A, %f) : (!linalg.buffer, f32) -> () - return %A : !linalg.buffer + %V = linalg.view %buf[%R] : !linalg.buffer -> !linalg.view + linalg.fill(%V, %f) : !linalg.view, f32 + return %buf : !linalg.buffer } +// Test for linalg.dot. func @dot() -> f32 { %c0 = constant 0 : index %c1 = constant 1 : index @@ -48,6 +42,7 @@ func @dot() -> f32 { return %res : f32 } +// Test for linalg.matmul. func @matmul() -> f32 { %c0 = constant 0 : index %c1 = constant 1 : index @@ -82,6 +77,5 @@ func @matmul() -> f32 { return %res : f32 } - // All tests return this value // CHECK: 4.2{{0+}}e+01 -- 2.7.4