Pipe Linalg to LLVM via mlir-cpu-runner
authorNicolas Vasilache <ntv@google.com>
Thu, 9 May 2019 19:34:04 +0000 (12:34 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 11 May 2019 02:26:18 +0000 (19:26 -0700)
    This CL adds support for functions in the Linalg dialect to run with mlir-cpu-runner.
    For this purpose, this CL adds BufferAllocOp, BufferDeallocOp, LoadOp and StoreOp to the Linalg dialect as well as their lowering to LLVM. To avoid collisions with mlir::LoadOp/StoreOp (which should really become mlir::affine::LoadOp/StoreOp), the mlir::linalg namespace is added.

    The execution uses a dummy linalg_dot function that just returns for now. In the future a proper library call will be used.

--

PiperOrigin-RevId: 247476061

14 files changed:
mlir/include/mlir/Linalg/IR/LinalgOps.h
mlir/include/mlir/Linalg/IR/LinalgOps.td
mlir/include/mlir/Linalg/IR/LinalgTraits.h
mlir/include/mlir/Linalg/IR/LinalgTypes.h
mlir/include/mlir/Linalg/Passes.h
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/IR/LinalgTypes.cpp
mlir/lib/Linalg/LinalgRegistration.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/Tiling.cpp
mlir/lib/Linalg/Utils/Utils.cpp
mlir/test/mlir-cpu-runner/simple_linalg.mlir [new file with mode: 0644]
mlir/tools/mlir-cpu-runner/CMakeLists.txt
mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp

index 9472c71..f468b96 100644 (file)
@@ -24,6 +24,7 @@
 #include "mlir/Support/LLVM.h"
 
 namespace mlir {
+namespace linalg {
 
 /// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
 /// upon which a base view can be laid out to give it indexing semantics.
@@ -77,6 +78,35 @@ public:
   }
 };
 
+/// A linalg.LoadOp is the counterpart of load but operating on ViewType
+/// instead of MemRefType.
+///
+/// ```{.mlir}
+///    %0 = linalg.load %V[%c0] : !linalg.view<?xf32>
+/// ```
+class LoadOp
+    : public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
+public:
+  friend Operation;
+  using Op::Op;
+
+  // Hooks to customize the behavior of this op.
+  static llvm::StringRef getOperationName() { return "linalg.load"; }
+  static void build(Builder *b, OperationState *result, Value *view,
+                    ArrayRef<Value *> indices = {});
+  LogicalResult verify();
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  // Op-specific functionality.
+  unsigned getRank() { return getViewType().getRank(); }
+  ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
+  Value *getView() { return getOperand(0); }
+  Operation::operand_range getIndices() {
+    return {operand_begin() + 1, operand_end()};
+  }
+};
+
 /// The "linalg.range" op creates a linalg.range from 3 values of type `index`
 /// that represent the min, max and step values of the range.
 ///
@@ -142,9 +172,8 @@ public:
 ///    !linalg.view<f32>
 /// ```
 class ViewOp;
