class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
: LinalgLibraryBase_Op<mnemonic, props> {
-
- code classDeclaration = [{
- StringRef getLibraryCallName() {
- return "linalg_}] # mnemonic # [{";
+ code libraryCallName = [{
+ std::string getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
}
}];
}
return build(
builder, result, input, output, AffineMapAttr(), AffineMapAttr());
}]>];
- let extraClassDeclaration = classDeclaration # [{
+ let extraClassDeclaration = libraryCallName # [{
unsigned getNumParallelLoops() {
auto *view = *(getOperands().begin());
return view->getType().cast<ViewType>().getRank();
def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> {
let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>);
- let extraClassDeclaration = classDeclaration # [{
+ let extraClassDeclaration = libraryCallName # [{
unsigned getNumParallelLoops() {
auto *view = *(getOperands().begin());
return view->getType().cast<ViewType>().getRank();
NLoopTypes<0, 1, 0>,
ViewRanks<[1, 1, 0]>]> {
let arguments = (ins View, View, View);
- let extraClassDeclaration = classDeclaration;
+ let extraClassDeclaration = libraryCallName;
}
def MatvecOp : LinalgLibrary_Op<"matvec",
NLoopTypes<1, 1, 0>,
ViewRanks<[2, 1, 1]>]> {
let arguments = (ins View, View, View);
- let extraClassDeclaration = classDeclaration;
+ let extraClassDeclaration = libraryCallName;
}
def MatmulOp : LinalgLibrary_Op<"matmul",
NLoopTypes<2, 1, 0>,
ViewRanks<[2, 2, 2]>]> {
let arguments = (ins View, View, View);
- let extraClassDeclaration = classDeclaration;
+ let extraClassDeclaration = libraryCallName;
}
def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
let arguments = (ins View:$filter, View:$input, View:$output,
OptionalAttr<I64ArrayAttr>:$strides,
OptionalAttr<I64ArrayAttr>:$dilations);
- let extraClassDeclaration = classDeclaration # [{
+ let extraClassDeclaration = libraryCallName # [{
// TODO(ntv) extend to support more than 1 dimensions and potentially
// grouping too.
unsigned getNumBatchDimensions() { return 1; }
}
};
+/// Returns the name mangled library call name to disambiguate between different
+/// overloads at the C level. The name mangling scheme is basic and uses MLIR
+/// type names:
+/// 1. form a string which is the concatenation of the linalg op name with all
+/// the operand type names, separate by underscores;
+/// 2. drop the `linalg.` prefix, and the `<`, `>`, `?` symbols from the type.
+/// Assumes `op` is a LinalgOp.
+///
+/// Examples:
+///
+/// 1. linalg.fill(%A, %f) : !linalg.view<f32>, f32
+/// name mangles into `linalg_fill_viewf32_f32_impl`
+///
+/// 2. linalg.dot(%A, %B, %C) :
+/// !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
+/// name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl`
+///
+/// 3. linalg.matmul(...) :
+/// !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+/// name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl`
+std::string generateLibraryCallName(Operation *op);
+
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.h.inc"
#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::edsc;
}
llvm_unreachable("Missing loopToOperandRangesMaps for op");
}
+
+static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
+ if (auto view = t.dyn_cast<ViewType>()) {
+ ss << "view";
+ for (unsigned i = 0, e = view.getRank(); i < e; ++i)
+ ss << "x";
+ appendMangledType(ss, view.getElementType());
+ } else if (auto vec = t.dyn_cast<VectorType>()) {
+ ss << "vector";
+ interleave(
+ vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
+ appendMangledType(ss, vec.getElementType());
+ } else if (t.isIntOrIndexOrFloat()) {
+ ss << t;
+ } else {
+ llvm_unreachable("Invalid type for linalg library name mangling");
+ }
+}
+
+std::string mlir::linalg::generateLibraryCallName(Operation *op) {
+ assert(isa<LinalgOp>(op));
+ std::string name(op->getName().getStringRef().str());
+ name.reserve(128);
+ std::replace(name.begin(), name.end(), '.', '_');
+ llvm::raw_string_ostream ss(name);
+ ss << "_";
+ auto types = op->getOperandTypes();
+ interleave(
+ types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
+ [&]() { ss << "_"; });
+ return ss.str();
+}
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/LowerAffine.h"
#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/SetVector.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
}
SmallVector<Type, 4> fnArgTypes;
for (auto t : libFn.getType().getInputs()) {
- assert(t.isa<LLVMType>() &&
+ assert(t && t.isa<LLVMType>() &&
"Expected LLVM Type for argument while generating library Call "
"Implementation Definition");
fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
// Get the Function type consistent with LLVM Lowering.
SmallVector<Type, 4> inputTypes;
- for (auto operand : op->getOperands()) {
- // TODO(ravishankarm): convertLinalgType handles only a subset of Linalg
- // types. Handle other types (as well as non-Linalg types) either here or in
- // convertLinalgType.
- inputTypes.push_back(convertLinalgType(operand->getType(), lowering));
- }
+ for (auto operand : op->getOperands())
+ inputTypes.push_back(lowering.convertType(operand->getType()));
assert(op->getNumResults() == 0 &&
"Library call for linalg operation can be generated only for ops that "
"have void return types");
return convertLinalgType(t, *this);
}
- void addLibraryFnDeclaration(FuncOp fn) {
- libraryFnDeclarations.push_back(fn);
- }
+ void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.insert(fn); }
- ArrayRef<FuncOp> getLibraryFnDeclarations() { return libraryFnDeclarations; }
+ ArrayRef<FuncOp> getLibraryFnDeclarations() {
+ return libraryFnDeclarations.getArrayRef();
+ }
private:
/// List of library functions declarations needed during dialect conversion
- SmallVector<FuncOp, 2> libraryFnDeclarations;
+ llvm::SetVector<FuncOp> libraryFnDeclarations;
};
} // end anonymous namespace
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
OwningRewritePatternList &patterns,
MLIRContext *ctx) {
- patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
- BufferSizeOpConversion, DimOpConversion,
- LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
- LoadOpConversion, RangeOpConversion, SliceOpConversion,
- StoreOpConversion, ViewOpConversion>(ctx, converter);
+ patterns
+ .insert<BufferAllocOpConversion, BufferDeallocOpConversion,
+ BufferSizeOpConversion, DimOpConversion,
+ LinalgOpConversion<DotOp>, LinalgOpConversion<FillOp>,
+ LinalgOpConversion<MatmulOp>, LoadOpConversion, RangeOpConversion,
+ SliceOpConversion, StoreOpConversion, ViewOpConversion>(
+ ctx, converter);
}
namespace {
return
}
// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
-// CHECK: llvm.call @linalg_dot(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> ()
+// CHECK: llvm.call @linalg_dot_viewxf32_viewxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> ()
func @dim(%arg0: !linalg.view<?x?xf32>) {
%0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
unsigned long offset;
};
-extern "C" void linalg_dot_impl(ViewType<float, 1> *X, ViewType<float, 1> *Y,
- ViewType<float, 0> *Z) {
+extern "C" void linalg_fill_viewf32_f32_impl(ViewType<float, 0> *X, float *pF) {
+ *(X->data + X->offset) = *pF;
+}
+
+extern "C" void linalg_fill_viewxf32_f32_impl(ViewType<float, 1> *X,
+ float *pF) {
+ float f = *pF;
+ for (unsigned i = 0; i < X->sizes[0]; ++i) {
+ *(X->data + X->offset + i * X->strides[0]) = f;
+ }
+}
+
+extern "C" void linalg_dot_viewxf32_viewxf32_viewf32_impl(
+ ViewType<float, 1> *X, ViewType<float, 1> *Y, ViewType<float, 0> *Z) {
+ assert(X->strides[0] == 1);
+ assert(Y->strides[0] == 1);
assert(X->sizes[0] == Y->sizes[0] && "Expected X and Y of same size");
*(Z->data + Z->offset) +=
cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0],
Y->data + Y->offset, Y->strides[0]);
}
-extern "C" void linalg_matmul_impl(ViewType<float, 2> *A, ViewType<float, 2> *B,
- ViewType<float, 2> *C) {
+extern "C" void linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl(
+ ViewType<float, 2> *A, ViewType<float, 2> *B, ViewType<float, 2> *C) {
assert(A->strides[1] == B->strides[1]);
assert(A->strides[1] == C->strides[1]);
assert(A->strides[1] == 1);
// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
-func @fill_f32(%arg0 : !linalg.buffer<?xf32>, %f : f32) {
+// Creates and returns a 1-D buffer of size %s filled with the value %f
+func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
- %s = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
+ %buf = linalg.buffer_alloc %s : !linalg.buffer<?xf32>
%R = linalg.range %c0:%s:%c1 : !linalg.range
- %V = linalg.view %arg0[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
- loop.for %i0 = %c0 to %s step %c1 {
- linalg.store %f, %V[%i0] : !linalg.view<?xf32>
- }
- return
-}
-
-func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<?xf32> {
- %A = linalg.buffer_alloc %s : !linalg.buffer<?xf32>
- call @fill_f32(%A, %f) : (!linalg.buffer<?xf32>, f32) -> ()
- return %A : !linalg.buffer<?xf32>
+ %V = linalg.view %buf[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
+ linalg.fill(%V, %f) : !linalg.view<?xf32>, f32
+ return %buf : !linalg.buffer<?xf32>
}
+// Test for linalg.dot.
func @dot() -> f32 {
%c0 = constant 0 : index
%c1 = constant 1 : index
return %res : f32
}
+// Test for linalg.matmul.
func @matmul() -> f32 {
%c0 = constant 0 : index
%c1 = constant 1 : index
return %res : f32
}
-
// All tests return this value
// CHECK: 4.2{{0+}}e+01