External library name mangling support for linalg.
authorNicolas Vasilache <ntv@google.com>
Fri, 9 Aug 2019 14:33:34 +0000 (07:33 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 9 Aug 2019 14:33:58 +0000 (07:33 -0700)
This CL introduces the ability to generate the external library name for Linalg operations.
The problem is that neither mlir or C support overloading and we want a simplified form of name mangling that is still reasonable to read.
This CL creates the name of the external call that Linalg expects from the operation name and the type of its arguments.

The interface library names are updated and use new cases are added for FillOp.

PiperOrigin-RevId: 262556833

mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td
mlir/include/mlir/Linalg/IR/LinalgOps.h
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/test/Linalg/llvm.mlir
mlir/test/mlir-cpu-runner/cblas_interface.cpp
mlir/test/mlir-cpu-runner/linalg_integration_test.mlir

index 547a2c4..998d68b 100644 (file)
@@ -80,10 +80,9 @@ class LinalgLibraryBase_Op<string mnemonic, list<OpTrait> props>
 
 class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
   : LinalgLibraryBase_Op<mnemonic, props> {
-
-  code classDeclaration = [{
-    StringRef getLibraryCallName() {
-      return "linalg_}] # mnemonic # [{";
+  code libraryCallName = [{
+    std::string getLibraryCallName() {
+      return generateLibraryCallName(getOperation());
     }
   }];
 }
@@ -138,7 +137,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
     return build(
       builder, result, input, output, AffineMapAttr(), AffineMapAttr());
   }]>];
-  let extraClassDeclaration = classDeclaration # [{
+  let extraClassDeclaration = libraryCallName # [{
     unsigned getNumParallelLoops() {
       auto *view = *(getOperands().begin());
       return view->getType().cast<ViewType>().getRank();
@@ -151,7 +150,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
 
 def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> {
   let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>);
-  let extraClassDeclaration = classDeclaration # [{
+  let extraClassDeclaration = libraryCallName # [{
     unsigned getNumParallelLoops() {
       auto *view = *(getOperands().begin());
       return view->getType().cast<ViewType>().getRank();
@@ -170,7 +169,7 @@ def DotOp : LinalgLibrary_Op<"dot",
                              NLoopTypes<0, 1, 0>,
                              ViewRanks<[1, 1, 0]>]> {
   let arguments = (ins View, View, View);
-  let extraClassDeclaration = classDeclaration;
+  let extraClassDeclaration = libraryCallName;
 }
 
 def MatvecOp : LinalgLibrary_Op<"matvec",
@@ -178,7 +177,7 @@ def MatvecOp : LinalgLibrary_Op<"matvec",
                                    NLoopTypes<1, 1, 0>,
                                    ViewRanks<[2, 1, 1]>]> {
   let arguments = (ins View, View, View);
-  let extraClassDeclaration = classDeclaration;
+  let extraClassDeclaration = libraryCallName;
 }
 
 def MatmulOp : LinalgLibrary_Op<"matmul",
@@ -186,7 +185,7 @@ def MatmulOp : LinalgLibrary_Op<"matmul",
                                    NLoopTypes<2, 1, 0>,
                                    ViewRanks<[2, 2, 2]>]> {
   let arguments = (ins View, View, View);
-  let extraClassDeclaration = classDeclaration;
+  let extraClassDeclaration = libraryCallName;
 }
 
 def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
@@ -211,7 +210,7 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
   let arguments = (ins View:$filter, View:$input, View:$output,
                    OptionalAttr<I64ArrayAttr>:$strides,
                    OptionalAttr<I64ArrayAttr>:$dilations);
-  let extraClassDeclaration = classDeclaration # [{
+  let extraClassDeclaration = libraryCallName # [{
     // TODO(ntv) extend to support more than 1 dimensions and potentially
     // grouping too.
     unsigned getNumBatchDimensions() { return 1; }
index 4085d06..3187f4f 100644 (file)
@@ -186,6 +186,28 @@ public:
   }
 };
 