-class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::VariadicOperands,
-                                mlir::OpTrait::OneResult,
-                                mlir::OpTrait::HasNoSideEffect> {
+class SliceOp : public Op<SliceOp, OpTrait::VariadicOperands,
+                          OpTrait::OneResult, OpTrait::HasNoSideEffect> {
   enum { FirstIndexingOperand = 1 };
 
 public:
@@ -153,33 +182,60 @@ public:
 
   // Hooks to customize the behavior of this op.
   static llvm::StringRef getOperationName() { return "linalg.slice"; }
-  static void build(mlir::Builder *b, mlir::OperationState *result,
-                    mlir::Value *base, llvm::ArrayRef<mlir::Value *> indexings);
-  mlir::LogicalResult verify();
-  static ParseResult parse(mlir::OpAsmParser *parser,
-                           mlir::OperationState *result);
-  void print(mlir::OpAsmPrinter *p);
+  static void build(Builder *b, OperationState *result, Value *base,
+                    llvm::ArrayRef<Value *> indexings);
+  LogicalResult verify();
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
 
   // Op-specific functionality.
   unsigned getRank() { return getViewType().getRank(); }
-  mlir::Type getElementType() { return getViewType().getElementType(); }
+  Type getElementType() { return getViewType().getElementType(); }
   ViewType getViewType() { return getType().cast<ViewType>(); }
   Value *getBaseView() { return getOperand(0); }
   ViewOp getBaseViewOp();
   ViewType getBaseViewType();
   unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
   // Get the underlying indexing at a given rank.
-  mlir::Value *getIndexing(unsigned rank) {
-    return *(getIndexings().begin() + rank);
-  }
+  Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
   // Get all the indexings in this view.
-  mlir::Operation::operand_range getIndexings() {
+  Operation::operand_range getIndexings() {
     return {operand_begin() + SliceOp::FirstIndexingOperand, operand_end()};
   }
   // Get the subset of indexings that are of RangeType.
   SmallVector<Value *, 8> getRanges();
 };
 
+/// A linalg.StoreOp is the counterpart of affine.store but operating on
+/// ViewType instead of MemRefType.
+///
+/// ```{.mlir}
+///    linalg.store %f, %V[%c0] : !linalg.view<?xf32>
+/// ```
+class StoreOp
+    : public Op<StoreOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+public:
+  friend Operation;
+  using Op::Op;
+
+  // Hooks to customize the behavior of this op.
+  static llvm::StringRef getOperationName() { return "linalg.store"; }
+  static void build(Builder *b, OperationState *result, Value *valueToStore,
+                    Value *view, ArrayRef<Value *> indices = {});
+  LogicalResult verify();
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  // Op-specific functionality.
+  unsigned getRank() { return getViewType().getRank(); }
+  ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
+  Value *getValueToStore() { return getOperand(0); }
+  Value *getView() { return getOperand(1); }
+  Operation::operand_range getIndices() {
+    return {operand_begin() + 2, operand_end()};
+  }
+};
+
 /// The "linalg.view" op produces a linalg.view which is a multi-dimensional
 /// range abstraction on top of an underlying linalg.buffer. This gives an
 /// indexing structure to an otherwise non-indexable linalg.buffer.
@@ -193,9 +249,8 @@ public:
 ///    %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
 ///    %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
 /// ```
-class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands,
-                               mlir::OpTrait::OneResult,
-                               mlir::OpTrait::HasNoSideEffect> {
+class ViewOp : public Op<ViewOp, OpTrait::VariadicOperands, OpTrait::OneResult,
+                         OpTrait::HasNoSideEffect> {
   enum { FirstIndexingOperand = 1 };
 
 public:
@@ -204,25 +259,21 @@ public:
 
   // Hooks to customize the behavior of this op.
   static llvm::StringRef getOperationName() { return "linalg.view"; }
-  static void build(mlir::Builder *b, mlir::OperationState *result,
-                    mlir::Value *buffer,
-                    llvm::ArrayRef<mlir::Value *> indexings);
-  mlir::LogicalResult verify();
-  static ParseResult parse(mlir::OpAsmParser *parser,
-                           mlir::OperationState *result);
-  void print(mlir::OpAsmPrinter *p);
+  static void build(Builder *b, OperationState *result, Value *buffer,
+                    llvm::ArrayRef<Value *> indexings);
+  LogicalResult verify();
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
 
   // Op-specific functionality.
   unsigned getRank() { return getViewType().getRank(); }
-  mlir::Type getElementType() { return getViewType().getElementType(); }
+  Type getElementType() { return getViewType().getElementType(); }
   ViewType getViewType() { return getType().cast<ViewType>(); }
-  mlir::Value *getSupportingBuffer() { return getOperand(0); }
+  Value *getSupportingBuffer() { return getOperand(0); }
   // Get the underlying indexing at a given rank.
-  mlir::Value *getIndexing(unsigned rank) {
-    return *(getIndexings().begin() + rank);
-  }
+  Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
   // Get all the indexings in this view.
-  mlir::Operation::operand_range getIndexings() {
+  Operation::operand_range getIndexings() {
     return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
   }
 };
@@ -245,9 +296,10 @@ public:
 ///    )
 /// ```
 ///
-/// Only permutation maps are currently supported. 
+/// Only permutation maps are currently supported.
 SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
 
+} // namespace linalg
 } // namespace mlir
 
 #endif // MLIR_LINALG_LINALGOPS_H_
index fd07f6c..2aa1e43 100644 (file)
@@ -39,12 +39,12 @@ def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
 def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
 def View : Type<LinalgIsViewTypePred, "view">;
 
-class ParametricNativeOpTrait<string prop, string parameters> :
-  NativeOpTrait<prop # parameters>
+class LinalgParametricNativeOpTrait<string prop, string parameters> :
+  NativeOpTrait<"linalg::" # prop # parameters>
 {}
 
-class ParametricIntNativeOpTrait<string prop, list<int> parameters> :
-  ParametricNativeOpTrait<
+class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
+  LinalgParametricNativeOpTrait<
     prop,
     !strconcat("<",
                !cast<string>(!head(parameters)),
@@ -60,7 +60,7 @@ class ParametricIntNativeOpTrait<string prop, list<int> parameters> :
 // to have a specified number of inputs and outputs, all passed as operands.
 // See Linalg/LinalgTraits.h for implementation details an usage.
 class NInputsAndOutputs<int n_ins, int n_outs> :
-  ParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
+  LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
 {}
 
 // The linalg `NLoopTypes` trait provides the API for ops that are known to have
@@ -68,14 +68,14 @@ class NInputsAndOutputs<int n_ins, int n_outs> :
 // loops.
 // See Linalg/LinalgTraits.h for implementation details an usage.
 class NLoopTypes<int n_par, int n_red, int n_win> :
-ParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
+LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
 {}
 
 // The linalg `ViewRanks` trait the API for ops that are known to have a
 // specified list of view ranks.
 // See Linalg/LinalgTraits.h for implementation details an usage.
 class ViewRanks<list<int> ranks> :
-ParametricIntNativeOpTrait<"ViewRanks", ranks>
+LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
 {}
 
 // Base Tablegen class for Linalg ops.
index 4a7428b..0d557fb 100644 (file)
@@ -24,6 +24,7 @@
 
 namespace mlir {
 namespace OpTrait {
+namespace linalg {
 
 /// This class provides the API for ops that are known to have a specified
 /// number of inputs and outputs, all passed as operands. This is used as a
@@ -44,16 +45,20 @@ public:
     Value *getOutput(unsigned i) {
       return this->getOperand(getNumInputs() + i);
     }
-    ViewType getInputViewType(unsigned i) {
-      return this->getOperand(i)->getType().template cast<ViewType>();
+    mlir::linalg::ViewType getInputViewType(unsigned i) {
+      return this->getOperand(i)
+          ->getType()
+          .template cast<mlir::linalg::ViewType>();
     }
-    ViewType getOutputViewType(unsigned i) {
+    mlir::linalg::ViewType getOutputViewType(unsigned i) {
       return this->getOperand(getNumInputs() + i)
           ->getType()
-          .template cast<ViewType>();
+          .template cast<mlir::linalg::ViewType>();
     }
-    ViewType getViewType(unsigned i) {
-      return this->getOperand(i)->getType().template cast<ViewType>();
+    mlir::linalg::ViewType getViewType(unsigned i) {
+      return this->getOperand(i)
+          ->getType()
+          .template cast<mlir::linalg::ViewType>();
     }
     static LogicalResult verifyTrait(Operation *op) {
       return OpTrait::impl::verifyAtLeastNOperands(op, NInputs + NOutputs);
@@ -98,7 +103,8 @@ public:
       if (op->getNumOperands() != ranks.size())
         return op->emitError("expected " + Twine(ranks.size()) + " operands");
       for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
-        auto viewType = op->getOperand(i)->getType().dyn_cast<ViewType>();
+        auto viewType =
+            op->getOperand(i)->getType().dyn_cast<mlir::linalg::ViewType>();
         if (!viewType)
           return op->emitOpError("operand " + Twine(i) +
                                  " must have view type ");
@@ -111,6 +117,7 @@ public:
   };
 };
 
+} // namespace linalg
 } // namespace OpTrait
 } // namespace mlir
 
index 64f86d4..38ef3cb 100644 (file)
@@ -24,6 +24,7 @@
 namespace mlir {
 class MLIRContext;
 
+namespace linalg {
 enum LinalgTypes {
   Buffer = Type::FIRST_LINALG_TYPE,
   Range,
@@ -110,6 +111,7 @@ public:
   unsigned getRank();
 };
 
+} // namespace linalg
 } // namespace mlir
 
 #endif // MLIR_LINALG_LINALGTYPES_H_
index 7ccb788..931de90 100644 (file)
@@ -30,6 +30,8 @@ class ModulePassBase;
 
 mlir::ModulePassBase *
 createLinalgTilingPass(llvm::ArrayRef<int64_t> tileSizes = {});
+
+mlir::ModulePassBase *createLowerLinalgToLLVMPass();
 } // namespace mlir
 
 #endif // MLIR_LINALG_PASSES_H_
index 356a906..6998da5 100644 (file)
 #include "mlir/Support/STLExtras.h"
 
 using namespace mlir;
+using namespace mlir::linalg;
 
 //////////////////////////////////////////////////////////////////////////////
 // BufferAllocOp
 //////////////////////////////////////////////////////////////////////////////
-void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type,
-                                Value *size) {
+void mlir::linalg::BufferAllocOp::build(Builder *b, OperationState *result,
+                                        Type type, Value *size) {
   result->addOperands({size});
   result->addTypes(type);
 }
 
-mlir::LogicalResult mlir::BufferAllocOp::verify() {
+LogicalResult mlir::linalg::BufferAllocOp::verify() {
   if (!size() || !size()->getType().isa<IndexType>())
     return emitOpError("first operand should be of type index");
   if (!VectorType::isValidElementType(getElementType()) &&
       !getElementType().isa<VectorType>())
     return emitOpError("unsupported buffer element type");
-  return mlir::success();
+  return success();
 }
 
 // A BufferAllocOp prints as:
@@ -54,21 +55,21 @@ mlir::LogicalResult mlir::BufferAllocOp::verify() {
 // ```{.mlir}
 //   linalg.alloc %0 : !linalg.buffer<f32>
 // ```
-void mlir::BufferAllocOp::print(OpAsmPrinter *p) {
+void mlir::linalg::BufferAllocOp::print(OpAsmPrinter *p) {
   *p << getOperationName() << " " << *size() << " : " << getType();
 }
 
-ParseResult mlir::BufferAllocOp::parse(OpAsmParser *parser,
-                                       OperationState *result) {
+ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser,
+                                               OperationState *result) {
   OpAsmParser::OperandType sizeInfo;
   BufferType bufferType;
   auto indexTy = parser->getBuilder().getIndexType();
   if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
     return failure();
   if (bufferType.getElementType() != parser->getBuilder().getF32Type())
-    return parser->emitError(
-        parser->getNameLoc(),
-        "Only buffer<f32> supported until mlir::Parser pieces are exposed");
+    return parser->emitError(parser->getNameLoc(),
+                             "Only buffer<f32> supported until "
+                             "mlir::linalg::Parser pieces are exposed");
   return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
                  parser->addTypeToList(bufferType, result->types));
 }
@@ -76,15 +77,15 @@ ParseResult mlir::BufferAllocOp::parse(OpAsmParser *parser,
 //////////////////////////////////////////////////////////////////////////////
 // BufferDeallocOp
 //////////////////////////////////////////////////////////////////////////////
-void mlir::BufferDeallocOp::build(Builder *b, OperationState *result,
-                                  Value *buffer) {
+void mlir::linalg::BufferDeallocOp::build(Builder *b, OperationState *result,
+                                          Value *buffer) {
   result->addOperands({buffer});
 }
 
-mlir::LogicalResult mlir::BufferDeallocOp::verify() {
+LogicalResult mlir::linalg::BufferDeallocOp::verify() {
   if (!getBuffer()->getType())
     return emitOpError("first operand should be of type buffer");
-  return mlir::success();
+  return success();
 }
 
 // A BufferDeallocOp prints as:
@@ -92,36 +93,99 @@ mlir::LogicalResult mlir::BufferDeallocOp::verify() {
 // ```{.mlir}
 //   linalg.dealloc %0 : !linalg.buffer<f32>
 // ```
-void mlir::BufferDeallocOp::print(OpAsmPrinter *p) {
+void mlir::linalg::BufferDeallocOp::print(OpAsmPrinter *p) {
   *p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
 }
 
-ParseResult mlir::BufferDeallocOp::parse(OpAsmParser *parser,
-                                         OperationState *result) {
+ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
+                                                 OperationState *result) {
   OpAsmParser::OperandType sizeInfo;
   BufferType bufferType;
   return failure(
       parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
       parser->resolveOperands(sizeInfo, bufferType, result->operands));
 }
+
+////////////////////////////////////////////////////////////////////////////////
+// LoadOp.
+////////////////////////////////////////////////////////////////////////////////
+void mlir::linalg::LoadOp::build(Builder *b, OperationState *result,
+                                 Value *view, ArrayRef<Value *> indices) {
+  auto viewType = view->getType().cast<ViewType>();
+  result->addOperands(view);
+  result->addOperands(indices);
+  result->addTypes(viewType.getElementType());
+}
+
+// A LoadOp prints as:
+//
+// ```{.mlir}
+//    %0 = linalg.load %V[%c0] : !linalg.view<?xf32>
+// ```
+void mlir::linalg::LoadOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getView() << '[';
+  p->printOperands(getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << getViewType();
+}
+
+ParseResult mlir::linalg::LoadOp::parse(OpAsmParser *parser,
+                                        OperationState *result) {
+  OpAsmParser::OperandType viewInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  ViewType type;
+
+  auto affineIntTy = parser->getBuilder().getIndexType();
+  return failure(
+      parser->parseOperand(viewInfo) ||
+      parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperand(viewInfo, type, result->operands) ||
+      parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+      parser->addTypeToList(type.getElementType(), result->types));
+}
+
+LogicalResult mlir::linalg::LoadOp::verify() {
+  if (getNumOperands() == 0)
+    return emitOpError("expected a view to load from");
+
+  auto viewType = getView()->getType().dyn_cast<ViewType>();
+  if (!viewType)
+    return emitOpError("first operand must be a view");
+
+  if (getType() != viewType.getElementType())
+    return emitOpError("result type must match element type of the view");
+
+  if (getRank() != getNumOperands() - 1)
+    return emitOpError("incorrect number of indices for load");
+
+  for (auto *idx : getIndices())
+    if (!idx->getType().isIndex())
+      return emitOpError("index to load must have 'index' type");
+
+  return success();
+}
+
 //////////////////////////////////////////////////////////////////////////////
 // RangeOp
 //////////////////////////////////////////////////////////////////////////////
-void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
-                          Value *max, Value *step) {
+void mlir::linalg::RangeOp::build(Builder *b, OperationState *result,
+                                  Value *min, Value *max, Value *step) {
   result->addOperands({min, max, step});
   result->addTypes({RangeType::get(b->getContext())});
 }
 
 // Verification is simply that a RangeOp takes 3 index ssa-value.
-mlir::LogicalResult mlir::RangeOp::verify() {
+LogicalResult mlir::linalg::RangeOp::verify() {
   if (!min() || !min()->getType().isa<IndexType>())
     return emitOpError("first operand should be of type index");
   if (!max() || !max()->getType().isa<IndexType>())
     return emitOpError("second operand should be of type index");
   if (!step() || !step()->getType().isa<IndexType>())
     return emitOpError("third operand should be of type index");
-  return mlir::success();
+  return success();
 }
 
 // A RangeOp prints as:
@@ -129,12 +193,13 @@ mlir::LogicalResult mlir::RangeOp::verify() {
 // ```{.mlir}
 //   linalg.range %0:%1:%2 : !linalg.range
 // ```
-void mlir::RangeOp::print(OpAsmPrinter *p) {
+void mlir::linalg::RangeOp::print(OpAsmPrinter *p) {
   *p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
      << " : " << getType();
 }
 
-ParseResult mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
+ParseResult mlir::linalg::RangeOp::parse(OpAsmParser *parser,
+                                         OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
   RangeType type;
   auto affineIntTy = parser->getBuilder().getIndexType();
@@ -149,8 +214,8 @@ ParseResult mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
 //////////////////////////////////////////////////////////////////////////////
 // SliceOp
 //////////////////////////////////////////////////////////////////////////////
-void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base,
-                          ArrayRef<Value *> indexings) {
+void mlir::linalg::SliceOp::build(Builder *b, OperationState *result,
+                                  Value *base, ArrayRef<Value *> indexings) {
   result->addOperands({base});
   result->addOperands(indexings);
 
@@ -163,7 +228,7 @@ void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base,
   result->addTypes({ViewType::get(b->getContext(), elementType, rank)});
 }
 
-LogicalResult mlir::SliceOp::verify() {
+LogicalResult mlir::linalg::SliceOp::verify() {
   if (llvm::empty(getOperands()))
     return emitOpError(
         "requires at least a view operand followed by 'rank' indices");
@@ -193,7 +258,8 @@ LogicalResult mlir::SliceOp::verify() {
   return success();
 }
 
-ParseResult mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
+ParseResult mlir::linalg::SliceOp::parse(OpAsmParser *parser,
+                                         OperationState *result) {
   OpAsmParser::OperandType baseInfo;
   SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
   SmallVector<Type, 8> types;
@@ -241,11 +307,11 @@ ParseResult mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
 //
 // Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are
 // ssa-value each holding a range.
-void mlir::SliceOp::print(OpAsmPrinter *p) {
+void mlir::linalg::SliceOp::print(OpAsmPrinter *p) {
   *p << getOperationName() << " " << *getBaseView() << "[";
   interleave(
-      getIndexings().begin(), getIndexings().end(),
-      [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
+      getIndexings().begin(), getIndexings().end(), [p](Value *v) { *p << *v; },
+      [p]() { *p << ", "; });
   *p << "] : " << getBaseViewType();
   for (auto indexing : getIndexings()) {
     *p << ", " << indexing->getType();
@@ -253,15 +319,15 @@ void mlir::SliceOp::print(OpAsmPrinter *p) {
   *p << ", " << getType();
 }
 
-ViewOp mlir::SliceOp::getBaseViewOp() {
+ViewOp mlir::linalg::SliceOp::getBaseViewOp() {
   return getOperand(0)->getDefiningOp()->cast<ViewOp>();
 }
 
-ViewType mlir::SliceOp::getBaseViewType() {
+ViewType mlir::linalg::SliceOp::getBaseViewType() {
   return getBaseViewOp().getType().cast<ViewType>();
 }
 
-SmallVector<Value *, 8> mlir::SliceOp::getRanges() {
+SmallVector<Value *, 8> mlir::linalg::SliceOp::getRanges() {
   llvm::SmallVector<Value *, 8> res;
   for (auto *operand : getIndexings()) {
     if (!operand->getType().isa<IndexType>()) {
@@ -271,11 +337,79 @@ SmallVector<Value *, 8> mlir::SliceOp::getRanges() {
   return res;
 }
 
+////////////////////////////////////////////////////////////////////////////////
+// StoreOp.
+////////////////////////////////////////////////////////////////////////////////
+void mlir::linalg::StoreOp::build(Builder *b, OperationState *result,
+                                  Value *valueToStore, Value *view,
+                                  ArrayRef<Value *> indices) {
+  result->addOperands(valueToStore);
+  result->addOperands(view);
+  result->addOperands(indices);
+}
+
+// A StoreOp prints as:
+//
+// ```{.mlir}
+//    linalg.store %f, %V[%c0] : !linalg.view<?xf32>
+// ```
+void mlir::linalg::StoreOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getValueToStore();
+  *p << ", " << *getView() << '[';
+  p->printOperands(getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << getViewType();
+}
+
+ParseResult mlir::linalg::StoreOp::parse(OpAsmParser *parser,
+                                         OperationState *result) {
+  OpAsmParser::OperandType storeValueInfo;
+  OpAsmParser::OperandType viewInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  ViewType viewType;
+
+  auto affineIntTy = parser->getBuilder().getIndexType();
+  return failure(
+      parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+      parser->parseOperand(viewInfo) ||
+      parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(viewType) ||
+      parser->resolveOperand(storeValueInfo, viewType.getElementType(),
+                             result->operands) ||
+      parser->resolveOperand(viewInfo, viewType, result->operands) ||
+      parser->resolveOperands(indexInfo, affineIntTy, result->operands));
+}
+
+LogicalResult mlir::linalg::StoreOp::verify() {
+  if (getNumOperands() < 2)
+    return emitOpError("expected a value to store and a view");
+
+  // Second operand is a memref type.
+  auto viewType = getView()->getType().dyn_cast<ViewType>();
+  if (!viewType)
+    return emitOpError("second operand must be a view");
+
+  // First operand must have same type as memref element type.
+  if (getValueToStore()->getType() != viewType.getElementType())
+    return emitOpError("first operand must have same element type as the view");
+
+  if (getNumOperands() != 2 + viewType.getRank())
+    return emitOpError("store index operand count not equal to view rank");
+
+  for (auto *idx : getIndices())
+    if (!idx->getType().isIndex())
+      return emitOpError("index to store must have 'index' type");
+
+  return success();
+}
+
 //////////////////////////////////////////////////////////////////////////////
 // ViewOp
 //////////////////////////////////////////////////////////////////////////////
-void mlir::ViewOp::build(Builder *b, OperationState *result, Value *buffer,
-                         ArrayRef<Value *> indexings) {
+void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
+                                 Value *buffer, ArrayRef<Value *> indexings) {
   BufferType bufferType = buffer->getType().cast<BufferType>();
   result->addOperands({buffer});
   result->addOperands(indexings);
@@ -289,7 +423,7 @@ void mlir::ViewOp::build(Builder *b, OperationState *result, Value *buffer,
       {ViewType::get(b->getContext(), elementType, indexings.size())});
 }
 
-LogicalResult mlir::ViewOp::verify() {
+LogicalResult mlir::linalg::ViewOp::verify() {
   if (llvm::empty(getOperands()))
     return emitOpError(
         "requires at least a buffer operand followed by indexings");
@@ -309,7 +443,8 @@ LogicalResult mlir::ViewOp::verify() {
   return success();
 }
 
-ParseResult mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
+ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
+                                        OperationState *result) {
   OpAsmParser::OperandType bufferInfo;
   SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
   Type type;
@@ -345,28 +480,30 @@ ParseResult mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
 //
 // Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
 // holding a range.
-void mlir::ViewOp::print(OpAsmPrinter *p) {
+void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
   *p << getOperationName() << " " << *getSupportingBuffer() << "[";
   interleave(
-      getIndexings().begin(), getIndexings().end(),
-      [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
+      getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
+      [&]() { *p << ", "; });
   *p << "] : " << getType();
 }
 
 namespace mlir {
+namespace linalg {
 namespace impl {
-void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op);
+void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op);
 ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
-void printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op);
+void printBufferSizeOp(OpAsmPrinter *p, Operation *op);
 ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
 } // namespace impl
+} // namespace linalg
 
 /// Buffer size prints as:
 ///
 /// ``` {.mlir}
 ///    %0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
 /// ```
-void mlir::impl::printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op) {
+void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) {
   assert(op->getAbstractOperation() && "unregistered operation");
   *p << op->cast<BufferSizeOp>().getOperationName() << " "
      << *op->getOperand(0);
@@ -374,8 +511,8 @@ void mlir::impl::printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op) {
   *p << " : " << op->getOperand(0)->getType();
 }
 
-ParseResult mlir::impl::parseBufferSizeOp(OpAsmParser *parser,
-                                          OperationState *result) {
+ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
+                                                  OperationState *result) {
   OpAsmParser::OperandType op;
   Type type;
   return failure(parser->parseOperand(op) ||
@@ -405,20 +542,20 @@ ParseResult mlir::impl::parseBufferSizeOp(OpAsmParser *parser,
 // ```
 //
 // Where %0, %1 and %2 are ssa-values of type ViewType.
-void mlir::impl::printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op) {
+void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
   assert(op->getAbstractOperation() && "unregistered operation");
   *p << op->getName().getStringRef() << "(";
   interleave(
       op->getOperands().begin(), op->getOperands().end(),
-      [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
+      [&](Value *v) { *p << *v; }, [&]() { *p << ", "; });
   *p << ") : ";
   interleave(
       op->getOperands().begin(), op->getOperands().end(),
-      [&](mlir::Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
+      [&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
 }
 
-ParseResult mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser,
-                                             OperationState *result) {
+ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser,
+                                                     OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 3> ops;
   SmallVector<Type, 3> types;
   return failure(
@@ -431,7 +568,7 @@ ParseResult mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser,
 
 // Ideally this should all be Tablegen'd but there is no good story for
 // AffineMap for now.
-SmallVector<AffineMap, 4> mlir::loopToOperandRangesMaps(Operation *op) {
+SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
   MLIRContext *context = op->getContext();
   auto i = getAffineDimExpr(0, context);
   auto j = getAffineDimExpr(1, context);
index 556d5d1..19105e8 100644 (file)
 #include "mlir/Support/LLVM.h"
 
 using namespace mlir;
+using namespace mlir::linalg;
 
-mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
+mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
     : Dialect("linalg", context) {
   addTypes<BufferType, RangeType, ViewType>();
-  addOperations<BufferAllocOp, BufferDeallocOp, RangeOp, SliceOp, ViewOp>();
+  addOperations<BufferAllocOp, BufferDeallocOp, LoadOp, RangeOp, StoreOp,
+                SliceOp, ViewOp>();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Linalg/IR/LinalgOps.cpp.inc"
       >();
 }
 
-struct mlir::BufferTypeStorage : public mlir::TypeStorage {
+struct mlir::linalg::BufferTypeStorage : public TypeStorage {
   /// Underlying Key type to transport the payload needed to construct a custom
   /// type in a generic way.
   struct Key {
@@ -71,13 +73,17 @@ private:
   Type elementType;
 };
 
-BufferType mlir::BufferType::get(MLIRContext *context, Type elementType) {
+BufferType mlir::linalg::BufferType::get(MLIRContext *context,
+                                         Type elementType) {
   return Base::get(context, LinalgTypes::Buffer, elementType);
 }
 
-Type mlir::BufferType::getElementType() { return getImpl()->getElementType(); }
+Type mlir::linalg::BufferType::getElementType() {
+  return getImpl()->getElementType();
+}
 
-Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
+Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
+                                            Location loc) const {
   MLIRContext *context = getContext();
   if (spec == "range")
     return RangeType::get(getContext());
@@ -97,7 +103,7 @@ Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
   return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
 }
 
-struct mlir::ViewTypeStorage : public mlir::TypeStorage {
+struct mlir::linalg::ViewTypeStorage : public TypeStorage {
   /// Underlying Key type to transport the payload needed to construct a custom
   /// type in a generic way.
   struct Key {
@@ -136,14 +142,16 @@ private:
   unsigned rank;
 };
 
-ViewType mlir::ViewType::get(MLIRContext *context, Type elementType,
-                             unsigned rank) {
+ViewType mlir::linalg::ViewType::get(MLIRContext *context, Type elementType,
+                                     unsigned rank) {
   return Base::get(context, LinalgTypes::View, elementType, rank);
 }
 
-Type mlir::ViewType::getElementType() { return getImpl()->getElementType(); }
+Type mlir::linalg::ViewType::getElementType() {
+  return getImpl()->getElementType();
+}
 
-unsigned mlir::ViewType::getRank() { return getImpl()->getRank(); }
+unsigned mlir::linalg::ViewType::getRank() { return getImpl()->getRank(); }
 
 /// BufferType prints as "buffer<element_type>".
 static void print(BufferType bt, raw_ostream &os) {
@@ -166,7 +174,7 @@ static void print(RangeType rt, raw_ostream &os) { os << "range"; }
 /// ```
 ///
 /// for 0-D views (a.k.a pointer to a scalar value).
-static void print(mlir::ViewType rt, raw_ostream &os) {
+static void print(mlir::linalg::ViewType rt, raw_ostream &os) {
   os << "view<";
   for (unsigned i = 0, e = rt.getRank(); i < e; ++i) {
     os << "?x";
@@ -175,7 +183,7 @@ static void print(mlir::ViewType rt, raw_ostream &os) {
   os << ">";
 }
 
-void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
+void mlir::linalg::LinalgDialect::printType(Type type, raw_ostream &os) const {
   switch (type.getKind()) {
   default:
     llvm_unreachable("Unhandled Linalg type");
index 816b565..cf5bd8f 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/Linalg/IR/LinalgTypes.h"
 
 using namespace mlir;
+using namespace mlir::linalg;
 
 // Static initialization for LinalgOps dialect registration.
 static DialectRegistration<LinalgDialect> LinalgOps;
index 108da6c..90111a8 100644 (file)
@@ -30,6 +30,7 @@
 #include "mlir/LLVMIR/Transforms.h"
 #include "mlir/Linalg/IR/LinalgOps.h"
 #include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Linalg/Passes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Support/LogicalResult.h"
@@ -45,6 +46,7 @@ using namespace mlir;
 using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
 using namespace mlir::LLVM;
+using namespace mlir::linalg;
 
 using undef = ValueBuilder<mlir::LLVM::UndefOp>;
 using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
@@ -53,6 +55,11 @@ using constant = ValueBuilder<mlir::LLVM::ConstantOp>;
 using add = ValueBuilder<mlir::LLVM::AddOp>;
 using sub = ValueBuilder<mlir::LLVM::SubOp>;
 using mul = ValueBuilder<mlir::LLVM::MulOp>;
+using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>;
+using call = OperationBuilder<mlir::LLVM::CallOp>;
+using gep = ValueBuilder<mlir::LLVM::GEPOp>;
+using llvm_load = ValueBuilder<LLVM::LoadOp>;
+using llvm_store = OperationBuilder<LLVM::StoreOp>;
 
 template <typename T>
 static llvm::Type *getPtrToElementType(T containerType,
@@ -85,8 +92,8 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
   //   Elem *ptr;
   //   int64_t size;
   // };
-  if (auto bufferTy = t.dyn_cast<BufferType>()) {
-    auto *ptrTy = getPtrToElementType(bufferTy, lowering);
+  if (auto bufferType = t.dyn_cast<BufferType>()) {
+    auto *ptrTy = getPtrToElementType(bufferType, lowering);
     auto *structTy = llvm::StructType::get(ptrTy, int64Ty);
     return LLVMType::get(context, structTy);
   }
@@ -98,7 +105,7 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
   //   int64_t max;
   //   int64_t step;
   // };
-  if (auto rangeTy = t.dyn_cast<RangeType>()) {
+  if (t.isa<RangeType>()) {
     auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
     return LLVMType::get(context, structTy);
   }
@@ -126,9 +133,9 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
   //   int64_t sizes[Rank];
   //   int64_t strides[Rank];
   // };
-  if (auto viewTy = t.dyn_cast<ViewType>()) {
-    auto *ptrTy = getPtrToElementType(viewTy, lowering);
-    auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
+  if (auto viewType = t.dyn_cast<ViewType>()) {
+    auto *ptrTy = getPtrToElementType(viewType, lowering);
+    auto *arrayTy = llvm::ArrayType::get(int64Ty, viewType.getRank());
     auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy);
     return LLVMType::get(context, structTy);
   }
@@ -147,6 +154,106 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder,
   return builder.getArrayAttr(attrs);
 }
 
+// BufferAllocOp creates a new `index` value.
+class BufferAllocOpConversion : public LLVMOpLowering {
+public:
+  explicit BufferAllocOpConversion(MLIRContext *context,
+                                   LLVMLowering &lowering_)
+      : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto indexType = IndexType::get(op->getContext());
+    auto voidPtrTy = LLVM::LLVMType::get(
+        op->getContext(),
+        lowering.convertType(IntegerType::get(8, op->getContext()))
+            .cast<LLVM::LLVMType>()
+            .getUnderlyingType()
+            ->getPointerTo());
+    auto int64Ty = lowering.convertType(operands[0]->getType());
+    // Insert the `malloc` declaration if it is not already present.
+    Function *mallocFunc =
+        op->getFunction()->getModule()->getNamedFunction("malloc");
+    if (!mallocFunc) {
+      auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
+      mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
+      op->getFunction()->getModule()->getFunctions().push_back(mallocFunc);
+    }
+
+    // Get MLIR types for injecting element pointer.
+    auto allocOp = op->cast<BufferAllocOp>();
+    auto elementType = allocOp.getElementType();
+    uint64_t elementSize = 0;
+    if (auto vectorType = elementType.dyn_cast<VectorType>())
+      elementSize = vectorType.getNumElements() *
+                    llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
+    else
+      elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
+    auto elementPtrType = rewriter.getType<LLVMType>(getPtrToElementType(
+        allocOp.getResult()->getType().cast<BufferType>(), lowering));
+    auto bufferDescriptorType =
+        convertLinalgType(allocOp.getResult()->getType(), lowering);
+
+    // Emit IR for creating a new buffer descriptor with an underlying malloc.
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    Value *size = operands[0];
+    Value *allocSize =
+        mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
+    Value *allocated =
+        call(voidPtrTy, rewriter.getFunctionAttr(mallocFunc), allocSize)
+            .getOperation()
+            ->getResult(0);
+    allocated = bitcast(elementPtrType, allocated);
+    Value *desc = undef(bufferDescriptorType);
+    desc = insertvalue(bufferDescriptorType, desc, allocated,
+                       makePositionAttr(rewriter, 0));
+    desc = insertvalue(bufferDescriptorType, desc, size,
+                       makePositionAttr(rewriter, 1));
+    return {desc};
+  }
+};
+
+// BufferDeallocOp creates a new `index` value.
+class BufferDeallocOpConversion : public LLVMOpLowering {
+public:
+  explicit BufferDeallocOpConversion(MLIRContext *context,
+                                     LLVMLowering &lowering_)
+      : LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
+                       lowering_) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto voidPtrTy = LLVM::LLVMType::get(
+        op->getContext(),
+        lowering.convertType(IntegerType::get(8, op->getContext()))
+            .cast<LLVM::LLVMType>()
+            .getUnderlyingType()
+            ->getPointerTo());
+    // Insert the `free` declaration if it is not already present.
+    Function *freeFunc =
+        op->getFunction()->getModule()->getNamedFunction("free");
+    if (!freeFunc) {
+      auto freeType = rewriter.getFunctionType(voidPtrTy, {});
+      freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
+      op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
+    }
+
+    // Get MLIR types for extracting element pointer.
+    auto deallocOp = op->cast<BufferDeallocOp>();
+    auto elementPtrTy = rewriter.getType<LLVMType>(getPtrToElementType(
+        deallocOp.getOperand()->getType().cast<BufferType>(), lowering));
+
+    // Emit MLIR for buffer_dealloc.
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    Value *casted =
+        bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
+                                        makePositionAttr(rewriter, 0)));
+    call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
+
+    return {};
+  }
+};
+
 // BufferSizeOp creates a new `index` value.
 class BufferSizeOpConversion : public LLVMOpLowering {
 public:
@@ -155,10 +262,62 @@ public:
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto bufferSizeType = lowering.convertType(operands[0]->getType());
+    auto int64Ty = lowering.convertType(operands[0]->getType());
     edsc::ScopedContext context(rewriter, op->getLoc());
-    return {extractvalue(bufferSizeType, operands[0],
-                         makePositionAttr(rewriter, 1))};
+    return {extractvalue(int64Ty, operands[0], makePositionAttr(rewriter, 1))};
+  }
+};
+
+namespace {
+// Common functionality for Linalg LoadOp and StoreOp conversion to the
+// LLVM IR Dialect.
+template <typename Op> class LoadStoreOpConversion : public LLVMOpLowering {
+public:
+  explicit LoadStoreOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+      : LLVMOpLowering(Op::getOperationName(), context, lowering_) {}
+  using Base = LoadStoreOpConversion<Op>;
+
+  // Compute the pointer to an element of the buffer underlying the view given
+  // current view indices.  Use the base offset and strides stored in the view
+  // descriptor to emit IR iteratively computing the actual offset, followed by
+  // a getelementptr. This must be called under an edsc::ScopedContext.
+  Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
+                       ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
+    auto loadOp = op->cast<Op>();
+    auto elementTy = rewriter.getType<LLVMType>(
+        getPtrToElementType(loadOp.getViewType(), lowering));
+    auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
+    auto pos = [&rewriter](ArrayRef<int> values) {
+      return makePositionAttr(rewriter, values);
+    };
+
+    // Linearize subscripts as:
+    //   base_offset + SUM_i index_i * stride_i.
+    Value *base = extractvalue(elementTy, viewDescriptor, pos(0));
+    Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1));
+    for (int i = 0, e = loadOp.getRank(); i < e; ++i) {
+      Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i}));
+      Value *additionalOffset = mul(indices[i], stride);
+      offset = add(offset, additionalOffset);
+    }
+    return gep(elementTy, base, offset);
+  }
+};
+} // namespace
+
+// A load is converted into the actual address computation, getelementptr and
+// an LLVM IR load.
+class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
+  using Base::Base;
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    edsc::ScopedContext edscContext(rewriter, op->getLoc());
+    auto elementTy = lowering.convertType(*op->getResultTypes().begin());
+    Value *viewDescriptor = operands[0];
+    ArrayRef<Value *> indices = operands.drop_front();
+    auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
+    Value *element = llvm_load(elementTy, ptr);
+    return {element};
   }
 };
 
@@ -171,18 +330,18 @@ public:
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
     auto rangeOp = op->cast<RangeOp>();
-    auto rangeDescriptorType =
+    auto rangeDescriptorTy =
         convertLinalgType(rangeOp.getResult()->getType(), lowering);
 
     edsc::ScopedContext context(rewriter, op->getLoc());
 
     // Fill in an aggregate value of the descriptor.
-    Value *desc = undef(rangeDescriptorType);
-    desc = insertvalue(rangeDescriptorType, desc, operands[0],
+    Value *desc = undef(rangeDescriptorTy);
+    desc = insertvalue(rangeDescriptorTy, desc, operands[0],
                        makePositionAttr(rewriter, 0));
-    desc = insertvalue(rangeDescriptorType, desc, operands[1],
+    desc = insertvalue(rangeDescriptorTy, desc, operands[1],
                        makePositionAttr(rewriter, 1));
-    desc = insertvalue(rangeDescriptorType, desc, operands[2],
+    desc = insertvalue(rangeDescriptorTy, desc, operands[2],
                        makePositionAttr(rewriter, 2));
 
     return {desc};
@@ -197,8 +356,7 @@ public:
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
     auto sliceOp = op->cast<SliceOp>();
-    auto viewDescriptorType =
-        convertLinalgType(sliceOp.getViewType(), lowering);
+    auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
     auto viewType = sliceOp.getBaseViewType();
     auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
 
@@ -217,8 +375,8 @@ public:
 
     edsc::ScopedContext context(rewriter, op->getLoc());
     // Declare the view descriptor and insert data ptr.
-    Value *desc = undef(viewDescriptorType);
-    desc = insertvalue(viewDescriptorType, desc,
+    Value *desc = undef(viewDescriptorTy);
+    desc = insertvalue(viewDescriptorTy, desc,
                        getViewPtr(viewType, operands[0]), pos(0));
 
     // TODO(ntv): extract sizes and emit asserts.
@@ -238,7 +396,7 @@ public:
       Value *product = mul(min, strides[j]);
       baseOffset = add(baseOffset, product);
     }
-    desc = insertvalue(viewDescriptorType, desc, baseOffset, pos(1));
+    desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
 
     // Compute and insert view sizes (max - min along the range).  Skip the
     // non-range operands as they will be projected away from the view.
@@ -252,7 +410,7 @@ public:
       Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
       Value *size = sub(max, min);
 
-      desc = insertvalue(viewDescriptorType, desc, size, pos({2, i}));
+      desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
       ++i;
     }
 
@@ -264,7 +422,7 @@ public:
         continue;
       Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
       Value *stride = mul(strides[j], step);
-      desc = insertvalue(viewDescriptorType, desc, stride, pos({3, i}));
+      desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
       ++i;
     }
 
@@ -272,6 +430,22 @@ public:
   }
 };
 
+// A store is converted into the actual address computation, getelementptr and
+// an LLVM IR store.
+class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
+  using Base::Base;
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    edsc::ScopedContext edscContext(rewriter, op->getLoc());
+    Value *data = operands[0];
+    Value *viewDescriptor = operands[1];
+    ArrayRef<Value *> indices = operands.drop_front(2);
+    Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
+    llvm_store(data, ptr);
+    return {};
+  }
+};
+
 class ViewOpConversion : public LLVMOpLowering {
 public:
   explicit ViewOpConversion(MLIRContext *context, LLVMLowering &lowering_)
@@ -280,8 +454,8 @@ public:
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
     auto viewOp = op->cast<ViewOp>();
-    auto viewDescriptorType = convertLinalgType(viewOp.getViewType(), lowering);
-    auto elementType = rewriter.getType<LLVMType>(
+    auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
+    auto elementTy = rewriter.getType<LLVMType>(
         getPtrToElementType(viewOp.getViewType(), lowering));
     auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
 
@@ -294,16 +468,16 @@ public:
 
     // Declare the descriptor of the view.
     edsc::ScopedContext context(rewriter, op->getLoc());
-    Value *desc = undef(viewDescriptorType);
+    Value *desc = undef(viewDescriptorTy);
 
     // Copy the buffer pointer from the old descriptor to the new one.
-    Value *buffer = extractvalue(elementType, bufferDescriptor, pos(0));
-    desc = insertvalue(viewDescriptorType, desc, buffer, pos(0));
+    Value *buffer = extractvalue(elementTy, bufferDescriptor, pos(0));
+    desc = insertvalue(viewDescriptorTy, desc, buffer, pos(0));
 
     // Zero base offset.
     auto indexTy = rewriter.getIndexType();
     Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
-    desc = insertvalue(viewDescriptorType, desc, baseOffset, pos(1));
+    desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
 
     // Compute and insert view sizes (max - min along the range).
     int numIndexings = llvm::size(viewOp.getIndexings());
@@ -313,12 +487,12 @@ public:
       Value *rangeDescriptor = operands[1 + i];
       Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
       Value *stride = mul(runningStride, step);
-      desc = insertvalue(viewDescriptorType, desc, stride, pos({3, i}));
+      desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
       // Update size.
       Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
       Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
       Value *size = sub(max, min);
-      desc = insertvalue(viewDescriptorType, desc, size, pos({2, i}));
+      desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
       ++i;
       // Update stride for the next dimension.
       if (i < numIndexings - 1)
@@ -346,7 +520,8 @@ public:
                     "in lowering to LLVM ");
     auto fAttr = rewriter.getFunctionAttr(f);
     auto named = rewriter.getNamedAttr("callee", fAttr);
-    rewriter.create<LLVM::CallOp>(op->getLoc(), operands, ArrayRef<NamedAttribute>{named});
+    rewriter.create<LLVM::CallOp>(op->getLoc(), operands,
+                                  ArrayRef<NamedAttribute>{named});
     return {};
   }
 };
@@ -357,10 +532,11 @@ class Lowering : public LLVMLowering {
 protected:
   llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() override {
     return ConversionListBuilder<
-        BufferSizeOpConversion, DotOpConversion, RangeOpConversion,
-        SliceOpConversion, ViewOpConversion>::build(&converterStorage,
-                                                    llvmDialect->getContext(),
-                                                    *this);
+        BufferAllocOpConversion, BufferDeallocOpConversion,
+        BufferSizeOpConversion, DotOpConversion, LoadOpConversion,
+        RangeOpConversion, SliceOpConversion, StoreOpConversion,
+        ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(),
+                                 *this);
   }
 
   Type convertAdditionalType(Type t) override {
@@ -378,13 +554,17 @@ struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
 void LowerLinalgToLLVMPass::runOnModule() {
   auto &module = getModule();
 
+  PassManager pm;
+  pm.addPass(createLowerAffinePass());
+  if (failed(pm.run(&module)))
+    signalPassFailure();
+
   // Convert to the LLVM IR dialect using the converter defined above.
-  auto r = Lowering().convert(&module);
-  if (failed(r))
+  if (failed(Lowering().convert(&module)))
     signalPassFailure();
 }
 
-ModulePassBase *createLowerLinalgToLLVMPass() {
+ModulePassBase *mlir::createLowerLinalgToLLVMPass() {
   return new LowerLinalgToLLVMPass();
 }
 
index 48ddb3d..434f720 100644 (file)
@@ -37,6 +37,7 @@
 using namespace mlir;
 using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
 using namespace llvm;
 
 static llvm::cl::OptionCategory clOptionsCategory("linalg options");
index 0052ef0..4b77ece 100644 (file)
@@ -33,6 +33,7 @@
 using namespace mlir;
 using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
 using namespace llvm;
 
 mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
diff --git a/mlir/test/mlir-cpu-runner/simple_linalg.mlir b/mlir/test/mlir-cpu-runner/simple_linalg.mlir
new file mode 100644 (file)
index 0000000..119cea6
--- /dev/null
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s
+
+func @linalg_dot(!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
+                 !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
+                 !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
+  return
+}
+
+func @dot(%arg0: !linalg.buffer<f32>, %arg1: !linalg.buffer<f32>, %arg2: !linalg.buffer<f32>) -> f32 {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %s = linalg.buffer_size %arg0 : !linalg.buffer<f32>
+  %R = linalg.range %c0:%s:%c1 : !linalg.range
+  %A = linalg.view %arg0[%R] : !linalg.view<?xf32>
+  %B = linalg.view %arg1[%R] : !linalg.view<?xf32>
+  %C = linalg.view %arg2[] : !linalg.view<f32>
+  linalg.dot(%A, %B, %C) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
+  %res = linalg.load %C[] : !linalg.view<f32>
+  return %res : f32
+}
+
+func @fill_f32(%arg0 : !linalg.buffer<f32>, %f : f32) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %s = linalg.buffer_size %arg0 : !linalg.buffer<f32>
+  %R = linalg.range %c0:%s:%c1 : !linalg.range
+  %V = linalg.view %arg0[%R] : !linalg.view<?xf32>
+  affine.for %i0 = 0 to %s {
+    linalg.store %f, %V[%i0] : !linalg.view<?xf32>
+  }
+  return
+}
+
+func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<f32> {
+  %A = linalg.buffer_alloc %s : !linalg.buffer<f32>
+  call @fill_f32(%A, %f) : (!linalg.buffer<f32>, f32) -> ()
+  return %A : !linalg.buffer<f32>
+}
+
+func @entry1() -> f32 {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c16 = constant 16 : index
+  %f0 = constant 0.00000e+00 : f32
+  %f1 = constant 0.00000e+00 : f32
+  %f2 = constant 2.00000e+00 : f32
+
+  %A = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer<f32>)
+  %B = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer<f32>)
+  %C = call @alloc_filled_f32(%c1, %f0) : (index, f32) -> (!linalg.buffer<f32>)
+  %res = call @dot(%A, %B, %C) : (!linalg.buffer<f32>, !linalg.buffer<f32>, !linalg.buffer<f32>) -> (f32)
+  linalg.buffer_dealloc %C : !linalg.buffer<f32>
+  linalg.buffer_dealloc %B : !linalg.buffer<f32>
+  linalg.buffer_dealloc %A : !linalg.buffer<f32>
+  return %res : f32
+}
+
+// CHECK: 0.{{0+}}e+00
\ No newline at end of file
index eff9409..844e8db 100644 (file)
@@ -4,6 +4,7 @@ set(LIBS
   MLIREDSC
   MLIRExecutionEngine
   MLIRIR
+  MLIRLLVMIR
   MLIRParser
   MLIRTargetLLVMIR
   MLIRTransforms
index 5d65886..5deadb0 100644 (file)
@@ -26,6 +26,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
 #include "mlir/Parser.h"
 #include "mlir/Support/FileUtilities.h"
 
@@ -56,6 +57,10 @@ static llvm::cl::opt<std::string>
     mainFuncName("e", llvm::cl::desc("The function to be called"),
                  llvm::cl::value_desc("<function name>"),
                  llvm::cl::init("main"));
+static llvm::cl::opt<std::string> mainFuncType(
+    "entry-point-result",
+    llvm::cl::desc("Textual description of the function type to be called"),
+    llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
 
 static llvm::cl::OptionCategory optFlags("opt-like flags");
 
@@ -129,9 +134,9 @@ static void printMemRefArguments(ArrayRef<Type> argTypes,
   }
 }
 
-static Error
-compileAndExecute(Module *module, StringRef entryPoint,
-                  std::function<llvm::Error(llvm::Module *)> transformer) {
+static Error compileAndExecuteFunctionWithMemRefs(
+    Module *module, StringRef entryPoint,
+    std::function<llvm::Error(llvm::Module *)> transformer) {
   Function *mainFunction = module->getNamedFunction(entryPoint);
   if (!mainFunction || mainFunction->getBlocks().empty()) {
     return make_string_error("entry point not found");
@@ -167,6 +172,50 @@ compileAndExecute(Module *module, StringRef entryPoint,
   return Error::success();
 }
 
+static Error compileAndExecuteSingleFloatReturnFunction(
+    Module *module, StringRef entryPoint,
+    std::function<llvm::Error(llvm::Module *)> transformer) {
+  Function *mainFunction = module->getNamedFunction(entryPoint);
+  if (!mainFunction || mainFunction->isExternal()) {
+    return make_string_error("entry point not found");
+  }
+
+  if (!mainFunction->getType().getInputs().empty())
+    return make_string_error("function inputs not supported");
+
+  if (mainFunction->getType().getResults().size() != 1)
+    return make_string_error("only single f32 function result supported");
+
+  auto t = mainFunction->getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
+  if (!t)
+    return make_string_error("only single llvm.f32 function result supported");
+  auto *llvmTy = t.getUnderlyingType();
+  if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
+    return make_string_error("only single llvm.f32 function result supported");
+
+  auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
+  if (!expectedEngine)
+    return expectedEngine.takeError();
+
+  auto engine = std::move(*expectedEngine);
+  auto expectedFPtr = engine->lookup(entryPoint);
+  if (!expectedFPtr)
+    return expectedFPtr.takeError();
+  void (*fptr)(void **) = *expectedFPtr;
+
+  float res;
+  struct {
+    void *data;
+  } data;
+  data.data = &res;
+  (*fptr)((void **)&data);
+
+  // Intentional printing of the output so we can test.
+  llvm::outs() << res;
+
+  return Error::success();
+}
+
 int main(int argc, char **argv) {
   llvm::PrettyStackTraceProgram x(argc, argv);
   llvm::InitLLVM y(argc, argv);
@@ -212,7 +261,11 @@ int main(int argc, char **argv) {
 
   auto transformer =
       mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
-  auto error = compileAndExecute(m.get(), mainFuncName.getValue(), transformer);
+  auto error = mainFuncType.getValue() == "f32"
+                   ? compileAndExecuteSingleFloatReturnFunction(
+                         m.get(), mainFuncName.getValue(), transformer)
+                   : compileAndExecuteFunctionWithMemRefs(
+                         m.get(), mainFuncName.getValue(), transformer);
   int exitCode = EXIT_SUCCESS;
   llvm::handleAllErrors(std::move(error),
                         [&exitCode](const llvm::ErrorInfoBase &info) {