#include "mlir/Support/LLVM.h"
namespace mlir {
+namespace linalg {
/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
/// upon which a base view can be laid out to give it indexing semantics.
}
};
+/// A linalg.LoadOp is the counterpart of load but operating on ViewType
+/// instead of MemRefType.
+///
+/// ```{.mlir}
+/// %0 = linalg.load %V[%c0] : !linalg.view<?xf32>
+/// ```
+class LoadOp
+ : public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
+public:
+ friend Operation;
+ using Op::Op;
+
+ // Hooks to customize the behavior of this op.
+ static llvm::StringRef getOperationName() { return "linalg.load"; }
+ static void build(Builder *b, OperationState *result, Value *view,
+ ArrayRef<Value *> indices = {});
+ LogicalResult verify();
+ static ParseResult parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
+
+ // Op-specific functionality.
+ unsigned getRank() { return getViewType().getRank(); }
+ ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
+ Value *getView() { return getOperand(0); }
+ Operation::operand_range getIndices() {
+ return {operand_begin() + 1, operand_end()};
+ }
+};
+
/// The "linalg.range" op creates a linalg.range from 3 values of type `index`
/// that represent the min, max and step values of the range.
///
/// !linalg.view<f32>
/// ```
class ViewOp;
-class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::VariadicOperands,
- mlir::OpTrait::OneResult,
- mlir::OpTrait::HasNoSideEffect> {
+class SliceOp : public Op<SliceOp, OpTrait::VariadicOperands,
+ OpTrait::OneResult, OpTrait::HasNoSideEffect> {
enum { FirstIndexingOperand = 1 };
public:
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.slice"; }
- static void build(mlir::Builder *b, mlir::OperationState *result,
- mlir::Value *base, llvm::ArrayRef<mlir::Value *> indexings);
- mlir::LogicalResult verify();
- static ParseResult parse(mlir::OpAsmParser *parser,
- mlir::OperationState *result);
- void print(mlir::OpAsmPrinter *p);
+ static void build(Builder *b, OperationState *result, Value *base,
+ llvm::ArrayRef<Value *> indexings);
+ LogicalResult verify();
+ static ParseResult parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
- mlir::Type getElementType() { return getViewType().getElementType(); }
+ Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
Value *getBaseView() { return getOperand(0); }
ViewOp getBaseViewOp();
ViewType getBaseViewType();
unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
// Get the underlying indexing at a given rank.
- mlir::Value *getIndexing(unsigned rank) {
- return *(getIndexings().begin() + rank);
- }
+ Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
// Get all the indexings in this view.
- mlir::Operation::operand_range getIndexings() {
+ Operation::operand_range getIndexings() {
return {operand_begin() + SliceOp::FirstIndexingOperand, operand_end()};
}
// Get the subset of indexings that are of RangeType.
SmallVector<Value *, 8> getRanges();
};
+/// A linalg.StoreOp is the counterpart of affine.store but operating on
+/// ViewType instead of MemRefType.
+///
+/// ```{.mlir}
+/// linalg.store %f, %V[%c0] : !linalg.view<?xf32>
+/// ```
+class StoreOp
+ : public Op<StoreOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+public:
+ friend Operation;
+ using Op::Op;
+
+ // Hooks to customize the behavior of this op.
+ static llvm::StringRef getOperationName() { return "linalg.store"; }
+ static void build(Builder *b, OperationState *result, Value *valueToStore,
+ Value *view, ArrayRef<Value *> indices = {});
+ LogicalResult verify();
+ static ParseResult parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
+
+ // Op-specific functionality.
+ unsigned getRank() { return getViewType().getRank(); }
+ ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
+ Value *getValueToStore() { return getOperand(0); }
+ Value *getView() { return getOperand(1); }
+ Operation::operand_range getIndices() {
+ return {operand_begin() + 2, operand_end()};
+ }
+};
+
/// The "linalg.view" op produces a linalg.view which is a multi-dimensional
/// range abstraction on top of an underlying linalg.buffer. This gives an
/// indexing structure to an otherwise non-indexable linalg.buffer.
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
/// ```
-class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands,
- mlir::OpTrait::OneResult,
- mlir::OpTrait::HasNoSideEffect> {
+class ViewOp : public Op<ViewOp, OpTrait::VariadicOperands, OpTrait::OneResult,
+ OpTrait::HasNoSideEffect> {
enum { FirstIndexingOperand = 1 };
public:
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.view"; }
- static void build(mlir::Builder *b, mlir::OperationState *result,
- mlir::Value *buffer,
- llvm::ArrayRef<mlir::Value *> indexings);
- mlir::LogicalResult verify();
- static ParseResult parse(mlir::OpAsmParser *parser,
- mlir::OperationState *result);
- void print(mlir::OpAsmPrinter *p);
+ static void build(Builder *b, OperationState *result, Value *buffer,
+ llvm::ArrayRef<Value *> indexings);
+ LogicalResult verify();
+ static ParseResult parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
- mlir::Type getElementType() { return getViewType().getElementType(); }
+ Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
- mlir::Value *getSupportingBuffer() { return getOperand(0); }
+ Value *getSupportingBuffer() { return getOperand(0); }
// Get the underlying indexing at a given rank.
- mlir::Value *getIndexing(unsigned rank) {
- return *(getIndexings().begin() + rank);
- }
+ Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
// Get all the indexings in this view.
- mlir::Operation::operand_range getIndexings() {
+ Operation::operand_range getIndexings() {
return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
}
};
/// )
/// ```
///
-/// Only permutation maps are currently supported.
+/// Only permutation maps are currently supported.
SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
+} // namespace linalg
} // namespace mlir
#endif // MLIR_LINALG_LINALGOPS_H_
def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
def View : Type<LinalgIsViewTypePred, "view">;
-class ParametricNativeOpTrait<string prop, string parameters> :
- NativeOpTrait<prop # parameters>
+class LinalgParametricNativeOpTrait<string prop, string parameters> :
+ NativeOpTrait<"linalg::" # prop # parameters>
{}
-class ParametricIntNativeOpTrait<string prop, list<int> parameters> :
- ParametricNativeOpTrait<
+class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
+ LinalgParametricNativeOpTrait<
prop,
!strconcat("<",
!cast<string>(!head(parameters)),
// to have a specified number of inputs and outputs, all passed as operands.
// See Linalg/LinalgTraits.h for implementation details an usage.
class NInputsAndOutputs<int n_ins, int n_outs> :
- ParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
+ LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
{}
// The linalg `NLoopTypes` trait provides the API for ops that are known to have
// loops.
// See Linalg/LinalgTraits.h for implementation details an usage.
class NLoopTypes<int n_par, int n_red, int n_win> :
-ParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
+LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
{}
// The linalg `ViewRanks` trait the API for ops that are known to have a
// specified list of view ranks.
// See Linalg/LinalgTraits.h for implementation details an usage.
class ViewRanks<list<int> ranks> :
-ParametricIntNativeOpTrait<"ViewRanks", ranks>
+LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
{}
// Base Tablegen class for Linalg ops.
namespace mlir {
namespace OpTrait {
+namespace linalg {
/// This class provides the API for ops that are known to have a specified
/// number of inputs and outputs, all passed as operands. This is used as a
Value *getOutput(unsigned i) {
return this->getOperand(getNumInputs() + i);
}
- ViewType getInputViewType(unsigned i) {
- return this->getOperand(i)->getType().template cast<ViewType>();
+ mlir::linalg::ViewType getInputViewType(unsigned i) {
+ return this->getOperand(i)
+ ->getType()
+ .template cast<mlir::linalg::ViewType>();
}
- ViewType getOutputViewType(unsigned i) {
+ mlir::linalg::ViewType getOutputViewType(unsigned i) {
return this->getOperand(getNumInputs() + i)
->getType()
- .template cast<ViewType>();
+ .template cast<mlir::linalg::ViewType>();
}
- ViewType getViewType(unsigned i) {
- return this->getOperand(i)->getType().template cast<ViewType>();
+ mlir::linalg::ViewType getViewType(unsigned i) {
+ return this->getOperand(i)
+ ->getType()
+ .template cast<mlir::linalg::ViewType>();
}
static LogicalResult verifyTrait(Operation *op) {
return OpTrait::impl::verifyAtLeastNOperands(op, NInputs + NOutputs);
if (op->getNumOperands() != ranks.size())
return op->emitError("expected " + Twine(ranks.size()) + " operands");
for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
- auto viewType = op->getOperand(i)->getType().dyn_cast<ViewType>();
+ auto viewType =
+ op->getOperand(i)->getType().dyn_cast<mlir::linalg::ViewType>();
if (!viewType)
return op->emitOpError("operand " + Twine(i) +
" must have view type ");
};
};
+} // namespace linalg
} // namespace OpTrait
} // namespace mlir
namespace mlir {
class MLIRContext;
+namespace linalg {
enum LinalgTypes {
Buffer = Type::FIRST_LINALG_TYPE,
Range,
unsigned getRank();
};
+} // namespace linalg
} // namespace mlir
#endif // MLIR_LINALG_LINALGTYPES_H_
mlir::ModulePassBase *
createLinalgTilingPass(llvm::ArrayRef<int64_t> tileSizes = {});
+
+mlir::ModulePassBase *createLowerLinalgToLLVMPass();
} // namespace mlir
#endif // MLIR_LINALG_PASSES_H_
#include "mlir/Support/STLExtras.h"
using namespace mlir;
+using namespace mlir::linalg;
//////////////////////////////////////////////////////////////////////////////
// BufferAllocOp
//////////////////////////////////////////////////////////////////////////////
-void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type,
- Value *size) {
+void mlir::linalg::BufferAllocOp::build(Builder *b, OperationState *result,
+ Type type, Value *size) {
result->addOperands({size});
result->addTypes(type);
}
-mlir::LogicalResult mlir::BufferAllocOp::verify() {
+LogicalResult mlir::linalg::BufferAllocOp::verify() {
if (!size() || !size()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!VectorType::isValidElementType(getElementType()) &&
!getElementType().isa<VectorType>())
return emitOpError("unsupported buffer element type");
- return mlir::success();
+ return success();
}
// A BufferAllocOp prints as:
// ```{.mlir}
// linalg.alloc %0 : !linalg.buffer<f32>
// ```
-void mlir::BufferAllocOp::print(OpAsmPrinter *p) {
+void mlir::linalg::BufferAllocOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *size() << " : " << getType();
}
-ParseResult mlir::BufferAllocOp::parse(OpAsmParser *parser,
- OperationState *result) {
+ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser,
+ OperationState *result) {
OpAsmParser::OperandType sizeInfo;
BufferType bufferType;
auto indexTy = parser->getBuilder().getIndexType();
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
return failure();
if (bufferType.getElementType() != parser->getBuilder().getF32Type())
- return parser->emitError(
- parser->getNameLoc(),
- "Only buffer<f32> supported until mlir::Parser pieces are exposed");
+ return parser->emitError(parser->getNameLoc(),
+ "Only buffer<f32> supported until "
+ "mlir::linalg::Parser pieces are exposed");
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types));
}
//////////////////////////////////////////////////////////////////////////////
// BufferDeallocOp
//////////////////////////////////////////////////////////////////////////////
-void mlir::BufferDeallocOp::build(Builder *b, OperationState *result,
- Value *buffer) {
+void mlir::linalg::BufferDeallocOp::build(Builder *b, OperationState *result,
+ Value *buffer) {
result->addOperands({buffer});
}
-mlir::LogicalResult mlir::BufferDeallocOp::verify() {
+LogicalResult mlir::linalg::BufferDeallocOp::verify() {
if (!getBuffer()->getType())
return emitOpError("first operand should be of type buffer");
- return mlir::success();
+ return success();
}
// A BufferDeallocOp prints as:
// ```{.mlir}
// linalg.dealloc %0 : !linalg.buffer<f32>
// ```
-void mlir::BufferDeallocOp::print(OpAsmPrinter *p) {
+void mlir::linalg::BufferDeallocOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
}
-ParseResult mlir::BufferDeallocOp::parse(OpAsmParser *parser,
- OperationState *result) {
+ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
+ OperationState *result) {
OpAsmParser::OperandType sizeInfo;
BufferType bufferType;
return failure(
parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
parser->resolveOperands(sizeInfo, bufferType, result->operands));
}
+
+////////////////////////////////////////////////////////////////////////////////
+// LoadOp.
+////////////////////////////////////////////////////////////////////////////////
+void mlir::linalg::LoadOp::build(Builder *b, OperationState *result,
+ Value *view, ArrayRef<Value *> indices) {
+ auto viewType = view->getType().cast<ViewType>();
+ result->addOperands(view);
+ result->addOperands(indices);
+ result->addTypes(viewType.getElementType());
+}
+
+// A LoadOp prints as:
+//
+// ```{.mlir}
+// %0 = linalg.load %V[%c0] : !linalg.view<?xf32>
+// ```
+void mlir::linalg::LoadOp::print(OpAsmPrinter *p) {
+ *p << getOperationName() << " " << *getView() << '[';
+ p->printOperands(getIndices());
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << getViewType();
+}
+
+ParseResult mlir::linalg::LoadOp::parse(OpAsmParser *parser,
+ OperationState *result) {
+ OpAsmParser::OperandType viewInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+ ViewType type;
+
+ auto affineIntTy = parser->getBuilder().getIndexType();
+ return failure(
+ parser->parseOperand(viewInfo) ||
+ parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(viewInfo, type, result->operands) ||
+ parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+ parser->addTypeToList(type.getElementType(), result->types));
+}
+
+LogicalResult mlir::linalg::LoadOp::verify() {
+ if (getNumOperands() == 0)
+ return emitOpError("expected a view to load from");
+
+ auto viewType = getView()->getType().dyn_cast<ViewType>();
+ if (!viewType)
+ return emitOpError("first operand must be a view");
+
+ if (getType() != viewType.getElementType())
+ return emitOpError("result type must match element type of the view");
+
+ if (getRank() != getNumOperands() - 1)
+ return emitOpError("incorrect number of indices for load");
+
+ for (auto *idx : getIndices())
+ if (!idx->getType().isIndex())
+ return emitOpError("index to load must have 'index' type");
+
+ return success();
+}
+
//////////////////////////////////////////////////////////////////////////////
// RangeOp
//////////////////////////////////////////////////////////////////////////////
-void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
- Value *max, Value *step) {
+void mlir::linalg::RangeOp::build(Builder *b, OperationState *result,
+ Value *min, Value *max, Value *step) {
result->addOperands({min, max, step});
result->addTypes({RangeType::get(b->getContext())});
}
// Verification is simply that a RangeOp takes 3 index ssa-value.
-mlir::LogicalResult mlir::RangeOp::verify() {
+LogicalResult mlir::linalg::RangeOp::verify() {
if (!min() || !min()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!max() || !max()->getType().isa<IndexType>())
return emitOpError("second operand should be of type index");
if (!step() || !step()->getType().isa<IndexType>())
return emitOpError("third operand should be of type index");
- return mlir::success();
+ return success();
}
// A RangeOp prints as:
// ```{.mlir}
// linalg.range %0:%1:%2 : !linalg.range
// ```
-void mlir::RangeOp::print(OpAsmPrinter *p) {
+void mlir::linalg::RangeOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
<< " : " << getType();
}
-ParseResult mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
+ParseResult mlir::linalg::RangeOp::parse(OpAsmParser *parser,
+ OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto affineIntTy = parser->getBuilder().getIndexType();
//////////////////////////////////////////////////////////////////////////////
// SliceOp
//////////////////////////////////////////////////////////////////////////////
-void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base,
- ArrayRef<Value *> indexings) {
+void mlir::linalg::SliceOp::build(Builder *b, OperationState *result,
+ Value *base, ArrayRef<Value *> indexings) {
result->addOperands({base});
result->addOperands(indexings);
result->addTypes({ViewType::get(b->getContext(), elementType, rank)});
}
-LogicalResult mlir::SliceOp::verify() {
+LogicalResult mlir::linalg::SliceOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a view operand followed by 'rank' indices");
return success();
}
-ParseResult mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
+ParseResult mlir::linalg::SliceOp::parse(OpAsmParser *parser,
+ OperationState *result) {
OpAsmParser::OperandType baseInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
SmallVector<Type, 8> types;
//
// Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are
// ssa-value each holding a range.
-void mlir::SliceOp::print(OpAsmPrinter *p) {
+void mlir::linalg::SliceOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getBaseView() << "[";
interleave(
- getIndexings().begin(), getIndexings().end(),
- [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
+ getIndexings().begin(), getIndexings().end(), [p](Value *v) { *p << *v; },
+ [p]() { *p << ", "; });
*p << "] : " << getBaseViewType();
for (auto indexing : getIndexings()) {
*p << ", " << indexing->getType();
*p << ", " << getType();
}
-ViewOp mlir::SliceOp::getBaseViewOp() {
+ViewOp mlir::linalg::SliceOp::getBaseViewOp() {
return getOperand(0)->getDefiningOp()->cast<ViewOp>();
}
-ViewType mlir::SliceOp::getBaseViewType() {
+ViewType mlir::linalg::SliceOp::getBaseViewType() {
return getBaseViewOp().getType().cast<ViewType>();
}
-SmallVector<Value *, 8> mlir::SliceOp::getRanges() {
+SmallVector<Value *, 8> mlir::linalg::SliceOp::getRanges() {
llvm::SmallVector<Value *, 8> res;
for (auto *operand : getIndexings()) {
if (!operand->getType().isa<IndexType>()) {
return res;
}
+////////////////////////////////////////////////////////////////////////////////
+// StoreOp.
+////////////////////////////////////////////////////////////////////////////////
+void mlir::linalg::StoreOp::build(Builder *b, OperationState *result,
+ Value *valueToStore, Value *view,
+ ArrayRef<Value *> indices) {
+ result->addOperands(valueToStore);
+ result->addOperands(view);
+ result->addOperands(indices);
+}
+
+// A StoreOp prints as:
+//
+// ```{.mlir}
+// linalg.store %f, %V[%c0] : !linalg.view<?xf32>
+// ```
+void mlir::linalg::StoreOp::print(OpAsmPrinter *p) {
+ *p << getOperationName() << " " << *getValueToStore();
+ *p << ", " << *getView() << '[';
+ p->printOperands(getIndices());
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << getViewType();
+}
+
+ParseResult mlir::linalg::StoreOp::parse(OpAsmParser *parser,
+ OperationState *result) {
+ OpAsmParser::OperandType storeValueInfo;
+ OpAsmParser::OperandType viewInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+ ViewType viewType;
+
+ auto affineIntTy = parser->getBuilder().getIndexType();
+ return failure(
+ parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+ parser->parseOperand(viewInfo) ||
+ parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(viewType) ||
+ parser->resolveOperand(storeValueInfo, viewType.getElementType(),
+ result->operands) ||
+ parser->resolveOperand(viewInfo, viewType, result->operands) ||
+ parser->resolveOperands(indexInfo, affineIntTy, result->operands));
+}
+
+LogicalResult mlir::linalg::StoreOp::verify() {
+ if (getNumOperands() < 2)
+ return emitOpError("expected a value to store and a view");
+
+ // Second operand is a memref type.
+ auto viewType = getView()->getType().dyn_cast<ViewType>();
+ if (!viewType)
+ return emitOpError("second operand must be a view");
+
+ // First operand must have same type as memref element type.
+ if (getValueToStore()->getType() != viewType.getElementType())
+ return emitOpError("first operand must have same element type as the view");
+
+ if (getNumOperands() != 2 + viewType.getRank())
+ return emitOpError("store index operand count not equal to view rank");
+
+ for (auto *idx : getIndices())
+ if (!idx->getType().isIndex())
+ return emitOpError("index to store must have 'index' type");
+
+ return success();
+}
+
//////////////////////////////////////////////////////////////////////////////
// ViewOp
//////////////////////////////////////////////////////////////////////////////
-void mlir::ViewOp::build(Builder *b, OperationState *result, Value *buffer,
- ArrayRef<Value *> indexings) {
+void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
+ Value *buffer, ArrayRef<Value *> indexings) {
BufferType bufferType = buffer->getType().cast<BufferType>();
result->addOperands({buffer});
result->addOperands(indexings);
{ViewType::get(b->getContext(), elementType, indexings.size())});
}
-LogicalResult mlir::ViewOp::verify() {
+LogicalResult mlir::linalg::ViewOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a buffer operand followed by indexings");
return success();
}
-ParseResult mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
+ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
+ OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type type;
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
// holding a range.
-void mlir::ViewOp::print(OpAsmPrinter *p) {
+void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
interleave(
- getIndexings().begin(), getIndexings().end(),
- [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
+ getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
+ [&]() { *p << ", "; });
*p << "] : " << getType();
}
namespace mlir {
+namespace linalg {
namespace impl {
-void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op);
+void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op);
ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
-void printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op);
+void printBufferSizeOp(OpAsmPrinter *p, Operation *op);
ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
} // namespace impl
+} // namespace linalg
/// Buffer size prints as:
///
/// ``` {.mlir}
/// %0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
/// ```
-void mlir::impl::printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op) {
+void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
*p << op->cast<BufferSizeOp>().getOperationName() << " "
<< *op->getOperand(0);
*p << " : " << op->getOperand(0)->getType();
}
-ParseResult mlir::impl::parseBufferSizeOp(OpAsmParser *parser,
- OperationState *result) {
+ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
+ OperationState *result) {
OpAsmParser::OperandType op;
Type type;
return failure(parser->parseOperand(op) ||
// ```
//
// Where %0, %1 and %2 are ssa-values of type ViewType.
-void mlir::impl::printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op) {
+void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
*p << op->getName().getStringRef() << "(";
interleave(
op->getOperands().begin(), op->getOperands().end(),
- [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
+ [&](Value *v) { *p << *v; }, [&]() { *p << ", "; });
*p << ") : ";
interleave(
op->getOperands().begin(), op->getOperands().end(),
- [&](mlir::Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
+ [&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
}
-ParseResult mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser,
- OperationState *result) {
+ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser,
+ OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types;
return failure(
// Ideally this should all be Tablegen'd but there is no good story for
// AffineMap for now.
-SmallVector<AffineMap, 4> mlir::loopToOperandRangesMaps(Operation *op) {
+SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
MLIRContext *context = op->getContext();
auto i = getAffineDimExpr(0, context);
auto j = getAffineDimExpr(1, context);
#include "mlir/Support/LLVM.h"
using namespace mlir;
+using namespace mlir::linalg;
-mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
+mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect("linalg", context) {
addTypes<BufferType, RangeType, ViewType>();
- addOperations<BufferAllocOp, BufferDeallocOp, RangeOp, SliceOp, ViewOp>();
+ addOperations<BufferAllocOp, BufferDeallocOp, LoadOp, RangeOp, StoreOp,
+ SliceOp, ViewOp>();
addOperations<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
>();
}
-struct mlir::BufferTypeStorage : public mlir::TypeStorage {
+struct mlir::linalg::BufferTypeStorage : public TypeStorage {
/// Underlying Key type to transport the payload needed to construct a custom
/// type in a generic way.
struct Key {
Type elementType;
};
-BufferType mlir::BufferType::get(MLIRContext *context, Type elementType) {
+BufferType mlir::linalg::BufferType::get(MLIRContext *context,
+ Type elementType) {
return Base::get(context, LinalgTypes::Buffer, elementType);
}
-Type mlir::BufferType::getElementType() { return getImpl()->getElementType(); }
+Type mlir::linalg::BufferType::getElementType() {
+ return getImpl()->getElementType();
+}
-Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
+Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
+ Location loc) const {
MLIRContext *context = getContext();
if (spec == "range")
return RangeType::get(getContext());
return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
}
-struct mlir::ViewTypeStorage : public mlir::TypeStorage {
+struct mlir::linalg::ViewTypeStorage : public TypeStorage {
/// Underlying Key type to transport the payload needed to construct a custom
/// type in a generic way.
struct Key {
unsigned rank;
};
-ViewType mlir::ViewType::get(MLIRContext *context, Type elementType,
- unsigned rank) {
+ViewType mlir::linalg::ViewType::get(MLIRContext *context, Type elementType,
+ unsigned rank) {
return Base::get(context, LinalgTypes::View, elementType, rank);
}
-Type mlir::ViewType::getElementType() { return getImpl()->getElementType(); }
+Type mlir::linalg::ViewType::getElementType() {
+ return getImpl()->getElementType();
+}
-unsigned mlir::ViewType::getRank() { return getImpl()->getRank(); }
+unsigned mlir::linalg::ViewType::getRank() { return getImpl()->getRank(); }
/// BufferType prints as "buffer<element_type>".
static void print(BufferType bt, raw_ostream &os) {
/// ```
///
/// for 0-D views (a.k.a pointer to a scalar value).
-static void print(mlir::ViewType rt, raw_ostream &os) {
+static void print(mlir::linalg::ViewType rt, raw_ostream &os) {
os << "view<";
for (unsigned i = 0, e = rt.getRank(); i < e; ++i) {
os << "?x";
os << ">";
}
-void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
+void mlir::linalg::LinalgDialect::printType(Type type, raw_ostream &os) const {
switch (type.getKind()) {
default:
llvm_unreachable("Unhandled Linalg type");
#include "mlir/Linalg/IR/LinalgTypes.h"
using namespace mlir;
+using namespace mlir::linalg;
// Static initialization for LinalgOps dialect registration.
static DialectRegistration<LinalgDialect> LinalgOps;
#include "mlir/LLVMIR/Transforms.h"
#include "mlir/Linalg/IR/LinalgOps.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Linalg/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::LLVM;
+using namespace mlir::linalg;
using undef = ValueBuilder<mlir::LLVM::UndefOp>;
using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
using add = ValueBuilder<mlir::LLVM::AddOp>;
using sub = ValueBuilder<mlir::LLVM::SubOp>;
using mul = ValueBuilder<mlir::LLVM::MulOp>;
+using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>;
+using call = OperationBuilder<mlir::LLVM::CallOp>;
+using gep = ValueBuilder<mlir::LLVM::GEPOp>;
+using llvm_load = ValueBuilder<LLVM::LoadOp>;
+using llvm_store = OperationBuilder<LLVM::StoreOp>;
template <typename T>
static llvm::Type *getPtrToElementType(T containerType,
// Elem *ptr;
// int64_t size;
// };
- if (auto bufferTy = t.dyn_cast<BufferType>()) {
- auto *ptrTy = getPtrToElementType(bufferTy, lowering);
+ if (auto bufferType = t.dyn_cast<BufferType>()) {
+ auto *ptrTy = getPtrToElementType(bufferType, lowering);
auto *structTy = llvm::StructType::get(ptrTy, int64Ty);
return LLVMType::get(context, structTy);
}
// int64_t max;
// int64_t step;
// };
- if (auto rangeTy = t.dyn_cast<RangeType>()) {
+ if (t.isa<RangeType>()) {
auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
return LLVMType::get(context, structTy);
}
// int64_t sizes[Rank];
// int64_t strides[Rank];
// };
- if (auto viewTy = t.dyn_cast<ViewType>()) {
- auto *ptrTy = getPtrToElementType(viewTy, lowering);
- auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
+ if (auto viewType = t.dyn_cast<ViewType>()) {
+ auto *ptrTy = getPtrToElementType(viewType, lowering);
+ auto *arrayTy = llvm::ArrayType::get(int64Ty, viewType.getRank());
auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy);
return LLVMType::get(context, structTy);
}
return builder.getArrayAttr(attrs);
}
+// BufferAllocOp creates a new `index` value.
+class BufferAllocOpConversion : public LLVMOpLowering {
+public:
+ explicit BufferAllocOpConversion(MLIRContext *context,
+ LLVMLowering &lowering_)
+ : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
+
+ SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+ FuncBuilder &rewriter) const override {
+ auto indexType = IndexType::get(op->getContext());
+ auto voidPtrTy = LLVM::LLVMType::get(
+ op->getContext(),
+ lowering.convertType(IntegerType::get(8, op->getContext()))
+ .cast<LLVM::LLVMType>()
+ .getUnderlyingType()
+ ->getPointerTo());
+ auto int64Ty = lowering.convertType(operands[0]->getType());
+ // Insert the `malloc` declaration if it is not already present.
+ Function *mallocFunc =
+ op->getFunction()->getModule()->getNamedFunction("malloc");
+ if (!mallocFunc) {
+ auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
+ mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
+ op->getFunction()->getModule()->getFunctions().push_back(mallocFunc);
+ }
+
+ // Get MLIR types for injecting element pointer.
+ auto allocOp = op->cast<BufferAllocOp>();
+ auto elementType = allocOp.getElementType();
+ uint64_t elementSize = 0;
+ if (auto vectorType = elementType.dyn_cast<VectorType>())
+ elementSize = vectorType.getNumElements() *
+ llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
+ else
+ elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
+ auto elementPtrType = rewriter.getType<LLVMType>(getPtrToElementType(
+ allocOp.getResult()->getType().cast<BufferType>(), lowering));
+ auto bufferDescriptorType =
+ convertLinalgType(allocOp.getResult()->getType(), lowering);
+
+ // Emit IR for creating a new buffer descriptor with an underlying malloc.
+ edsc::ScopedContext context(rewriter, op->getLoc());
+ Value *size = operands[0];
+ Value *allocSize =
+ mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
+ Value *allocated =
+ call(voidPtrTy, rewriter.getFunctionAttr(mallocFunc), allocSize)
+ .getOperation()
+ ->getResult(0);
+ allocated = bitcast(elementPtrType, allocated);
+ Value *desc = undef(bufferDescriptorType);
+ desc = insertvalue(bufferDescriptorType, desc, allocated,
+ makePositionAttr(rewriter, 0));
+ desc = insertvalue(bufferDescriptorType, desc, size,
+ makePositionAttr(rewriter, 1));
+ return {desc};
+ }
+};
+
+// BufferDeallocOp creates a new `index` value.
+class BufferDeallocOpConversion : public LLVMOpLowering {
+public:
+ explicit BufferDeallocOpConversion(MLIRContext *context,
+ LLVMLowering &lowering_)
+ : LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
+ lowering_) {}
+
+ SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+ FuncBuilder &rewriter) const override {
+ auto voidPtrTy = LLVM::LLVMType::get(
+ op->getContext(),
+ lowering.convertType(IntegerType::get(8, op->getContext()))
+ .cast<LLVM::LLVMType>()
+ .getUnderlyingType()
+ ->getPointerTo());
+ // Insert the `free` declaration if it is not already present.
+ Function *freeFunc =
+ op->getFunction()->getModule()->getNamedFunction("free");
+ if (!freeFunc) {
+ auto freeType = rewriter.getFunctionType(voidPtrTy, {});
+ freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
+ op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
+ }
+
+ // Get MLIR types for extracting element pointer.
+ auto deallocOp = op->cast<BufferDeallocOp>();
+ auto elementPtrTy = rewriter.getType<LLVMType>(getPtrToElementType(
+ deallocOp.getOperand()->getType().cast<BufferType>(), lowering));
+
+ // Emit MLIR for buffer_dealloc.
+ edsc::ScopedContext context(rewriter, op->getLoc());
+ Value *casted =
+ bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
+ makePositionAttr(rewriter, 0)));
+ call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
+
+ return {};
+ }
+};
+
// BufferSizeOp creates a new `index` value.
class BufferSizeOpConversion : public LLVMOpLowering {
public:
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto bufferSizeType = lowering.convertType(operands[0]->getType());
+ auto int64Ty = lowering.convertType(operands[0]->getType());
edsc::ScopedContext context(rewriter, op->getLoc());
- return {extractvalue(bufferSizeType, operands[0],
- makePositionAttr(rewriter, 1))};
+ return {extractvalue(int64Ty, operands[0], makePositionAttr(rewriter, 1))};
+ }
+};
+
+namespace {
+// Common functionality for Linalg LoadOp and StoreOp conversion to the
+// LLVM IR Dialect.
+template <typename Op> class LoadStoreOpConversion : public LLVMOpLowering {
+public:
+ explicit LoadStoreOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+ : LLVMOpLowering(Op::getOperationName(), context, lowering_) {}
+ using Base = LoadStoreOpConversion<Op>;
+
+ // Compute the pointer to an element of the buffer underlying the view given
+ // current view indices. Use the base offset and strides stored in the view
+ // descriptor to emit IR iteratively computing the actual offset, followed by
+ // a getelementptr. This must be called under an edsc::ScopedContext.
+ Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
+ ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
+ auto loadOp = op->cast<Op>();
+ auto elementTy = rewriter.getType<LLVMType>(
+ getPtrToElementType(loadOp.getViewType(), lowering));
+ auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
+ auto pos = [&rewriter](ArrayRef<int> values) {
+ return makePositionAttr(rewriter, values);
+ };
+
+ // Linearize subscripts as:
+ // base_offset + SUM_i index_i * stride_i.
+ Value *base = extractvalue(elementTy, viewDescriptor, pos(0));
+ Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1));
+ for (int i = 0, e = loadOp.getRank(); i < e; ++i) {
+ Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i}));
+ Value *additionalOffset = mul(indices[i], stride);
+ offset = add(offset, additionalOffset);
+ }
+ return gep(elementTy, base, offset);
+ }
+};
+} // namespace
+
+// A load is converted into the actual address computation, getelementptr and
+// an LLVM IR load.
+class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
+ using Base::Base;
+ SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+ FuncBuilder &rewriter) const override {
+ edsc::ScopedContext edscContext(rewriter, op->getLoc());
+ auto elementTy = lowering.convertType(*op->getResultTypes().begin());
+ Value *viewDescriptor = operands[0];
+ ArrayRef<Value *> indices = operands.drop_front();
+ auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
+ Value *element = llvm_load(elementTy, ptr);
+ return {element};
}
};
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto rangeOp = op->cast<RangeOp>();
- auto rangeDescriptorType =
+ auto rangeDescriptorTy =
convertLinalgType(rangeOp.getResult()->getType(), lowering);
edsc::ScopedContext context(rewriter, op->getLoc());
// Fill in an aggregate value of the descriptor.
- Value *desc = undef(rangeDescriptorType);
- desc = insertvalue(rangeDescriptorType, desc, operands[0],
+ Value *desc = undef(rangeDescriptorTy);
+ desc = insertvalue(rangeDescriptorTy, desc, operands[0],
makePositionAttr(rewriter, 0));
- desc = insertvalue(rangeDescriptorType, desc, operands[1],
+ desc = insertvalue(rangeDescriptorTy, desc, operands[1],
makePositionAttr(rewriter, 1));
- desc = insertvalue(rangeDescriptorType, desc, operands[2],
+ desc = insertvalue(rangeDescriptorTy, desc, operands[2],
makePositionAttr(rewriter, 2));
return {desc};
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto sliceOp = op->cast<SliceOp>();
- auto viewDescriptorType =
- convertLinalgType(sliceOp.getViewType(), lowering);
+ auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
auto viewType = sliceOp.getBaseViewType();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
edsc::ScopedContext context(rewriter, op->getLoc());
// Declare the view descriptor and insert data ptr.
- Value *desc = undef(viewDescriptorType);
- desc = insertvalue(viewDescriptorType, desc,
+ Value *desc = undef(viewDescriptorTy);
+ desc = insertvalue(viewDescriptorTy, desc,
getViewPtr(viewType, operands[0]), pos(0));
// TODO(ntv): extract sizes and emit asserts.
Value *product = mul(min, strides[j]);
baseOffset = add(baseOffset, product);
}
- desc = insertvalue(viewDescriptorType, desc, baseOffset, pos(1));
+ desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
// Compute and insert view sizes (max - min along the range). Skip the
// non-range operands as they will be projected away from the view.
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
Value *size = sub(max, min);
- desc = insertvalue(viewDescriptorType, desc, size, pos({2, i}));
+ desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
++i;
}
continue;
Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
Value *stride = mul(strides[j], step);
- desc = insertvalue(viewDescriptorType, desc, stride, pos({3, i}));
+ desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
++i;
}
}
};
+// A store is converted into the actual address computation, getelementptr and
+// an LLVM IR store.
+class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
+ using Base::Base;
+ SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+ FuncBuilder &rewriter) const override {
+ edsc::ScopedContext edscContext(rewriter, op->getLoc());
+ Value *data = operands[0];
+ Value *viewDescriptor = operands[1];
+ ArrayRef<Value *> indices = operands.drop_front(2);
+ Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
+ llvm_store(data, ptr);
+ return {};
+ }
+};
+
class ViewOpConversion : public LLVMOpLowering {
public:
explicit ViewOpConversion(MLIRContext *context, LLVMLowering &lowering_)
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto viewOp = op->cast<ViewOp>();
- auto viewDescriptorType = convertLinalgType(viewOp.getViewType(), lowering);
- auto elementType = rewriter.getType<LLVMType>(
+ auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
+ auto elementTy = rewriter.getType<LLVMType>(
getPtrToElementType(viewOp.getViewType(), lowering));
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Declare the descriptor of the view.
edsc::ScopedContext context(rewriter, op->getLoc());
- Value *desc = undef(viewDescriptorType);
+ Value *desc = undef(viewDescriptorTy);
// Copy the buffer pointer from the old descriptor to the new one.
- Value *buffer = extractvalue(elementType, bufferDescriptor, pos(0));
- desc = insertvalue(viewDescriptorType, desc, buffer, pos(0));
+ Value *buffer = extractvalue(elementTy, bufferDescriptor, pos(0));
+ desc = insertvalue(viewDescriptorTy, desc, buffer, pos(0));
// Zero base offset.
auto indexTy = rewriter.getIndexType();
Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
- desc = insertvalue(viewDescriptorType, desc, baseOffset, pos(1));
+ desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
// Compute and insert view sizes (max - min along the range).
int numIndexings = llvm::size(viewOp.getIndexings());
Value *rangeDescriptor = operands[1 + i];
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
Value *stride = mul(runningStride, step);
- desc = insertvalue(viewDescriptorType, desc, stride, pos({3, i}));
+ desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
// Update size.
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
Value *size = sub(max, min);
- desc = insertvalue(viewDescriptorType, desc, size, pos({2, i}));
+ desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
++i;
// Update stride for the next dimension.
if (i < numIndexings - 1)
"in lowering to LLVM ");
auto fAttr = rewriter.getFunctionAttr(f);
auto named = rewriter.getNamedAttr("callee", fAttr);
- rewriter.create<LLVM::CallOp>(op->getLoc(), operands, ArrayRef<NamedAttribute>{named});
+ rewriter.create<LLVM::CallOp>(op->getLoc(), operands,
+ ArrayRef<NamedAttribute>{named});
return {};
}
};
protected:
llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() override {
return ConversionListBuilder<
- BufferSizeOpConversion, DotOpConversion, RangeOpConversion,
- SliceOpConversion, ViewOpConversion>::build(&converterStorage,
- llvmDialect->getContext(),
- *this);
+ BufferAllocOpConversion, BufferDeallocOpConversion,
+ BufferSizeOpConversion, DotOpConversion, LoadOpConversion,
+ RangeOpConversion, SliceOpConversion, StoreOpConversion,
+ ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(),
+ *this);
}
Type convertAdditionalType(Type t) override {
void LowerLinalgToLLVMPass::runOnModule() {
auto &module = getModule();
+ PassManager pm;
+ pm.addPass(createLowerAffinePass());
+ if (failed(pm.run(&module)))
+ signalPassFailure();
+
// Convert to the LLVM IR dialect using the converter defined above.
- auto r = Lowering().convert(&module);
- if (failed(r))
+ if (failed(Lowering().convert(&module)))
signalPassFailure();
}
-ModulePassBase *createLowerLinalgToLLVMPass() {
+ModulePassBase *mlir::createLowerLinalgToLLVMPass() {
return new LowerLinalgToLLVMPass();
}
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
using namespace llvm;
static llvm::cl::OptionCategory clOptionsCategory("linalg options");
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
using namespace llvm;
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
--- /dev/null
+// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s
+
+func @linalg_dot(!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
+ !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
+ !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
+ return
+}
+
+func @dot(%arg0: !linalg.buffer<f32>, %arg1: !linalg.buffer<f32>, %arg2: !linalg.buffer<f32>) -> f32 {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %s = linalg.buffer_size %arg0 : !linalg.buffer<f32>
+ %R = linalg.range %c0:%s:%c1 : !linalg.range
+ %A = linalg.view %arg0[%R] : !linalg.view<?xf32>
+ %B = linalg.view %arg1[%R] : !linalg.view<?xf32>
+ %C = linalg.view %arg2[] : !linalg.view<f32>
+ linalg.dot(%A, %B, %C) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
+ %res = linalg.load %C[] : !linalg.view<f32>
+ return %res : f32
+}
+
+func @fill_f32(%arg0 : !linalg.buffer<f32>, %f : f32) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %s = linalg.buffer_size %arg0 : !linalg.buffer<f32>
+ %R = linalg.range %c0:%s:%c1 : !linalg.range
+ %V = linalg.view %arg0[%R] : !linalg.view<?xf32>
+ affine.for %i0 = 0 to %s {
+ linalg.store %f, %V[%i0] : !linalg.view<?xf32>
+ }
+ return
+}
+
+func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<f32> {
+ %A = linalg.buffer_alloc %s : !linalg.buffer<f32>
+ call @fill_f32(%A, %f) : (!linalg.buffer<f32>, f32) -> ()
+ return %A : !linalg.buffer<f32>
+}
+
+func @entry1() -> f32 {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c16 = constant 16 : index
+ %f0 = constant 0.00000e+00 : f32
+ %f1 = constant 0.00000e+00 : f32
+ %f2 = constant 2.00000e+00 : f32
+
+ %A = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer<f32>)
+ %B = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer<f32>)
+ %C = call @alloc_filled_f32(%c1, %f0) : (index, f32) -> (!linalg.buffer<f32>)
+ %res = call @dot(%A, %B, %C) : (!linalg.buffer<f32>, !linalg.buffer<f32>, !linalg.buffer<f32>) -> (f32)
+ linalg.buffer_dealloc %C : !linalg.buffer<f32>
+ linalg.buffer_dealloc %B : !linalg.buffer<f32>
+ linalg.buffer_dealloc %A : !linalg.buffer<f32>
+ return %res : f32
+}
+
+// CHECK: 0.{{0+}}e+00
\ No newline at end of file
MLIREDSC
MLIRExecutionEngine
MLIRIR
+ MLIRLLVMIR
MLIRParser
MLIRTargetLLVMIR
MLIRTransforms
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
mainFuncName("e", llvm::cl::desc("The function to be called"),
llvm::cl::value_desc("<function name>"),
llvm::cl::init("main"));
+static llvm::cl::opt<std::string> mainFuncType(
+ "entry-point-result",
+ llvm::cl::desc("Textual description of the function type to be called"),
+ llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
static llvm::cl::OptionCategory optFlags("opt-like flags");
}
}
-static Error
-compileAndExecute(Module *module, StringRef entryPoint,
- std::function<llvm::Error(llvm::Module *)> transformer) {
+static Error compileAndExecuteFunctionWithMemRefs(
+ Module *module, StringRef entryPoint,
+ std::function<llvm::Error(llvm::Module *)> transformer) {
Function *mainFunction = module->getNamedFunction(entryPoint);
if (!mainFunction || mainFunction->getBlocks().empty()) {
return make_string_error("entry point not found");
return Error::success();
}
+static Error compileAndExecuteSingleFloatReturnFunction(
+ Module *module, StringRef entryPoint,
+ std::function<llvm::Error(llvm::Module *)> transformer) {
+ Function *mainFunction = module->getNamedFunction(entryPoint);
+ if (!mainFunction || mainFunction->isExternal()) {
+ return make_string_error("entry point not found");
+ }
+
+ if (!mainFunction->getType().getInputs().empty())
+ return make_string_error("function inputs not supported");
+
+ if (mainFunction->getType().getResults().size() != 1)
+ return make_string_error("only single f32 function result supported");
+
+ auto t = mainFunction->getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
+ if (!t)
+ return make_string_error("only single llvm.f32 function result supported");
+ auto *llvmTy = t.getUnderlyingType();
+ if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
+ return make_string_error("only single llvm.f32 function result supported");
+
+ auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
+ if (!expectedEngine)
+ return expectedEngine.takeError();
+
+ auto engine = std::move(*expectedEngine);
+ auto expectedFPtr = engine->lookup(entryPoint);
+ if (!expectedFPtr)
+ return expectedFPtr.takeError();
+ void (*fptr)(void **) = *expectedFPtr;
+
+ float res;
+ struct {
+ void *data;
+ } data;
+ data.data = &res;
+ (*fptr)((void **)&data);
+
+ // Intentional printing of the output so we can test.
+ llvm::outs() << res;
+
+ return Error::success();
+}
+
int main(int argc, char **argv) {
llvm::PrettyStackTraceProgram x(argc, argv);
llvm::InitLLVM y(argc, argv);
auto transformer =
mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
- auto error = compileAndExecute(m.get(), mainFuncName.getValue(), transformer);
+ auto error = mainFuncType.getValue() == "f32"
+ ? compileAndExecuteSingleFloatReturnFunction(
+ m.get(), mainFuncName.getValue(), transformer)
+ : compileAndExecuteFunctionWithMemRefs(
+ m.get(), mainFuncName.getValue(), transformer);
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
[&exitCode](const llvm::ErrorInfoBase &info) {