From 33449c3e6c0cbff11a168e7a05c2ecd342f1b80b Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 9 May 2019 12:34:04 -0700 Subject: [PATCH] Pipe Linalg to LLVM via mlir-cpu-runner This CL adds support for functions in the Linalg dialect to run with mlir-cpu-runner. For this purpose, this CL adds BufferAllocOp, BufferDeallocOp, LoadOp and StoreOp to the Linalg dialect as well as their lowering to LLVM. To avoid collisions with mlir::LoadOp/StoreOp (which should really become mlir::affine::LoadOp/StoreOp), the mlir::linalg namespace is added. The execution uses a dummy linalg_dot function that just returns for now. In the future a proper library call will be used. -- PiperOrigin-RevId: 247476061 --- mlir/include/mlir/Linalg/IR/LinalgOps.h | 114 +++++++--- mlir/include/mlir/Linalg/IR/LinalgOps.td | 14 +- mlir/include/mlir/Linalg/IR/LinalgTraits.h | 21 +- mlir/include/mlir/Linalg/IR/LinalgTypes.h | 2 + mlir/include/mlir/Linalg/Passes.h | 2 + mlir/lib/Linalg/IR/LinalgOps.cpp | 239 +++++++++++++++----- mlir/lib/Linalg/IR/LinalgTypes.cpp | 34 +-- mlir/lib/Linalg/LinalgRegistration.cpp | 1 + mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 254 ++++++++++++++++++---- mlir/lib/Linalg/Transforms/Tiling.cpp | 1 + mlir/lib/Linalg/Utils/Utils.cpp | 1 + mlir/test/mlir-cpu-runner/simple_linalg.mlir | 58 +++++ mlir/tools/mlir-cpu-runner/CMakeLists.txt | 1 + mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp | 61 +++++- 14 files changed, 653 insertions(+), 150 deletions(-) create mode 100644 mlir/test/mlir-cpu-runner/simple_linalg.mlir diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index 9472c71..f468b96 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -24,6 +24,7 @@ #include "mlir/Support/LLVM.h" namespace mlir { +namespace linalg { /// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type, /// upon which a base view can be laid out to give it indexing semantics. @@ -77,6 +78,35 @@ public: } }; +/// A linalg.LoadOp is the counterpart of load but operating on ViewType +/// instead of MemRefType. +/// +/// ```{.mlir} +/// %0 = linalg.load %V[%c0] : !linalg.view +/// ``` +class LoadOp + : public Op { +public: + friend Operation; + using Op::Op; + + // Hooks to customize the behavior of this op. + static llvm::StringRef getOperationName() { return "linalg.load"; } + static void build(Builder *b, OperationState *result, Value *view, + ArrayRef indices = {}); + LogicalResult verify(); + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + + // Op-specific functionality. + unsigned getRank() { return getViewType().getRank(); } + ViewType getViewType() { return getView()->getType().cast(); } + Value *getView() { return getOperand(0); } + Operation::operand_range getIndices() { + return {operand_begin() + 1, operand_end()}; + } +}; + /// The "linalg.range" op creates a linalg.range from 3 values of type `index` /// that represent the min, max and step values of the range. /// @@ -142,9 +172,8 @@ public: /// !linalg.view /// ``` class ViewOp; -class SliceOp : public mlir::Op { +class SliceOp : public Op { enum { FirstIndexingOperand = 1 }; public: @@ -153,33 +182,60 @@ public: // Hooks to customize the behavior of this op. static llvm::StringRef getOperationName() { return "linalg.slice"; } - static void build(mlir::Builder *b, mlir::OperationState *result, - mlir::Value *base, llvm::ArrayRef indexings); - mlir::LogicalResult verify(); - static ParseResult parse(mlir::OpAsmParser *parser, - mlir::OperationState *result); - void print(mlir::OpAsmPrinter *p); + static void build(Builder *b, OperationState *result, Value *base, + llvm::ArrayRef indexings); + LogicalResult verify(); + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); // Op-specific functionality. unsigned getRank() { return getViewType().getRank(); } - mlir::Type getElementType() { return getViewType().getElementType(); } + Type getElementType() { return getViewType().getElementType(); } ViewType getViewType() { return getType().cast(); } Value *getBaseView() { return getOperand(0); } ViewOp getBaseViewOp(); ViewType getBaseViewType(); unsigned getBaseViewRank() { return getBaseViewType().getRank(); } // Get the underlying indexing at a given rank. - mlir::Value *getIndexing(unsigned rank) { - return *(getIndexings().begin() + rank); - } + Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); } // Get all the indexings in this view. - mlir::Operation::operand_range getIndexings() { + Operation::operand_range getIndexings() { return {operand_begin() + SliceOp::FirstIndexingOperand, operand_end()}; } // Get the subset of indexings that are of RangeType. SmallVector getRanges(); }; +/// A linalg.StoreOp is the counterpart of affine.store but operating on +/// ViewType instead of MemRefType. +/// +/// ```{.mlir} +/// linalg.store %f, %V[%c0] : !linalg.view +/// ``` +class StoreOp + : public Op { +public: + friend Operation; + using Op::Op; + + // Hooks to customize the behavior of this op. + static llvm::StringRef getOperationName() { return "linalg.store"; } + static void build(Builder *b, OperationState *result, Value *valueToStore, + Value *view, ArrayRef indices = {}); + LogicalResult verify(); + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + + // Op-specific functionality. + unsigned getRank() { return getViewType().getRank(); } + ViewType getViewType() { return getView()->getType().cast(); } + Value *getValueToStore() { return getOperand(0); } + Value *getView() { return getOperand(1); } + Operation::operand_range getIndices() { + return {operand_begin() + 2, operand_end()}; + } +}; + /// The "linalg.view" op produces a linalg.view which is a multi-dimensional /// range abstraction on top of an underlying linalg.buffer. This gives an /// indexing structure to an otherwise non-indexable linalg.buffer. @@ -193,9 +249,8 @@ public: /// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range /// %3 = linalg.view %1[%2, %2] : !linalg.view /// ``` -class ViewOp : public mlir::Op { +class ViewOp : public Op { enum { FirstIndexingOperand = 1 }; public: @@ -204,25 +259,21 @@ public: // Hooks to customize the behavior of this op. static llvm::StringRef getOperationName() { return "linalg.view"; } - static void build(mlir::Builder *b, mlir::OperationState *result, - mlir::Value *buffer, - llvm::ArrayRef indexings); - mlir::LogicalResult verify(); - static ParseResult parse(mlir::OpAsmParser *parser, - mlir::OperationState *result); - void print(mlir::OpAsmPrinter *p); + static void build(Builder *b, OperationState *result, Value *buffer, + llvm::ArrayRef indexings); + LogicalResult verify(); + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); // Op-specific functionality. unsigned getRank() { return getViewType().getRank(); } - mlir::Type getElementType() { return getViewType().getElementType(); } + Type getElementType() { return getViewType().getElementType(); } ViewType getViewType() { return getType().cast(); } - mlir::Value *getSupportingBuffer() { return getOperand(0); } + Value *getSupportingBuffer() { return getOperand(0); } // Get the underlying indexing at a given rank. - mlir::Value *getIndexing(unsigned rank) { - return *(getIndexings().begin() + rank); - } + Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); } // Get all the indexings in this view. - mlir::Operation::operand_range getIndexings() { + Operation::operand_range getIndexings() { return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()}; } }; @@ -245,9 +296,10 @@ public: /// ) /// ``` /// -/// Only permutation maps are currently supported. +/// Only permutation maps are currently supported. SmallVector loopToOperandRangesMaps(Operation *op); +} // namespace linalg } // namespace mlir #endif // MLIR_LINALG_LINALGOPS_H_ diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Linalg/IR/LinalgOps.td index fd07f6c..2aa1e43 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.td @@ -39,12 +39,12 @@ def Buffer : Type; def LinalgIsViewTypePred : CPred<"$_self.isa()">; def View : Type; -class ParametricNativeOpTrait : - NativeOpTrait +class LinalgParametricNativeOpTrait : + NativeOpTrait<"linalg::" # prop # parameters> {} -class ParametricIntNativeOpTrait parameters> : - ParametricNativeOpTrait< +class LinalgParametricIntNativeOpTrait parameters> : + LinalgParametricNativeOpTrait< prop, !strconcat("<", !cast(!head(parameters)), @@ -60,7 +60,7 @@ class ParametricIntNativeOpTrait parameters> : // to have a specified number of inputs and outputs, all passed as operands. // See Linalg/LinalgTraits.h for implementation details an usage. class NInputsAndOutputs : - ParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]> + LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]> {} // The linalg `NLoopTypes` trait provides the API for ops that are known to have @@ -68,14 +68,14 @@ class NInputsAndOutputs : // loops. // See Linalg/LinalgTraits.h for implementation details an usage. class NLoopTypes : -ParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]> +LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]> {} // The linalg `ViewRanks` trait the API for ops that are known to have a // specified list of view ranks. // See Linalg/LinalgTraits.h for implementation details an usage. class ViewRanks ranks> : -ParametricIntNativeOpTrait<"ViewRanks", ranks> +LinalgParametricIntNativeOpTrait<"ViewRanks", ranks> {} // Base Tablegen class for Linalg ops. diff --git a/mlir/include/mlir/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Linalg/IR/LinalgTraits.h index 4a7428b..0d557fb 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Linalg/IR/LinalgTraits.h @@ -24,6 +24,7 @@ namespace mlir { namespace OpTrait { +namespace linalg { /// This class provides the API for ops that are known to have a specified /// number of inputs and outputs, all passed as operands. This is used as a @@ -44,16 +45,20 @@ public: Value *getOutput(unsigned i) { return this->getOperand(getNumInputs() + i); } - ViewType getInputViewType(unsigned i) { - return this->getOperand(i)->getType().template cast(); + mlir::linalg::ViewType getInputViewType(unsigned i) { + return this->getOperand(i) + ->getType() + .template cast(); } - ViewType getOutputViewType(unsigned i) { + mlir::linalg::ViewType getOutputViewType(unsigned i) { return this->getOperand(getNumInputs() + i) ->getType() - .template cast(); + .template cast(); } - ViewType getViewType(unsigned i) { - return this->getOperand(i)->getType().template cast(); + mlir::linalg::ViewType getViewType(unsigned i) { + return this->getOperand(i) + ->getType() + .template cast(); } static LogicalResult verifyTrait(Operation *op) { return OpTrait::impl::verifyAtLeastNOperands(op, NInputs + NOutputs); @@ -98,7 +103,8 @@ public: if (op->getNumOperands() != ranks.size()) return op->emitError("expected " + Twine(ranks.size()) + " operands"); for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) { - auto viewType = op->getOperand(i)->getType().dyn_cast(); + auto viewType = + op->getOperand(i)->getType().dyn_cast(); if (!viewType) return op->emitOpError("operand " + Twine(i) + " must have view type "); @@ -111,6 +117,7 @@ public: }; }; +} // namespace linalg } // namespace OpTrait } // namespace mlir diff --git a/mlir/include/mlir/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Linalg/IR/LinalgTypes.h index 64f86d4..38ef3cb 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Linalg/IR/LinalgTypes.h @@ -24,6 +24,7 @@ namespace mlir { class MLIRContext; +namespace linalg { enum LinalgTypes { Buffer = Type::FIRST_LINALG_TYPE, Range, @@ -110,6 +111,7 @@ public: unsigned getRank(); }; +} // namespace linalg } // namespace mlir #endif // MLIR_LINALG_LINALGTYPES_H_ diff --git a/mlir/include/mlir/Linalg/Passes.h b/mlir/include/mlir/Linalg/Passes.h index 7ccb788..931de90 100644 --- a/mlir/include/mlir/Linalg/Passes.h +++ b/mlir/include/mlir/Linalg/Passes.h @@ -30,6 +30,8 @@ class ModulePassBase; mlir::ModulePassBase * createLinalgTilingPass(llvm::ArrayRef tileSizes = {}); + +mlir::ModulePassBase *createLowerLinalgToLLVMPass(); } // namespace mlir #endif // MLIR_LINALG_PASSES_H_ diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 356a906..6998da5 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -30,23 +30,24 @@ #include "mlir/Support/STLExtras.h" using namespace mlir; +using namespace mlir::linalg; ////////////////////////////////////////////////////////////////////////////// // BufferAllocOp ////////////////////////////////////////////////////////////////////////////// -void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type, - Value *size) { +void mlir::linalg::BufferAllocOp::build(Builder *b, OperationState *result, + Type type, Value *size) { result->addOperands({size}); result->addTypes(type); } -mlir::LogicalResult mlir::BufferAllocOp::verify() { +LogicalResult mlir::linalg::BufferAllocOp::verify() { if (!size() || !size()->getType().isa()) return emitOpError("first operand should be of type index"); if (!VectorType::isValidElementType(getElementType()) && !getElementType().isa()) return emitOpError("unsupported buffer element type"); - return mlir::success(); + return success(); } // A BufferAllocOp prints as: @@ -54,21 +55,21 @@ mlir::LogicalResult mlir::BufferAllocOp::verify() { // ```{.mlir} // linalg.alloc %0 : !linalg.buffer // ``` -void mlir::BufferAllocOp::print(OpAsmPrinter *p) { +void mlir::linalg::BufferAllocOp::print(OpAsmPrinter *p) { *p << getOperationName() << " " << *size() << " : " << getType(); } -ParseResult mlir::BufferAllocOp::parse(OpAsmParser *parser, - OperationState *result) { +ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType sizeInfo; BufferType bufferType; auto indexTy = parser->getBuilder().getIndexType(); if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType)) return failure(); if (bufferType.getElementType() != parser->getBuilder().getF32Type()) - return parser->emitError( - parser->getNameLoc(), - "Only buffer supported until mlir::Parser pieces are exposed"); + return parser->emitError(parser->getNameLoc(), + "Only buffer supported until " + "mlir::linalg::Parser pieces are exposed"); return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) || parser->addTypeToList(bufferType, result->types)); } @@ -76,15 +77,15 @@ ParseResult mlir::BufferAllocOp::parse(OpAsmParser *parser, ////////////////////////////////////////////////////////////////////////////// // BufferDeallocOp ////////////////////////////////////////////////////////////////////////////// -void mlir::BufferDeallocOp::build(Builder *b, OperationState *result, - Value *buffer) { +void mlir::linalg::BufferDeallocOp::build(Builder *b, OperationState *result, + Value *buffer) { result->addOperands({buffer}); } -mlir::LogicalResult mlir::BufferDeallocOp::verify() { +LogicalResult mlir::linalg::BufferDeallocOp::verify() { if (!getBuffer()->getType()) return emitOpError("first operand should be of type buffer"); - return mlir::success(); + return success(); } // A BufferDeallocOp prints as: @@ -92,36 +93,99 @@ mlir::LogicalResult mlir::BufferDeallocOp::verify() { // ```{.mlir} // linalg.dealloc %0 : !linalg.buffer // ``` -void mlir::BufferDeallocOp::print(OpAsmPrinter *p) { +void mlir::linalg::BufferDeallocOp::print(OpAsmPrinter *p) { *p << getOperationName() << " " << *getBuffer() << " : " << getBufferType(); } -ParseResult mlir::BufferDeallocOp::parse(OpAsmParser *parser, - OperationState *result) { +ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType sizeInfo; BufferType bufferType; return failure( parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) || parser->resolveOperands(sizeInfo, bufferType, result->operands)); } + +//////////////////////////////////////////////////////////////////////////////// +// LoadOp. +//////////////////////////////////////////////////////////////////////////////// +void mlir::linalg::LoadOp::build(Builder *b, OperationState *result, + Value *view, ArrayRef indices) { + auto viewType = view->getType().cast(); + result->addOperands(view); + result->addOperands(indices); + result->addTypes(viewType.getElementType()); +} + +// A LoadOp prints as: +// +// ```{.mlir} +// %0 = linalg.load %V[%c0] : !linalg.view +// ``` +void mlir::linalg::LoadOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getView() << '['; + p->printOperands(getIndices()); + *p << ']'; + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getViewType(); +} + +ParseResult mlir::linalg::LoadOp::parse(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType viewInfo; + SmallVector indexInfo; + ViewType type; + + auto affineIntTy = parser->getBuilder().getIndexType(); + return failure( + parser->parseOperand(viewInfo) || + parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(viewInfo, type, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands) || + parser->addTypeToList(type.getElementType(), result->types)); +} + +LogicalResult mlir::linalg::LoadOp::verify() { + if (getNumOperands() == 0) + return emitOpError("expected a view to load from"); + + auto viewType = getView()->getType().dyn_cast(); + if (!viewType) + return emitOpError("first operand must be a view"); + + if (getType() != viewType.getElementType()) + return emitOpError("result type must match element type of the view"); + + if (getRank() != getNumOperands() - 1) + return emitOpError("incorrect number of indices for load"); + + for (auto *idx : getIndices()) + if (!idx->getType().isIndex()) + return emitOpError("index to load must have 'index' type"); + + return success(); +} + ////////////////////////////////////////////////////////////////////////////// // RangeOp ////////////////////////////////////////////////////////////////////////////// -void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min, - Value *max, Value *step) { +void mlir::linalg::RangeOp::build(Builder *b, OperationState *result, + Value *min, Value *max, Value *step) { result->addOperands({min, max, step}); result->addTypes({RangeType::get(b->getContext())}); } // Verification is simply that a RangeOp takes 3 index ssa-value. -mlir::LogicalResult mlir::RangeOp::verify() { +LogicalResult mlir::linalg::RangeOp::verify() { if (!min() || !min()->getType().isa()) return emitOpError("first operand should be of type index"); if (!max() || !max()->getType().isa()) return emitOpError("second operand should be of type index"); if (!step() || !step()->getType().isa()) return emitOpError("third operand should be of type index"); - return mlir::success(); + return success(); } // A RangeOp prints as: @@ -129,12 +193,13 @@ mlir::LogicalResult mlir::RangeOp::verify() { // ```{.mlir} // linalg.range %0:%1:%2 : !linalg.range // ``` -void mlir::RangeOp::print(OpAsmPrinter *p) { +void mlir::linalg::RangeOp::print(OpAsmPrinter *p) { *p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step() << " : " << getType(); } -ParseResult mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult mlir::linalg::RangeOp::parse(OpAsmParser *parser, + OperationState *result) { SmallVector rangeInfo(3); RangeType type; auto affineIntTy = parser->getBuilder().getIndexType(); @@ -149,8 +214,8 @@ ParseResult mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) { ////////////////////////////////////////////////////////////////////////////// // SliceOp ////////////////////////////////////////////////////////////////////////////// -void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base, - ArrayRef indexings) { +void mlir::linalg::SliceOp::build(Builder *b, OperationState *result, + Value *base, ArrayRef indexings) { result->addOperands({base}); result->addOperands(indexings); @@ -163,7 +228,7 @@ void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base, result->addTypes({ViewType::get(b->getContext(), elementType, rank)}); } -LogicalResult mlir::SliceOp::verify() { +LogicalResult mlir::linalg::SliceOp::verify() { if (llvm::empty(getOperands())) return emitOpError( "requires at least a view operand followed by 'rank' indices"); @@ -193,7 +258,8 @@ LogicalResult mlir::SliceOp::verify() { return success(); } -ParseResult mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult mlir::linalg::SliceOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType baseInfo; SmallVector indexingsInfo; SmallVector types; @@ -241,11 +307,11 @@ ParseResult mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) { // // Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are // ssa-value each holding a range. -void mlir::SliceOp::print(OpAsmPrinter *p) { +void mlir::linalg::SliceOp::print(OpAsmPrinter *p) { *p << getOperationName() << " " << *getBaseView() << "["; interleave( - getIndexings().begin(), getIndexings().end(), - [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; }); + getIndexings().begin(), getIndexings().end(), [p](Value *v) { *p << *v; }, + [p]() { *p << ", "; }); *p << "] : " << getBaseViewType(); for (auto indexing : getIndexings()) { *p << ", " << indexing->getType(); @@ -253,15 +319,15 @@ void mlir::SliceOp::print(OpAsmPrinter *p) { *p << ", " << getType(); } -ViewOp mlir::SliceOp::getBaseViewOp() { +ViewOp mlir::linalg::SliceOp::getBaseViewOp() { return getOperand(0)->getDefiningOp()->cast(); } -ViewType mlir::SliceOp::getBaseViewType() { +ViewType mlir::linalg::SliceOp::getBaseViewType() { return getBaseViewOp().getType().cast(); } -SmallVector mlir::SliceOp::getRanges() { +SmallVector mlir::linalg::SliceOp::getRanges() { llvm::SmallVector res; for (auto *operand : getIndexings()) { if (!operand->getType().isa()) { @@ -271,11 +337,79 @@ SmallVector mlir::SliceOp::getRanges() { return res; } +//////////////////////////////////////////////////////////////////////////////// +// StoreOp. +//////////////////////////////////////////////////////////////////////////////// +void mlir::linalg::StoreOp::build(Builder *b, OperationState *result, + Value *valueToStore, Value *view, + ArrayRef indices) { + result->addOperands(valueToStore); + result->addOperands(view); + result->addOperands(indices); +} + +// A StoreOp prints as: +// +// ```{.mlir} +// linalg.store %f, %V[%c0] : !linalg.view +// ``` +void mlir::linalg::StoreOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getValueToStore(); + *p << ", " << *getView() << '['; + p->printOperands(getIndices()); + *p << ']'; + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getViewType(); +} + +ParseResult mlir::linalg::StoreOp::parse(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType storeValueInfo; + OpAsmParser::OperandType viewInfo; + SmallVector indexInfo; + ViewType viewType; + + auto affineIntTy = parser->getBuilder().getIndexType(); + return failure( + parser->parseOperand(storeValueInfo) || parser->parseComma() || + parser->parseOperand(viewInfo) || + parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(viewType) || + parser->resolveOperand(storeValueInfo, viewType.getElementType(), + result->operands) || + parser->resolveOperand(viewInfo, viewType, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands)); +} + +LogicalResult mlir::linalg::StoreOp::verify() { + if (getNumOperands() < 2) + return emitOpError("expected a value to store and a view"); + + // Second operand is a memref type. + auto viewType = getView()->getType().dyn_cast(); + if (!viewType) + return emitOpError("second operand must be a view"); + + // First operand must have same type as memref element type. + if (getValueToStore()->getType() != viewType.getElementType()) + return emitOpError("first operand must have same element type as the view"); + + if (getNumOperands() != 2 + viewType.getRank()) + return emitOpError("store index operand count not equal to view rank"); + + for (auto *idx : getIndices()) + if (!idx->getType().isIndex()) + return emitOpError("index to store must have 'index' type"); + + return success(); +} + ////////////////////////////////////////////////////////////////////////////// // ViewOp ////////////////////////////////////////////////////////////////////////////// -void mlir::ViewOp::build(Builder *b, OperationState *result, Value *buffer, - ArrayRef indexings) { +void mlir::linalg::ViewOp::build(Builder *b, OperationState *result, + Value *buffer, ArrayRef indexings) { BufferType bufferType = buffer->getType().cast(); result->addOperands({buffer}); result->addOperands(indexings); @@ -289,7 +423,7 @@ void mlir::ViewOp::build(Builder *b, OperationState *result, Value *buffer, {ViewType::get(b->getContext(), elementType, indexings.size())}); } -LogicalResult mlir::ViewOp::verify() { +LogicalResult mlir::linalg::ViewOp::verify() { if (llvm::empty(getOperands())) return emitOpError( "requires at least a buffer operand followed by indexings"); @@ -309,7 +443,8 @@ LogicalResult mlir::ViewOp::verify() { return success(); } -ParseResult mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType bufferInfo; SmallVector indexingsInfo; Type type; @@ -345,28 +480,30 @@ ParseResult mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) { // // Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each // holding a range. -void mlir::ViewOp::print(OpAsmPrinter *p) { +void mlir::linalg::ViewOp::print(OpAsmPrinter *p) { *p << getOperationName() << " " << *getSupportingBuffer() << "["; interleave( - getIndexings().begin(), getIndexings().end(), - [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; }); + getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; }, + [&]() { *p << ", "; }); *p << "] : " << getType(); } namespace mlir { +namespace linalg { namespace impl { -void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op); +void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op); ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result); -void printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op); +void printBufferSizeOp(OpAsmPrinter *p, Operation *op); ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result); } // namespace impl +} // namespace linalg /// Buffer size prints as: /// /// ``` {.mlir} /// %0 = linalg.buffer_size %arg0 : !linalg.buffer /// ``` -void mlir::impl::printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op) { +void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) { assert(op->getAbstractOperation() && "unregistered operation"); *p << op->cast().getOperationName() << " " << *op->getOperand(0); @@ -374,8 +511,8 @@ void mlir::impl::printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op) { *p << " : " << op->getOperand(0)->getType(); } -ParseResult mlir::impl::parseBufferSizeOp(OpAsmParser *parser, - OperationState *result) { +ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType op; Type type; return failure(parser->parseOperand(op) || @@ -405,20 +542,20 @@ ParseResult mlir::impl::parseBufferSizeOp(OpAsmParser *parser, // ``` // // Where %0, %1 and %2 are ssa-values of type ViewType. -void mlir::impl::printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op) { +void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) { assert(op->getAbstractOperation() && "unregistered operation"); *p << op->getName().getStringRef() << "("; interleave( op->getOperands().begin(), op->getOperands().end(), - [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; }); + [&](Value *v) { *p << *v; }, [&]() { *p << ", "; }); *p << ") : "; interleave( op->getOperands().begin(), op->getOperands().end(), - [&](mlir::Value *v) { *p << v->getType(); }, [&]() { *p << ", "; }); + [&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; }); } -ParseResult mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser, - OperationState *result) { +ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser, + OperationState *result) { SmallVector ops; SmallVector types; return failure( @@ -431,7 +568,7 @@ ParseResult mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser, // Ideally this should all be Tablegen'd but there is no good story for // AffineMap for now. -SmallVector mlir::loopToOperandRangesMaps(Operation *op) { +SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { MLIRContext *context = op->getContext(); auto i = getAffineDimExpr(0, context); auto j = getAffineDimExpr(1, context); diff --git a/mlir/lib/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Linalg/IR/LinalgTypes.cpp index 556d5d1..19105e8 100644 --- a/mlir/lib/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Linalg/IR/LinalgTypes.cpp @@ -26,18 +26,20 @@ #include "mlir/Support/LLVM.h" using namespace mlir; +using namespace mlir::linalg; -mlir::LinalgDialect::LinalgDialect(MLIRContext *context) +mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context) : Dialect("linalg", context) { addTypes(); - addOperations(); + addOperations(); addOperations< #define GET_OP_LIST #include "mlir/Linalg/IR/LinalgOps.cpp.inc" >(); } -struct mlir::BufferTypeStorage : public mlir::TypeStorage { +struct mlir::linalg::BufferTypeStorage : public TypeStorage { /// Underlying Key type to transport the payload needed to construct a custom /// type in a generic way. struct Key { @@ -71,13 +73,17 @@ private: Type elementType; }; -BufferType mlir::BufferType::get(MLIRContext *context, Type elementType) { +BufferType mlir::linalg::BufferType::get(MLIRContext *context, + Type elementType) { return Base::get(context, LinalgTypes::Buffer, elementType); } -Type mlir::BufferType::getElementType() { return getImpl()->getElementType(); } +Type mlir::linalg::BufferType::getElementType() { + return getImpl()->getElementType(); +} -Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const { +Type mlir::linalg::LinalgDialect::parseType(StringRef spec, + Location loc) const { MLIRContext *context = getContext(); if (spec == "range") return RangeType::get(getContext()); @@ -97,7 +103,7 @@ Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const { return (context->emitError(loc, "unknown Linalg type: " + spec), Type()); } -struct mlir::ViewTypeStorage : public mlir::TypeStorage { +struct mlir::linalg::ViewTypeStorage : public TypeStorage { /// Underlying Key type to transport the payload needed to construct a custom /// type in a generic way. struct Key { @@ -136,14 +142,16 @@ private: unsigned rank; }; -ViewType mlir::ViewType::get(MLIRContext *context, Type elementType, - unsigned rank) { +ViewType mlir::linalg::ViewType::get(MLIRContext *context, Type elementType, + unsigned rank) { return Base::get(context, LinalgTypes::View, elementType, rank); } -Type mlir::ViewType::getElementType() { return getImpl()->getElementType(); } +Type mlir::linalg::ViewType::getElementType() { + return getImpl()->getElementType(); +} -unsigned mlir::ViewType::getRank() { return getImpl()->getRank(); } +unsigned mlir::linalg::ViewType::getRank() { return getImpl()->getRank(); } /// BufferType prints as "buffer". static void print(BufferType bt, raw_ostream &os) { @@ -166,7 +174,7 @@ static void print(RangeType rt, raw_ostream &os) { os << "range"; } /// ``` /// /// for 0-D views (a.k.a pointer to a scalar value). -static void print(mlir::ViewType rt, raw_ostream &os) { +static void print(mlir::linalg::ViewType rt, raw_ostream &os) { os << "view<"; for (unsigned i = 0, e = rt.getRank(); i < e; ++i) { os << "?x"; @@ -175,7 +183,7 @@ static void print(mlir::ViewType rt, raw_ostream &os) { os << ">"; } -void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const { +void mlir::linalg::LinalgDialect::printType(Type type, raw_ostream &os) const { switch (type.getKind()) { default: llvm_unreachable("Unhandled Linalg type"); diff --git a/mlir/lib/Linalg/LinalgRegistration.cpp b/mlir/lib/Linalg/LinalgRegistration.cpp index 816b565..cf5bd8f 100644 --- a/mlir/lib/Linalg/LinalgRegistration.cpp +++ b/mlir/lib/Linalg/LinalgRegistration.cpp @@ -19,6 +19,7 @@ #include "mlir/Linalg/IR/LinalgTypes.h" using namespace mlir; +using namespace mlir::linalg; // Static initialization for LinalgOps dialect registration. static DialectRegistration LinalgOps; diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 108da6c..90111a8 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -30,6 +30,7 @@ #include "mlir/LLVMIR/Transforms.h" #include "mlir/Linalg/IR/LinalgOps.h" #include "mlir/Linalg/IR/LinalgTypes.h" +#include "mlir/Linalg/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" @@ -45,6 +46,7 @@ using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::LLVM; +using namespace mlir::linalg; using undef = ValueBuilder; using insertvalue = ValueBuilder; @@ -53,6 +55,11 @@ using constant = ValueBuilder; using add = ValueBuilder; using sub = ValueBuilder; using mul = ValueBuilder; +using bitcast = ValueBuilder; +using call = OperationBuilder; +using gep = ValueBuilder; +using llvm_load = ValueBuilder; +using llvm_store = OperationBuilder; template static llvm::Type *getPtrToElementType(T containerType, @@ -85,8 +92,8 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) { // Elem *ptr; // int64_t size; // }; - if (auto bufferTy = t.dyn_cast()) { - auto *ptrTy = getPtrToElementType(bufferTy, lowering); + if (auto bufferType = t.dyn_cast()) { + auto *ptrTy = getPtrToElementType(bufferType, lowering); auto *structTy = llvm::StructType::get(ptrTy, int64Ty); return LLVMType::get(context, structTy); } @@ -98,7 +105,7 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) { // int64_t max; // int64_t step; // }; - if (auto rangeTy = t.dyn_cast()) { + if (t.isa()) { auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty); return LLVMType::get(context, structTy); } @@ -126,9 +133,9 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) { // int64_t sizes[Rank]; // int64_t strides[Rank]; // }; - if (auto viewTy = t.dyn_cast()) { - auto *ptrTy = getPtrToElementType(viewTy, lowering); - auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank()); + if (auto viewType = t.dyn_cast()) { + auto *ptrTy = getPtrToElementType(viewType, lowering); + auto *arrayTy = llvm::ArrayType::get(int64Ty, viewType.getRank()); auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy); return LLVMType::get(context, structTy); } @@ -147,6 +154,106 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder, return builder.getArrayAttr(attrs); } +// BufferAllocOp creates a new `index` value. +class BufferAllocOpConversion : public LLVMOpLowering { +public: + explicit BufferAllocOpConversion(MLIRContext *context, + LLVMLowering &lowering_) + : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + auto indexType = IndexType::get(op->getContext()); + auto voidPtrTy = LLVM::LLVMType::get( + op->getContext(), + lowering.convertType(IntegerType::get(8, op->getContext())) + .cast() + .getUnderlyingType() + ->getPointerTo()); + auto int64Ty = lowering.convertType(operands[0]->getType()); + // Insert the `malloc` declaration if it is not already present. + Function *mallocFunc = + op->getFunction()->getModule()->getNamedFunction("malloc"); + if (!mallocFunc) { + auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); + mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType); + op->getFunction()->getModule()->getFunctions().push_back(mallocFunc); + } + + // Get MLIR types for injecting element pointer. + auto allocOp = op->cast(); + auto elementType = allocOp.getElementType(); + uint64_t elementSize = 0; + if (auto vectorType = elementType.dyn_cast()) + elementSize = vectorType.getNumElements() * + llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); + else + elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); + auto elementPtrType = rewriter.getType(getPtrToElementType( + allocOp.getResult()->getType().cast(), lowering)); + auto bufferDescriptorType = + convertLinalgType(allocOp.getResult()->getType(), lowering); + + // Emit IR for creating a new buffer descriptor with an underlying malloc. + edsc::ScopedContext context(rewriter, op->getLoc()); + Value *size = operands[0]; + Value *allocSize = + mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize))); + Value *allocated = + call(voidPtrTy, rewriter.getFunctionAttr(mallocFunc), allocSize) + .getOperation() + ->getResult(0); + allocated = bitcast(elementPtrType, allocated); + Value *desc = undef(bufferDescriptorType); + desc = insertvalue(bufferDescriptorType, desc, allocated, + makePositionAttr(rewriter, 0)); + desc = insertvalue(bufferDescriptorType, desc, size, + makePositionAttr(rewriter, 1)); + return {desc}; + } +}; + +// BufferDeallocOp creates a new `index` value. +class BufferDeallocOpConversion : public LLVMOpLowering { +public: + explicit BufferDeallocOpConversion(MLIRContext *context, + LLVMLowering &lowering_) + : LLVMOpLowering(BufferDeallocOp::getOperationName(), context, + lowering_) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + auto voidPtrTy = LLVM::LLVMType::get( + op->getContext(), + lowering.convertType(IntegerType::get(8, op->getContext())) + .cast() + .getUnderlyingType() + ->getPointerTo()); + // Insert the `free` declaration if it is not already present. + Function *freeFunc = + op->getFunction()->getModule()->getNamedFunction("free"); + if (!freeFunc) { + auto freeType = rewriter.getFunctionType(voidPtrTy, {}); + freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType); + op->getFunction()->getModule()->getFunctions().push_back(freeFunc); + } + + // Get MLIR types for extracting element pointer. + auto deallocOp = op->cast(); + auto elementPtrTy = rewriter.getType(getPtrToElementType( + deallocOp.getOperand()->getType().cast(), lowering)); + + // Emit MLIR for buffer_dealloc. + edsc::ScopedContext context(rewriter, op->getLoc()); + Value *casted = + bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0], + makePositionAttr(rewriter, 0))); + call(ArrayRef(), rewriter.getFunctionAttr(freeFunc), casted); + + return {}; + } +}; + // BufferSizeOp creates a new `index` value. class BufferSizeOpConversion : public LLVMOpLowering { public: @@ -155,10 +262,62 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto bufferSizeType = lowering.convertType(operands[0]->getType()); + auto int64Ty = lowering.convertType(operands[0]->getType()); edsc::ScopedContext context(rewriter, op->getLoc()); - return {extractvalue(bufferSizeType, operands[0], - makePositionAttr(rewriter, 1))}; + return {extractvalue(int64Ty, operands[0], makePositionAttr(rewriter, 1))}; + } +}; + +namespace { +// Common functionality for Linalg LoadOp and StoreOp conversion to the +// LLVM IR Dialect. +template class LoadStoreOpConversion : public LLVMOpLowering { +public: + explicit LoadStoreOpConversion(MLIRContext *context, LLVMLowering &lowering_) + : LLVMOpLowering(Op::getOperationName(), context, lowering_) {} + using Base = LoadStoreOpConversion; + + // Compute the pointer to an element of the buffer underlying the view given + // current view indices. Use the base offset and strides stored in the view + // descriptor to emit IR iteratively computing the actual offset, followed by + // a getelementptr. This must be called under an edsc::ScopedContext. + Value *obtainDataPtr(Operation *op, Value *viewDescriptor, + ArrayRef indices, FuncBuilder &rewriter) const { + auto loadOp = op->cast(); + auto elementTy = rewriter.getType( + getPtrToElementType(loadOp.getViewType(), lowering)); + auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); + auto pos = [&rewriter](ArrayRef values) { + return makePositionAttr(rewriter, values); + }; + + // Linearize subscripts as: + // base_offset + SUM_i index_i * stride_i. + Value *base = extractvalue(elementTy, viewDescriptor, pos(0)); + Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1)); + for (int i = 0, e = loadOp.getRank(); i < e; ++i) { + Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i})); + Value *additionalOffset = mul(indices[i], stride); + offset = add(offset, additionalOffset); + } + return gep(elementTy, base, offset); + } +}; +} // namespace + +// A load is converted into the actual address computation, getelementptr and +// an LLVM IR load. +class LoadOpConversion : public LoadStoreOpConversion { + using Base::Base; + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + edsc::ScopedContext edscContext(rewriter, op->getLoc()); + auto elementTy = lowering.convertType(*op->getResultTypes().begin()); + Value *viewDescriptor = operands[0]; + ArrayRef indices = operands.drop_front(); + auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); + Value *element = llvm_load(elementTy, ptr); + return {element}; } }; @@ -171,18 +330,18 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto rangeOp = op->cast(); - auto rangeDescriptorType = + auto rangeDescriptorTy = convertLinalgType(rangeOp.getResult()->getType(), lowering); edsc::ScopedContext context(rewriter, op->getLoc()); // Fill in an aggregate value of the descriptor. - Value *desc = undef(rangeDescriptorType); - desc = insertvalue(rangeDescriptorType, desc, operands[0], + Value *desc = undef(rangeDescriptorTy); + desc = insertvalue(rangeDescriptorTy, desc, operands[0], makePositionAttr(rewriter, 0)); - desc = insertvalue(rangeDescriptorType, desc, operands[1], + desc = insertvalue(rangeDescriptorTy, desc, operands[1], makePositionAttr(rewriter, 1)); - desc = insertvalue(rangeDescriptorType, desc, operands[2], + desc = insertvalue(rangeDescriptorTy, desc, operands[2], makePositionAttr(rewriter, 2)); return {desc}; @@ -197,8 +356,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto sliceOp = op->cast(); - auto viewDescriptorType = - convertLinalgType(sliceOp.getViewType(), lowering); + auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering); auto viewType = sliceOp.getBaseViewType(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); @@ -217,8 +375,8 @@ public: edsc::ScopedContext context(rewriter, op->getLoc()); // Declare the view descriptor and insert data ptr. - Value *desc = undef(viewDescriptorType); - desc = insertvalue(viewDescriptorType, desc, + Value *desc = undef(viewDescriptorTy); + desc = insertvalue(viewDescriptorTy, desc, getViewPtr(viewType, operands[0]), pos(0)); // TODO(ntv): extract sizes and emit asserts. @@ -238,7 +396,7 @@ public: Value *product = mul(min, strides[j]); baseOffset = add(baseOffset, product); } - desc = insertvalue(viewDescriptorType, desc, baseOffset, pos(1)); + desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1)); // Compute and insert view sizes (max - min along the range). Skip the // non-range operands as they will be projected away from the view. @@ -252,7 +410,7 @@ public: Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); Value *size = sub(max, min); - desc = insertvalue(viewDescriptorType, desc, size, pos({2, i})); + desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i})); ++i; } @@ -264,7 +422,7 @@ public: continue; Value *step = extractvalue(int64Ty, operands[1 + j], pos(2)); Value *stride = mul(strides[j], step); - desc = insertvalue(viewDescriptorType, desc, stride, pos({3, i})); + desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i})); ++i; } @@ -272,6 +430,22 @@ public: } }; +// A store is converted into the actual address computation, getelementptr and +// an LLVM IR store. +class StoreOpConversion : public LoadStoreOpConversion { + using Base::Base; + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + edsc::ScopedContext edscContext(rewriter, op->getLoc()); + Value *data = operands[0]; + Value *viewDescriptor = operands[1]; + ArrayRef indices = operands.drop_front(2); + Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); + llvm_store(data, ptr); + return {}; + } +}; + class ViewOpConversion : public LLVMOpLowering { public: explicit ViewOpConversion(MLIRContext *context, LLVMLowering &lowering_) @@ -280,8 +454,8 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { auto viewOp = op->cast(); - auto viewDescriptorType = convertLinalgType(viewOp.getViewType(), lowering); - auto elementType = rewriter.getType( + auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering); + auto elementTy = rewriter.getType( getPtrToElementType(viewOp.getViewType(), lowering)); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); @@ -294,16 +468,16 @@ public: // Declare the descriptor of the view. edsc::ScopedContext context(rewriter, op->getLoc()); - Value *desc = undef(viewDescriptorType); + Value *desc = undef(viewDescriptorTy); // Copy the buffer pointer from the old descriptor to the new one. - Value *buffer = extractvalue(elementType, bufferDescriptor, pos(0)); - desc = insertvalue(viewDescriptorType, desc, buffer, pos(0)); + Value *buffer = extractvalue(elementTy, bufferDescriptor, pos(0)); + desc = insertvalue(viewDescriptorTy, desc, buffer, pos(0)); // Zero base offset. auto indexTy = rewriter.getIndexType(); Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0)); - desc = insertvalue(viewDescriptorType, desc, baseOffset, pos(1)); + desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1)); // Compute and insert view sizes (max - min along the range). int numIndexings = llvm::size(viewOp.getIndexings()); @@ -313,12 +487,12 @@ public: Value *rangeDescriptor = operands[1 + i]; Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); Value *stride = mul(runningStride, step); - desc = insertvalue(viewDescriptorType, desc, stride, pos({3, i})); + desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i})); // Update size. Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); Value *size = sub(max, min); - desc = insertvalue(viewDescriptorType, desc, size, pos({2, i})); + desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i})); ++i; // Update stride for the next dimension. if (i < numIndexings - 1) @@ -346,7 +520,8 @@ public: "in lowering to LLVM "); auto fAttr = rewriter.getFunctionAttr(f); auto named = rewriter.getNamedAttr("callee", fAttr); - rewriter.create(op->getLoc(), operands, ArrayRef{named}); + rewriter.create(op->getLoc(), operands, + ArrayRef{named}); return {}; } }; @@ -357,10 +532,11 @@ class Lowering : public LLVMLowering { protected: llvm::DenseSet initAdditionalConverters() override { return ConversionListBuilder< - BufferSizeOpConversion, DotOpConversion, RangeOpConversion, - SliceOpConversion, ViewOpConversion>::build(&converterStorage, - llvmDialect->getContext(), - *this); + BufferAllocOpConversion, BufferDeallocOpConversion, + BufferSizeOpConversion, DotOpConversion, LoadOpConversion, + RangeOpConversion, SliceOpConversion, StoreOpConversion, + ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(), + *this); } Type convertAdditionalType(Type t) override { @@ -378,13 +554,17 @@ struct LowerLinalgToLLVMPass : public ModulePass { void LowerLinalgToLLVMPass::runOnModule() { auto &module = getModule(); + PassManager pm; + pm.addPass(createLowerAffinePass()); + if (failed(pm.run(&module))) + signalPassFailure(); + // Convert to the LLVM IR dialect using the converter defined above. - auto r = Lowering().convert(&module); - if (failed(r)) + if (failed(Lowering().convert(&module))) signalPassFailure(); } -ModulePassBase *createLowerLinalgToLLVMPass() { +ModulePassBase *mlir::createLowerLinalgToLLVMPass() { return new LowerLinalgToLLVMPass(); } diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 48ddb3d..434f720 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -37,6 +37,7 @@ using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; using namespace llvm; static llvm::cl::OptionCategory clOptionsCategory("linalg options"); diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index 0052ef0..4b77ece 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -33,6 +33,7 @@ using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; using namespace llvm; mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( diff --git a/mlir/test/mlir-cpu-runner/simple_linalg.mlir b/mlir/test/mlir-cpu-runner/simple_linalg.mlir new file mode 100644 index 0000000..119cea6 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/simple_linalg.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s + +func @linalg_dot(!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, + !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, + !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) { + return +} + +func @dot(%arg0: !linalg.buffer, %arg1: !linalg.buffer, %arg2: !linalg.buffer) -> f32 { + %c0 = constant 0 : index + %c1 = constant 1 : index + %s = linalg.buffer_size %arg0 : !linalg.buffer + %R = linalg.range %c0:%s:%c1 : !linalg.range + %A = linalg.view %arg0[%R] : !linalg.view + %B = linalg.view %arg1[%R] : !linalg.view + %C = linalg.view %arg2[] : !linalg.view + linalg.dot(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view + %res = linalg.load %C[] : !linalg.view + return %res : f32 +} + +func @fill_f32(%arg0 : !linalg.buffer, %f : f32) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %s = linalg.buffer_size %arg0 : !linalg.buffer + %R = linalg.range %c0:%s:%c1 : !linalg.range + %V = linalg.view %arg0[%R] : !linalg.view + affine.for %i0 = 0 to %s { + linalg.store %f, %V[%i0] : !linalg.view + } + return +} + +func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer { + %A = linalg.buffer_alloc %s : !linalg.buffer + call @fill_f32(%A, %f) : (!linalg.buffer, f32) -> () + return %A : !linalg.buffer +} + +func @entry1() -> f32 { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c16 = constant 16 : index + %f0 = constant 0.00000e+00 : f32 + %f1 = constant 0.00000e+00 : f32 + %f2 = constant 2.00000e+00 : f32 + + %A = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer) + %B = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer) + %C = call @alloc_filled_f32(%c1, %f0) : (index, f32) -> (!linalg.buffer) + %res = call @dot(%A, %B, %C) : (!linalg.buffer, !linalg.buffer, !linalg.buffer) -> (f32) + linalg.buffer_dealloc %C : !linalg.buffer + linalg.buffer_dealloc %B : !linalg.buffer + linalg.buffer_dealloc %A : !linalg.buffer + return %res : f32 +} + +// CHECK: 0.{{0+}}e+00 \ No newline at end of file diff --git a/mlir/tools/mlir-cpu-runner/CMakeLists.txt b/mlir/tools/mlir-cpu-runner/CMakeLists.txt index eff9409..844e8db 100644 --- a/mlir/tools/mlir-cpu-runner/CMakeLists.txt +++ b/mlir/tools/mlir-cpu-runner/CMakeLists.txt @@ -4,6 +4,7 @@ set(LIBS MLIREDSC MLIRExecutionEngine MLIRIR + MLIRLLVMIR MLIRParser MLIRTargetLLVMIR MLIRTransforms diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp index 5d65886..5deadb0 100644 --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Parser.h" #include "mlir/Support/FileUtilities.h" @@ -56,6 +57,10 @@ static llvm::cl::opt mainFuncName("e", llvm::cl::desc("The function to be called"), llvm::cl::value_desc(""), llvm::cl::init("main")); +static llvm::cl::opt mainFuncType( + "entry-point-result", + llvm::cl::desc("Textual description of the function type to be called"), + llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs")); static llvm::cl::OptionCategory optFlags("opt-like flags"); @@ -129,9 +134,9 @@ static void printMemRefArguments(ArrayRef argTypes, } } -static Error -compileAndExecute(Module *module, StringRef entryPoint, - std::function transformer) { +static Error compileAndExecuteFunctionWithMemRefs( + Module *module, StringRef entryPoint, + std::function transformer) { Function *mainFunction = module->getNamedFunction(entryPoint); if (!mainFunction || mainFunction->getBlocks().empty()) { return make_string_error("entry point not found"); @@ -167,6 +172,50 @@ compileAndExecute(Module *module, StringRef entryPoint, return Error::success(); } +static Error compileAndExecuteSingleFloatReturnFunction( + Module *module, StringRef entryPoint, + std::function transformer) { + Function *mainFunction = module->getNamedFunction(entryPoint); + if (!mainFunction || mainFunction->isExternal()) { + return make_string_error("entry point not found"); + } + + if (!mainFunction->getType().getInputs().empty()) + return make_string_error("function inputs not supported"); + + if (mainFunction->getType().getResults().size() != 1) + return make_string_error("only single f32 function result supported"); + + auto t = mainFunction->getType().getResults()[0].dyn_cast(); + if (!t) + return make_string_error("only single llvm.f32 function result supported"); + auto *llvmTy = t.getUnderlyingType(); + if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext())) + return make_string_error("only single llvm.f32 function result supported"); + + auto expectedEngine = mlir::ExecutionEngine::create(module, transformer); + if (!expectedEngine) + return expectedEngine.takeError(); + + auto engine = std::move(*expectedEngine); + auto expectedFPtr = engine->lookup(entryPoint); + if (!expectedFPtr) + return expectedFPtr.takeError(); + void (*fptr)(void **) = *expectedFPtr; + + float res; + struct { + void *data; + } data; + data.data = &res; + (*fptr)((void **)&data); + + // Intentional printing of the output so we can test. + llvm::outs() << res; + + return Error::success(); +} + int main(int argc, char **argv) { llvm::PrettyStackTraceProgram x(argc, argv); llvm::InitLLVM y(argc, argv); @@ -212,7 +261,11 @@ int main(int argc, char **argv) { auto transformer = mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition); - auto error = compileAndExecute(m.get(), mainFuncName.getValue(), transformer); + auto error = mainFuncType.getValue() == "f32" + ? compileAndExecuteSingleFloatReturnFunction( + m.get(), mainFuncName.getValue(), transformer) + : compileAndExecuteFunctionWithMemRefs( + m.get(), mainFuncName.getValue(), transformer); int exitCode = EXIT_SUCCESS; llvm::handleAllErrors(std::move(error), [&exitCode](const llvm::ErrorInfoBase &info) { -- 2.7.4