namespace linalg {
enum LinalgTypes {
- Buffer = Type::FIRST_LINALG_TYPE,
- Range,
+ Range = Type::FIRST_LINALG_TYPE,
LAST_USED_LINALG_TYPE = Range,
};
void printType(Type type, DialectAsmPrinter &os) const override;
};
-/// A BufferType represents a contiguous block of memory that can be allocated
-/// and deallocated. A buffer cannot be indexed directly, a view must be
-/// laid out on a buffer to give it indexing semantics.
-struct BufferTypeStorage;
-class BufferType : public Type::TypeBase<BufferType, Type, BufferTypeStorage> {
-public:
- // Used for generic hooks in TypeBase.
- using Base::Base;
- /// Construction hook.
- static BufferType get(MLIRContext *context, Type elementType,
- int64_t bufferSize = -1);
- /// Used to implement llvm-style cast.
- static bool kindof(unsigned kind) { return kind == LinalgTypes::Buffer; }
-
- // Type-specific functionality.
- Type getElementType();
- bool hasConstantSize();
- Optional<int64_t> getBufferSize();
-};
-
/// A RangeType represents a minimal range abstraction (min, max, step).
/// It is constructed by calling the linalg.range op with three values index of
/// index type:
auto int64Ty = lowering.convertType(IntegerType::get(64, context))
.cast<LLVM::LLVMType>();
- // A buffer descriptor contains the pointer to a flat region of storage and
- // the size of the region.
- //
- // template <typename Elem, size_t Rank>
- // struct {
- // void *baseAlloc;
- // Elem *ptr;
- // int64_t size;
- // };
- if (auto bufferType = t.dyn_cast<BufferType>()) {
- auto voidPtrTy = LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
- auto ptrTy = getPtrToElementType(bufferType, lowering);
- return LLVMType::getStructTy(voidPtrTy, ptrTy, int64Ty);
- }
-
// Range descriptor contains the range bounds and the step as 64-bit integers.
//
// struct {
return std::make_unique<ConvertLinalgToLLVMPass>();
}
-static PassRegistration<ConvertLinalgToLLVMPass>
- pass("convert-linalg-to-llvm",
- "Convert the operations from the linalg dialect into the LLVM dialect");
+static PassRegistration<ConvertLinalgToLLVMPass> pass(
+ "convert-linalg-to-llvm",
+ "Convert the operations from the linalg dialect into the LLVM dialect");
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
- addTypes<BufferType, RangeType>();
+ addTypes<RangeType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
>();
}
-
-struct mlir::linalg::BufferTypeStorage : public TypeStorage {
- /// Underlying Key type to transport the payload needed to construct a custom
- /// type in a generic way.
- struct Key {
- Key(Type elementType, int64_t bufferSize = -1)
- : elementType(elementType), bufferSize(bufferSize) {}
- Type elementType;
- int64_t bufferSize;
- };
- /// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
- using KeyTy = Key;
-
- /// Construction in the llvm::BumpPtrAllocator given a key.
- static BufferTypeStorage *construct(TypeStorageAllocator &allocator,
- const Key &key) {
- return new (allocator.allocate<BufferTypeStorage>()) BufferTypeStorage(key);
- }
-
- /// Equality operator for hashing.
- bool operator==(const Key &key) const {
- return elementType == key.elementType && bufferSize == key.bufferSize;
- }
-
- /// Hashing for unique'ing.
- static unsigned hashKey(const Key &key) {
- return llvm::hash_combine(key.elementType, key.bufferSize);
- }
-
- Type getElementType() { return elementType; }
- bool hasConstantSize() { return bufferSize >= 0; }
- Optional<int64_t> getBufferSize() {
- if (hasConstantSize()) {
- return bufferSize;
- }
- return llvm::None;
- }
-
-private:
- BufferTypeStorage(const Key &key)
- : elementType(key.elementType), bufferSize(key.bufferSize) {}
-
- Type elementType;
- int64_t bufferSize;
-};
-
-BufferType mlir::linalg::BufferType::get(MLIRContext *context, Type elementType,
- int64_t bufferSize) {
- return Base::get(context, LinalgTypes::Buffer, elementType, bufferSize);
-}
-
-Type mlir::linalg::BufferType::getElementType() {
- return getImpl()->getElementType();
-}
-
-bool mlir::linalg::BufferType::hasConstantSize() {
- return getImpl()->hasConstantSize();
-}
-
-Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
- return getImpl()->getBufferSize();
-}
-
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
// Parse the main keyword for the type.
StringRef keyword;
if (keyword == "range")
return RangeType::get(context);
- // Handle 'buffer' types.
- if (keyword == "buffer") {
- llvm::SMLoc dimensionLoc;
- SmallVector<int64_t, 1> size;
- Type type;
- if (parser.parseLess() || parser.getCurrentLocation(&dimensionLoc) ||
- parser.parseDimensionList(size) || parser.parseType(type) ||
- parser.parseGreater())
- return Type();
-
- if (size.size() != 1) {
- parser.emitError(dimensionLoc, "expected single element in size list");
- return Type();
- }
-
- return (size.front() == -1 ? BufferType::get(context, type)
- : BufferType::get(context, type, size.front()));
- }
-
parser.emitError(parser.getNameLoc(), "unknown Linalg type: " + keyword);
return Type();
}
-/// BufferType prints as "buffer<size x element_type>".
-static void print(BufferType bt, DialectAsmPrinter &os) {
- os << "buffer<";
- if (Optional<int64_t> bs = bt.getBufferSize())
- os << bs.getValue();
- else
- os << "?";
- os << "x" << bt.getElementType() << ">";
-}
-
/// RangeType prints as just "range".
static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
switch (type.getKind()) {
default:
llvm_unreachable("Unhandled Linalg type");
- case LinalgTypes::Buffer:
- print(type.cast<BufferType>(), os);
- break;
case LinalgTypes::Range:
print(type.cast<RangeType>(), os);
break;