+/// Returns the name mangled library call name to disambiguate between different
+/// overloads at the C level. The name mangling scheme is basic and uses MLIR
+/// type names:
+///   1. form a string which is the concatenation of the linalg op name with all
+///      the operand type names, separate by underscores;
+///   2. drop the `linalg.` prefix, and the `<`, `>`, `?` symbols from the type.
+/// Assumes `op` is a LinalgOp.
+///
+/// Examples:
+///
+/// 1. linalg.fill(%A, %f) : !linalg.view<f32>, f32
+///   name mangles into `linalg_fill_viewf32_f32_impl`
+///
+/// 2. linalg.dot(%A, %B, %C) :
+///      !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
+///   name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl`
+///
+/// 3. linalg.matmul(...) :
+///      !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+///   name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl`
+std::string generateLibraryCallName(Operation *op);
+
 #define GET_OP_CLASSES
 #include "mlir/Linalg/IR/LinalgOps.h.inc"
 
index 6549508..bce2b32 100644 (file)
@@ -37,6 +37,7 @@
 #include "mlir/Transforms/FoldUtils.h"
 
 #include "llvm/ADT/StringSet.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
 using namespace mlir::edsc;
@@ -1085,3 +1086,35 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
   }
   llvm_unreachable("Missing loopToOperandRangesMaps for op");
 }
+
+static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
+  if (auto view = t.dyn_cast<ViewType>()) {
+    ss << "view";
+    for (unsigned i = 0, e = view.getRank(); i < e; ++i)
+      ss << "x";
+    appendMangledType(ss, view.getElementType());
+  } else if (auto vec = t.dyn_cast<VectorType>()) {
+    ss << "vector";
+    interleave(
+        vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
+    appendMangledType(ss, vec.getElementType());
+  } else if (t.isIntOrIndexOrFloat()) {
+    ss << t;
+  } else {
+    llvm_unreachable("Invalid type for linalg library name mangling");
+  }
+}
+
+std::string mlir::linalg::generateLibraryCallName(Operation *op) {
+  assert(isa<LinalgOp>(op));
+  std::string name(op->getName().getStringRef().str());
+  name.reserve(128);
+  std::replace(name.begin(), name.end(), '.', '_');
+  llvm::raw_string_ostream ss(name);
+  ss << "_";
+  auto types = op->getOperandTypes();
+  interleave(
+      types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
+      [&]() { ss << "_"; });
+  return ss.str();
+}
index 6967a9d..a45f943 100644 (file)
@@ -39,6 +39,8 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/LowerAffine.h"
 #include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/SetVector.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
@@ -545,7 +547,7 @@ static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) {
   }
   SmallVector<Type, 4> fnArgTypes;
   for (auto t : libFn.getType().getInputs()) {
-    assert(t.isa<LLVMType>() &&
+    assert(t && t.isa<LLVMType>() &&
            "Expected LLVM Type for argument while generating library Call "
            "Implementation Definition");
     fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
@@ -577,12 +579,8 @@ getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering,
 
   // Get the Function type consistent with LLVM Lowering.
   SmallVector<Type, 4> inputTypes;
-  for (auto operand : op->getOperands()) {
-    // TODO(ravishankarm): convertLinalgType handles only a subset of Linalg
-    // types. Handle other types (as well as non-Linalg types) either here or in
-    // convertLinalgType.
-    inputTypes.push_back(convertLinalgType(operand->getType(), lowering));
-  }
+  for (auto operand : op->getOperands())
+    inputTypes.push_back(lowering.convertType(operand->getType()));
   assert(op->getNumResults() == 0 &&
          "Library call for linalg operation can be generated only for ops that "
          "have void return types");
@@ -632,15 +630,15 @@ public:
     return convertLinalgType(t, *this);
   }
 
-  void addLibraryFnDeclaration(FuncOp fn) {
-    libraryFnDeclarations.push_back(fn);
-  }
+  void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.insert(fn); }
 
-  ArrayRef<FuncOp> getLibraryFnDeclarations() { return libraryFnDeclarations; }
+  ArrayRef<FuncOp> getLibraryFnDeclarations() {
+    return libraryFnDeclarations.getArrayRef();
+  }
 
 private:
   /// List of library functions declarations needed during dialect conversion
-  SmallVector<FuncOp, 2> libraryFnDeclarations;
+  llvm::SetVector<FuncOp> libraryFnDeclarations;
 };
 } // end anonymous namespace
 
