Add a primitive linalg-lower-to-llvm-dialect pass
authorNicolas Vasilache <ntv@google.com>
Thu, 2 May 2019 18:36:52 +0000 (11:36 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:24:59 +0000 (08:24 -0700)
    This CL builds upon ftynse@'s Linalg dialect conversion (in examples/Linalg/Linalg1) and updates it to support buffers and the fully composed form of view and slice operations.
    A new BufferSizeOp is introduced for the purpose of extracting the size information from a buffer.
    This will be useful in a followup CL for an end-to-end LLVM execution path where mlir-cpu-runner will allocate a buffer.

--

PiperOrigin-RevId: 246358593

mlir/include/mlir/LLVMIR/Transforms.h
mlir/include/mlir/Linalg/IR/LinalgOps.td
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/CMakeLists.txt
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp [new file with mode: 0644]
mlir/test/Linalg/llvm.mlir [new file with mode: 0644]

index 87a7a0e..b021981 100644 (file)
 
 #include <memory>
 
+namespace llvm {
+class Module;
+} // namespace llvm
+
 namespace mlir {
 class DialectConversion;
 class Module;
 class ModulePassBase;
+class Type;
 
 /// Creates a pass to convert Standard dialects into the LLVMIR dialect.
 ModulePassBase *createConvertToLLVMIRPass();
@@ -39,6 +44,10 @@ namespace LLVM {
 /// another block as a successor more than once with different values, insert
 /// a new dummy block for LLVM PHI nodes to tell the sources apart.
 void ensureDistinctSuccessors(Module *m);
+
+/// Converts a type in either MLIR standard or builtin type into LLVMIR dialect
+/// type.
+Type convertToLLVMDialectType(Type t, llvm::Module &llvmModule);
 } // namespace LLVM
 
 } // namespace mlir
index 958d879..fd07f6c 100644 (file)
@@ -31,6 +31,10 @@ def Linalg_Dialect : Dialect {
   let name = "linalg";
 }
 
+// Whether a type is a BufferType.
+def LinalgIsBufferTypePred : CPred<"$_self.isa<BufferType>()">;
+def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
+
 // Whether a type is a ViewType.
 def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
 def View : Type<LinalgIsViewTypePred, "view">;
@@ -84,6 +88,20 @@ Op<Linalg_Dialect, mnemonic, props> {
   let printer = [{ impl::printLinalgLibraryOp(p, *this); }];
 }
 
+def BufferSizeOp :
+    Op<Linalg_Dialect, "buffer_size", [NoSideEffect]>,
+    Arguments<(ins Buffer)>,
+    Results<(outs Index)>
+{
+  let parser = [{
+    return impl::parseBufferSizeOp(parser, result);
+  }];
+
+  let printer = [{
+    return impl::printBufferSizeOp(p, this->getOperation());
+  }];
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Concrete Linalg ops.
 ////////////////////////////////////////////////////////////////////////////////
index 2771fd4..d103430 100644 (file)
@@ -1169,3 +1169,7 @@ std::unique_ptr<DialectConversion> mlir::createStdToLLVMConverter() {
 
 static PassRegistration<LLVMLoweringPass>
     pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");
+
+Type mlir::LLVM::convertToLLVMDialectType(Type t, llvm::Module &llvmModule) {
+  return TypeConverter::convert(t, llvmModule);
+}
index e048d50..2488c44 100644 (file)
@@ -2,6 +2,7 @@ add_llvm_library(MLIRLinalg
   LinalgRegistration.cpp
   IR/LinalgOps.cpp
   IR/LinalgTypes.cpp
+  Transforms/LowerToLLVMDialect.cpp
   Transforms/Tiling.cpp
   Utils/Utils.cpp
 
index 84e1d44..d8f9b13 100644 (file)
@@ -355,8 +355,35 @@ namespace mlir {
 namespace impl {
 void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op);
 bool parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
+void printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op);
+bool parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
 } // namespace impl
 
+/// Buffer size prints as:
+///
+/// ``` {.mlir}
+///    %0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
+/// ```
+void mlir::impl::printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op) {
+  assert(op->getAbstractOperation() && "unregistered operation");
+  *p << op->cast<BufferSizeOp>().getOperationName() << " "
+     << *op->getOperand(0);
+  p->printOptionalAttrDict(op->getAttrs());
+  *p << " : " << op->getOperand(0)->getType();
+}
+
+bool mlir::impl::parseBufferSizeOp(OpAsmParser *parser,
+                                   OperationState *result) {
+  OpAsmParser::OperandType op;
+  Type type;
+  return parser->parseOperand(op) ||
+         parser->parseOptionalAttributeDict(result->attributes) ||
+         parser->parseColonType(type) ||
+         parser->resolveOperand(op, type, result->operands) ||
+         parser->addTypeToList(parser->getBuilder().getIndexType(),
+                               result->types);
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Linalg/IR/LinalgOps.cpp.inc"
 
diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
new file mode 100644 (file)
index 0000000..d399f6f
--- /dev/null
@@ -0,0 +1,466 @@
+//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/LLVMIR/Transforms.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::LLVM;
+
+using undef = ValueBuilder<mlir::LLVM::UndefOp>;
+using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
+using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>;
+using constant = ValueBuilder<mlir::LLVM::ConstantOp>;
+using add = ValueBuilder<mlir::LLVM::AddOp>;
+using sub = ValueBuilder<mlir::LLVM::SubOp>;
+using mul = ValueBuilder<mlir::LLVM::MulOp>;
+
+static llvm::Module *getLLVMModule(MLIRContext *context) {
+  auto *llvmDialect =
+      static_cast<LLVM::LLVMDialect *>(context->getRegisteredDialect("llvm"));
+  if (!llvmDialect) {
+    context->emitError(UnknownLoc::get(context),
+                       "LLVM IR dialect is not registered");
+    return nullptr;
+  }
+  return &llvmDialect->getLLVMModule();
+}
+
+template <typename T>
+static llvm::Type *getPtrToElementType(T containerType,
+                                       llvm::Module &llvmModule) {
+  return convertToLLVMDialectType(containerType.getElementType(), llvmModule)
+      .template cast<LLVMType>()
+      .getUnderlyingType()
+      ->getPointerTo();
+}
+
+// Convert the given type to the LLVM IR Dialect type.  The following
+// conversions are supported:
+//   - an Index type is converted into an LLVM integer type with pointer
+//     bitwidth (analogous to intptr_t in C);
+//   - an Integer type is converted into an LLVM integer type of the same width;
+//   - an F32 type is converted into an LLVM float type
+//   - a Buffer, Range or View is converted into an LLVM structure type
+//     containing the respective dynamic values.
+static Type convertLinalgType(Type t, llvm::Module &llvmModule) {
+  auto *context = t.getContext();
+  auto *int64Ty = llvm::Type::getInt64Ty(llvmModule.getContext());
+
+  // A buffer descriptor contains the pointer to a flat region of storage and
+  // the size of the region.
+  //
+  // template <typename Elem, size_t Rank>
+  // struct {
+  //   Elem *ptr;
+  //   int64_t size;
+  // };
+  if (auto bufferTy = t.dyn_cast<BufferType>()) {
+    auto *ptrTy = getPtrToElementType(bufferTy, llvmModule);
+    auto *structTy = llvm::StructType::get(ptrTy, int64Ty);
+    return LLVMType::get(context, structTy);
+  }
+
+  // Range descriptor contains the range bounds and the step as 64-bit integers.
+  //
+  // struct {
+  //   int64_t min;
+  //   int64_t max;
+  //   int64_t step;
+  // };
+  if (auto rangeTy = t.dyn_cast<RangeType>()) {
+    auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
+    return LLVMType::get(context, structTy);
+  }
+
+  // View descriptor contains the pointer to the data buffer, followed by a
+  // 64-bit integer containing the distance between the beginning of the buffer
+  // and the first element to be accessed through the view, followed by two
+  // arrays, each containing as many 64-bit integers as the rank of the View.
+  // The first array represents the size, in number of original elements, of the
+  // view along the given dimension.  When taking the view, the size is the
+  // difference between the upper and the lower bound of the range.  The second
+  // array represents the "stride" (in tensor abstraction sense), i.e. the
+  // number of consecutive elements of the underlying buffer that separate two
+  // consecutive elements addressable through the view along the given
+  // dimension.  When taking the view, the strides are constructed as products
+  // of the original sizes along the trailing dimensions, multiplied by the view
+  // step.  For example, a view of a MxN memref with ranges {0:M:1}, {0:N:1},
+  // i.e. the view of a complete memref, will have strides N and 1.  A view with
+  // ranges {0:M:2}, {0:N:3} will have strides 2*N and 3.
+  //
+  // template <typename Elem, size_t Rank>
+  // struct {
+  //   Elem *ptr;
+  //   int64_t offset;
+  //   int64_t sizes[Rank];
+  //   int64_t strides[Rank];
+  // };
+  if (auto viewTy = t.dyn_cast<ViewType>()) {
+    auto *ptrTy = getPtrToElementType(viewTy, llvmModule);
+    auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
+    auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy);
+    return LLVMType::get(context, structTy);
+  }
+
+  return Type();
+}
+
+// Create an array attribute containing integer attributes with values provided
+// in `position`.
+static ArrayAttr makePositionAttr(FuncBuilder &builder,
+                                  ArrayRef<int> position) {
+  SmallVector<Attribute, 4> attrs;
+  attrs.reserve(position.size());
+  for (auto p : position)
+    attrs.push_back(builder.getI64IntegerAttr(p));
+  return builder.getArrayAttr(attrs);
+}
+
+// BufferSizeOp creates a new `index` value.
+class BufferSizeOpConversion : public DialectOpConversion {
+public:
+  explicit BufferSizeOpConversion(MLIRContext *context)
+      : DialectOpConversion(BufferSizeOp::getOperationName(), 1, context),
+        llvmModule(*getLLVMModule(context)) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto bufferSizeType =
+        convertToLLVMDialectType(operands[0]->getType(), llvmModule);
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    return {extractvalue(bufferSizeType, operands[0],
+                         makePositionAttr(rewriter, 1))};
+  }
+
+  llvm::Module &llvmModule;
+};
+
+// RangeOp creates a new range descriptor.
+class RangeOpConversion : public DialectOpConversion {
+public:
+  explicit RangeOpConversion(MLIRContext *context)
+      : DialectOpConversion(RangeOp::getOperationName(), 1, context),
+        llvmModule(*getLLVMModule(context)) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto rangeOp = op->cast<RangeOp>();
+    auto rangeDescriptorType =
+        convertLinalgType(rangeOp.getResult()->getType(), llvmModule);
+
+    edsc::ScopedContext context(rewriter, op->getLoc());
+
+    // Fill in an aggregate value of the descriptor.
+    Value *desc = undef(rangeDescriptorType);
+    desc = insertvalue(rangeDescriptorType, desc, operands[0],
+                       makePositionAttr(rewriter, 0));
+    desc = insertvalue(rangeDescriptorType, desc, operands[1],
+                       makePositionAttr(rewriter, 1));
+    desc = insertvalue(rangeDescriptorType, desc, operands[2],
+                       makePositionAttr(rewriter, 2));
+
+    return {desc};
+  }
+
+  llvm::Module &llvmModule;
+};
+
+class SliceOpConversion : public DialectOpConversion {
+public:
+  explicit SliceOpConversion(MLIRContext *context)
+      : DialectOpConversion(SliceOp::getOperationName(), 1, context),
+        llvmModule(*getLLVMModule(context)) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto sliceOp = op->cast<SliceOp>();
+    auto viewDescriptorType =
+        convertLinalgType(sliceOp.getViewType(), llvmModule);
+    auto viewType = sliceOp.getBaseViewType();
+    auto int64Ty =
+        convertToLLVMDialectType(rewriter.getIntegerType(64), llvmModule);
+
+    // Helper function to create an integer array attribute out of a list of
+    // values.
+    auto pos = [&rewriter](ArrayRef<int> values) {
+      return makePositionAttr(rewriter, values);
+    };
+    // Helper function to obtain the ptr of the given `view`.
+    auto getViewPtr = [pos, &rewriter, this](ViewType type,
+                                             Value *view) -> Value * {
+      auto elementPtrTy =
+          rewriter.getType<LLVMType>(getPtrToElementType(type, llvmModule));
+      return extractvalue(elementPtrTy, view, pos(0));
+    };
+
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    // Declare the view descriptor and insert data ptr.
+    Value *desc = undef(viewDescriptorType);
+    desc = insertvalue(viewDescriptorType, desc,
+                       getViewPtr(viewType, operands[0]), pos(0));
+
+    // TODO(ntv): extract sizes and emit asserts.
+    SmallVector<Value *, 4> strides(viewType.getRank());
+    for (int dim = 0, e = viewType.getRank(); dim < e; ++dim) {
+      strides[dim] = extractvalue(int64Ty, operands[0], pos({3, dim}));
+    }
+
+    // Compute and insert base offset.
+    Value *baseOffset = extractvalue(int64Ty, operands[0], pos(1));
+    for (int j = 0, e = viewType.getRank(); j < e; ++j) {
+      Value *indexing = operands[1 + j];
+      Value *min =
+          sliceOp.getIndexing(j)->getType().isa<RangeType>()
+              ? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0)))
+              : indexing;
+      Value *product = mul(min, strides[j]);
+      baseOffset = add(baseOffset, product);
+    }
+    desc = insertvalue(viewDescriptorType, desc, baseOffset, pos(1));
+
+    // Compute and insert view sizes (max - min along the range).  Skip the
+    // non-range operands as they will be projected away from the view.
+    int i = 0;
+    for (Value *index : sliceOp.getIndexings()) {
+      if (!index->getType().isa<RangeType>())
+        continue;
+
+      Value *rangeDescriptor = operands[1 + i];
+      Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+      Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+      Value *size = sub(max, min);
+
+      desc = insertvalue(viewDescriptorType, desc, size, pos({2, i}));
+      ++i;
+    }
+
+    // Compute and insert view strides.  Step over the strides that correspond
+    // to non-range operands as they are projected away from the view.
+    i = 0;
+    for (int j = 0, e = strides.size(); j < e; ++j) {
+      if (!sliceOp.getIndexing(j)->getType().isa<RangeType>())
+        continue;
+      Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
+      Value *stride = mul(strides[j], step);
+      desc = insertvalue(viewDescriptorType, desc, stride, pos({3, i}));
+      ++i;
+    }
+
+    return {desc};
+  }
+
+  llvm::Module &llvmModule;
+};
+
+class ViewOpConversion : public DialectOpConversion {
+public:
+  explicit ViewOpConversion(MLIRContext *context)
+      : DialectOpConversion(ViewOp::getOperationName(), 1, context),
+        llvmModule(*getLLVMModule(context)) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto viewOp = op->cast<ViewOp>();
+    auto viewDescriptorType =
+        convertLinalgType(viewOp.getViewType(), llvmModule);
+    auto elementType = rewriter.getType<LLVMType>(
+        getPtrToElementType(viewOp.getViewType(), llvmModule));
+    auto int64Ty =
+        convertToLLVMDialectType(rewriter.getIntegerType(64), llvmModule);
+
+    auto pos = [&rewriter](ArrayRef<int> values) {
+      return makePositionAttr(rewriter, values);
+    };
+
+    // First operand to `view` is the buffer descriptor.
+    Value *bufferDescriptor = operands[0];
+
+    // Declare the descriptor of the view.
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    Value *desc = undef(viewDescriptorType);
+
+    // Copy the buffer pointer from the old descriptor to the new one.
+    Value *buffer = extractvalue(elementType, bufferDescriptor, pos(0));
+    desc = insertvalue(viewDescriptorType, desc, buffer, pos(0));
+
+    // Zero base offset.
+    auto indexTy = rewriter.getIndexType();
+    Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
+    desc = insertvalue(viewDescriptorType, desc, baseOffset, pos(1));
+
+    // Compute and insert view sizes (max - min along the range).
+    int numIndexings = llvm::size(viewOp.getIndexings());
+    Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
+    for (int i = 0; i < numIndexings; ++i) {
+      // Update stride.
+      Value *rangeDescriptor = operands[1 + i];
+      Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+      Value *stride = mul(runningStride, step);
+      desc = insertvalue(viewDescriptorType, desc, stride, pos({3, i}));
+      // Update size.
+      Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+      Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+      Value *size = sub(max, min);
+      desc = insertvalue(viewDescriptorType, desc, size, pos({2, i}));
+      ++i;
+      // Update stride for the next dimension.
+      if (i < numIndexings - 1)
+        runningStride = mul(runningStride, max);
+    }
+
+    return {desc};
+  }
+
+  llvm::Module &llvmModule;
+};
+
+// DotOp creates a new range descriptor.
+class DotOpConversion : public DialectOpConversion {
+public:
+  explicit DotOpConversion(MLIRContext *context)
+      : DialectOpConversion(DotOp::getOperationName(), 1, context) {}
+
+  static StringRef libraryFunctionName() { return "linalg_dot"; }
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto *f =
+        op->getFunction()->getModule()->getNamedFunction(libraryFunctionName());
+    if (!f)
+      op->emitError("Could not find function: " + libraryFunctionName() +
+                    "in lowering to LLVM ");
+    auto fAttr = rewriter.getFunctionAttr(f);
+    auto named = rewriter.getNamedAttr("callee", fAttr);
+    rewriter.create<LLVM::CallOp>(op->getLoc(), operands, ArrayRef<NamedAttribute>{named});
+    return {};
+  }
+};
+
+llvm::DenseSet<mlir::DialectOpConversion *>
+allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator,
+                             mlir::MLIRContext *context) {
+  return ConversionListBuilder<BufferSizeOpConversion, DotOpConversion,
+                               RangeOpConversion, SliceOpConversion,
+                               ViewOpConversion>::build(allocator, context);
+}
+
+namespace {
+// The conversion class from Linalg to LLVMIR.
+class Lowering : public DialectConversion {
+public:
+  explicit Lowering(std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
+                        llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
+                        conversions)
+      : setup(conversions) {}
+
+  Lowering &setLLVMModule(MLIRContext *context) {
+    llvmModule = getLLVMModule(context);
+    return *this;
+  }
+
+protected:
+  // Initialize the list of converters.
+  llvm::DenseSet<DialectOpConversion *>
+  initConverters(MLIRContext *context) override {
+    converterStorage.Reset();
+    return setup(&converterStorage, context);
+  }
+
+  // This gets called for block and region arguments, and attributes.
+  Type convertType(Type t) override {
+    if (auto res = convertLinalgType(t, *llvmModule))
+      return res;
+    return convertToLLVMDialectType(t, *llvmModule);
+  }
+
+private:
+  // Storage for individual converters.
+  llvm::BumpPtrAllocator converterStorage;
+
+  // Conversion setup.
+  std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
+      llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
+      setup;
+
+  llvm::Module *llvmModule;
+};
+} // end anonymous namespace
+
+std::unique_ptr<mlir::DialectConversion> makeLinalgToLLVMLowering(
+    std::function<llvm::DenseSet<mlir::DialectOpConversion *>(
+        llvm::BumpPtrAllocator *, mlir::MLIRContext *context)>
+        initer) {
+  return llvm::make_unique<Lowering>(initer);
+}
+
+namespace {
+struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
+  void runOnModule();
+};
+} // namespace
+
+void LowerLinalgToLLVMPass::runOnModule() {
+  auto &module = getModule();
+
+  // Convert Linalg ops to the LLVM IR dialect using the converter defined
+  // above.
+  auto r = Lowering(allocateDescriptorConverters)
+               .setLLVMModule(module.getContext())
+               .convert(&module);
+  if (failed(r))
+    signalPassFailure();
+
+  // Convert the remaining standard MLIR operations to the LLVM IR dialect using
+  // the default converter.
+  auto converter = createStdToLLVMConverter();
+  r = converter->convert(&module);
+  if (failed(r))
+    signalPassFailure();
+}
+
+ModulePassBase *createLowerLinalgToLLVMPass() {
+  return new LowerLinalgToLLVMPass();
+}
+
+static PassRegistration<LowerLinalgToLLVMPass>
+    pass("linalg-lower-to-llvm-dialect",
+         "Lower the operations from the linalg dialect into the LLVM dialect");
diff --git a/mlir/test/Linalg/llvm.mlir b/mlir/test/Linalg/llvm.mlir
new file mode 100644 (file)
index 0000000..3213342
--- /dev/null
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | FileCheck %s
+
+func @buffer_size(%arg0: !linalg.buffer<f32>) {
+  %s = linalg.buffer_size %arg0 : !linalg.buffer<f32>
+  return
+}
+// CHECK-LABEL: func @buffer_size(%arg0: !llvm<"{ float*, i64 }">) {
+//       CHECK:   %0 = llvm.extractvalue %arg0[1] : !llvm<"{ float*, i64 }">
+
+func @range(%arg0: index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %R = linalg.range %c0:%arg0:%c1 : !linalg.range
+  return
+}
+// CHECK-LABEL: func @range(%arg0: !llvm.i64) {
+//       CHECK:   %0 = llvm.constant(0 : index) : !llvm.i64
+//  CHECK-NEXT:   %1 = llvm.constant(1 : index) : !llvm.i64
+//  CHECK-NEXT:   %2 = llvm.undef : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %3 = llvm.insertvalue %0, %2[0] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %4 = llvm.insertvalue %arg0, %3[1] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %5 = llvm.insertvalue %1, %4[2] : !llvm<"{ i64, i64, i64 }">
+
+func @view(%arg0: !linalg.buffer<f32>, %arg1: !linalg.range) {
+  %0 = linalg.view %arg0[%arg1] : !linalg.view<?xf32>
+  return
+}
+// CHECK-LABEL: func @view(%arg0: !llvm<"{ float*, i64 }">, %arg1: !llvm<"{ i64, i64, i64 }">) {
+//       CHECK:   %0 = llvm.undef : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %1 = llvm.extractvalue %arg0[0] : !llvm<"{ float*, i64 }">
+//  CHECK-NEXT:   %2 = llvm.insertvalue %1, %0[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %3 = llvm.constant(0 : index) : !llvm.i64
+//  CHECK-NEXT:   %4 = llvm.insertvalue %3, %2[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %5 = llvm.constant(1 : index) : !llvm.i64
+//  CHECK-NEXT:   %6 = llvm.extractvalue %arg1[2] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %7 = llvm.mul %5, %6 : !llvm.i64
+//  CHECK-NEXT:   %8 = llvm.insertvalue %7, %4[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %9 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %10 = llvm.extractvalue %arg1[1] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %11 = llvm.sub %10, %9 : !llvm.i64
+//  CHECK-NEXT:   %12 = llvm.insertvalue %11, %8[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+
+func @slice(%arg0: !linalg.buffer<f32>, %arg1: !linalg.range) {
+  %0 = linalg.view %arg0[%arg1] : !linalg.view<?xf32>
+  %1 = linalg.slice %0[%arg1] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
+  return
+}
+// CHECK-LABEL: func @slice(%arg0: !llvm<"{ float*, i64 }">, %arg1: !llvm<"{ i64, i64, i64 }">) {
+//       CHECK:   %13 = llvm.undef : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %14 = llvm.extractvalue %12[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %15 = llvm.insertvalue %14, %13[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %16 = llvm.extractvalue %12[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %17 = llvm.extractvalue %12[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %18 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %19 = llvm.mul %18, %16 : !llvm.i64
+//  CHECK-NEXT:   %20 = llvm.add %17, %19 : !llvm.i64
+//  CHECK-NEXT:   %21 = llvm.insertvalue %20, %15[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %22 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %23 = llvm.extractvalue %arg1[1] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %24 = llvm.sub %23, %22 : !llvm.i64
+//  CHECK-NEXT:   %25 = llvm.insertvalue %24, %21[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//  CHECK-NEXT:   %26 = llvm.extractvalue %arg1[2] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %27 = llvm.mul %16, %26 : !llvm.i64
+//  CHECK-NEXT:   %28 = llvm.insertvalue %27, %25[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+
+func @linalg_dot(!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
+                 !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
+                 !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">)
+
+func @dot(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<f32>) {
+  linalg.dot(%arg0, %arg1, %arg2) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
+  return
+}
+// CHECK-LABEL: func @dot(%arg0: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg1: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg2: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
+//       CHECK:   llvm.call @linalg_dot(%arg0, %arg1, %arg2) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> ()