@@ -676,11 +674,13 @@ static void
 populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
                                        OwningRewritePatternList &patterns,
                                        MLIRContext *ctx) {
-  patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
-                  BufferSizeOpConversion, DimOpConversion,
-                  LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
-                  LoadOpConversion, RangeOpConversion, SliceOpConversion,
-                  StoreOpConversion, ViewOpConversion>(ctx, converter);
+  patterns
+      .insert<BufferAllocOpConversion, BufferDeallocOpConversion,
+              BufferSizeOpConversion, DimOpConversion,
+              LinalgOpConversion<DotOp>, LinalgOpConversion<FillOp>,
+              LinalgOpConversion<MatmulOp>, LoadOpConversion, RangeOpConversion,
+              SliceOpConversion, StoreOpConversion, ViewOpConversion>(
+          ctx, converter);
 }
 
 namespace {
index a1739d4..e82274a 100644 (file)
@@ -87,7 +87,7 @@ func @dot(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg
   return
 }
 // CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
-//       CHECK:   llvm.call @linalg_dot(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> ()
+//       CHECK:   llvm.call @linalg_dot_viewxf32_viewxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> ()
 
 func @dim(%arg0: !linalg.view<?x?xf32>) {
   %0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
index 1a63237..973c7f2 100644 (file)
@@ -36,16 +36,30 @@ template <typename T> struct ViewType<T, 0> {
   unsigned long offset;
 };
 
-extern "C" void linalg_dot_impl(ViewType<float, 1> *X, ViewType<float, 1> *Y,
-                                ViewType<float, 0> *Z) {
+extern "C" void linalg_fill_viewf32_f32_impl(ViewType<float, 0> *X, float *pF) {
+  *(X->data + X->offset) = *pF;
+}
+
+extern "C" void linalg_fill_viewxf32_f32_impl(ViewType<float, 1> *X,
+                                              float *pF) {
+  float f = *pF;
+  for (unsigned i = 0; i < X->sizes[0]; ++i) {
+    *(X->data + X->offset + i * X->strides[0]) = f;
+  }
+}
+
+extern "C" void linalg_dot_viewxf32_viewxf32_viewf32_impl(
+    ViewType<float, 1> *X, ViewType<float, 1> *Y, ViewType<float, 0> *Z) {
+  assert(X->strides[0] == 1);
+  assert(Y->strides[0] == 1);
   assert(X->sizes[0] == Y->sizes[0] && "Expected X and Y of same size");
   *(Z->data + Z->offset) +=
       cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0],
                  Y->data + Y->offset, Y->strides[0]);
 }
 
-extern "C" void linalg_matmul_impl(ViewType<float, 2> *A, ViewType<float, 2> *B,
-                                   ViewType<float, 2> *C) {
+extern "C" void linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl(
+    ViewType<float, 2> *A, ViewType<float, 2> *B, ViewType<float, 2> *C) {
   assert(A->strides[1] == B->strides[1]);
   assert(A->strides[1] == C->strides[1]);
   assert(A->strides[1] == 1);
index 2dc9748..76b28ab 100644 (file)
@@ -3,24 +3,18 @@
 // RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
 // RUN: mlir-opt %s -linalg-lower-to-loops -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
 
-func @fill_f32(%arg0 : !linalg.buffer<?xf32>, %f : f32) {
+// Creates and returns a 1-D buffer of size %s filled with the value %f
+func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<?xf32> {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
-  %s = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
+  %buf = linalg.buffer_alloc %s : !linalg.buffer<?xf32>
   %R = linalg.range %c0:%s:%c1 : !linalg.range
-  %V = linalg.view %arg0[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
-  loop.for %i0 = %c0 to %s step %c1 {
-    linalg.store %f, %V[%i0] : !linalg.view<?xf32>
-  }
-  return
-}
-
-func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<?xf32> {
-  %A = linalg.buffer_alloc %s : !linalg.buffer<?xf32>
-  call @fill_f32(%A, %f) : (!linalg.buffer<?xf32>, f32) -> ()
-  return %A : !linalg.buffer<?xf32>
+  %V = linalg.view %buf[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
+  linalg.fill(%V, %f) : !linalg.view<?xf32>, f32
+  return %buf : !linalg.buffer<?xf32>
 }
 
+// Test for linalg.dot.
 func @dot() -> f32 {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -48,6 +42,7 @@ func @dot() -> f32 {
   return %res : f32
 }
 
+// Test for linalg.matmul.
 func @matmul() -> f32 {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -82,6 +77,5 @@ func @matmul() -> f32 {
   return %res : f32
 }
 
-
 // All tests return this value
 // CHECK: 4.2{{0+}}e+01