From 75be1fe82b124b2679ff4c34f16e79ea58ee2d01 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 18 Apr 2019 13:56:18 -0700 Subject: [PATCH] [Linalg] Add a simple buffer type with alloc/dealloc ops This CL adds a linalg.buffer type and buffer_alloc / buffer_dealloc ops with the proper roundtripping test. -- PiperOrigin-RevId: 244252306 --- mlir/include/mlir/Linalg/LinalgOps.h | 46 +++++++++++++++++++-- mlir/include/mlir/Linalg/LinalgTypes.h | 19 ++++++++- mlir/lib/Linalg/LinalgOps.cpp | 75 ++++++++++++++++++++++++++++++++++ mlir/lib/Linalg/LinalgTypes.cpp | 54 +++++++++++++++++++++++- mlir/test/Linalg/roundtrip.mlir | 13 +++++- 5 files changed, 199 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Linalg/LinalgOps.h b/mlir/include/mlir/Linalg/LinalgOps.h index 7921822d..8e83d7c 100644 --- a/mlir/include/mlir/Linalg/LinalgOps.h +++ b/mlir/include/mlir/Linalg/LinalgOps.h @@ -19,10 +19,52 @@ #define MLIR_LINALG_LINALGOPS_H_ #include "mlir/IR/OpDefinition.h" +#include "mlir/Linalg/LinalgTypes.h" #include "mlir/Support/LLVM.h" namespace mlir { +/// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base +/// view can be laid out. The size argument is an `i64` (and not an index), so +/// that we can +class BufferAllocOp + : public Op { +public: + using Op::Op; + + // Hooks to customize the behavior of this op. + static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; } + static void build(Builder *b, OperationState *result, Type type, Value *size); + LogicalResult verify(); + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + + // Op-specific functionality. + Value *size() { return getOperand(); } + BufferType getBufferType() { return getType().cast(); } + Type getElementType() { return getBufferType().getElementType(); } +}; + +/// A BufferDeallocOp is used to free a !linalg.buffer. +class BufferDeallocOp + : public Op { +public: + using Op::Op; + + // Hooks to customize the behavior of this op. + static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; } + static void build(Builder *b, OperationState *result, Value *buffer); + LogicalResult verify(); + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + + // Op-specific functionality. + Value *getBuffer() { return getOperand(); } + BufferType getBufferType() { + return getOperand()->getType().cast(); + } +}; + /// A RangeOp is used to create a value of RangeType from 3 values of type index /// that represent the min, max and step values of the range. class RangeOp : public Op::Impl, @@ -30,9 +72,7 @@ class RangeOp : public Op::Impl, public: using Op::Op; - ////////////////////////////////////////////////////////////////////////////// // Hooks to customize the behavior of this op. - ////////////////////////////////////////////////////////////////////////////// static llvm::StringRef getOperationName() { return "linalg.range"; } static void build(Builder *b, OperationState *result, Value *min, Value *max, Value *step); @@ -40,9 +80,7 @@ public: static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); - ////////////////////////////////////////////////////////////////////////////// // Op-specific functionality. - ////////////////////////////////////////////////////////////////////////////// Value *min() { return getOperand(0); } Value *max() { return getOperand(1); } Value *step() { return getOperand(2); } diff --git a/mlir/include/mlir/Linalg/LinalgTypes.h b/mlir/include/mlir/Linalg/LinalgTypes.h index 2d2c74e..f08d071 100644 --- a/mlir/include/mlir/Linalg/LinalgTypes.h +++ b/mlir/include/mlir/Linalg/LinalgTypes.h @@ -25,7 +25,8 @@ namespace mlir { class MLIRContext; enum LinalgTypes { - Range = Type::FIRST_LINALG_TYPE, + Buffer = Type::FIRST_LINALG_TYPE, + Range, LAST_USED_LINALG_TYPE = Range, }; @@ -40,6 +41,22 @@ public: void printType(Type type, llvm::raw_ostream &os) const override; }; +/// A BufferType represents a minimal range abstraction (min, max, step). +class BufferTypeStorage; +class BufferType : public Type::TypeBase { +public: + // Used for generic hooks in TypeBase. + using Base::Base; + /// Construction hook. + static BufferType get(MLIRContext *context, Type elementType); + /// Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == LinalgTypes::Buffer; } + ////////////////////////////////////////////////////////////////////////////// + // Type-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + Type getElementType(); +}; + /// A RangeType represents a minimal range abstraction (min, max, step). class RangeType : public Type::TypeBase { public: diff --git a/mlir/lib/Linalg/LinalgOps.cpp b/mlir/lib/Linalg/LinalgOps.cpp index bba47fb..c6260d8 100644 --- a/mlir/lib/Linalg/LinalgOps.cpp +++ b/mlir/lib/Linalg/LinalgOps.cpp @@ -28,6 +28,9 @@ using namespace mlir; +////////////////////////////////////////////////////////////////////////////// +// RangeOp +////////////////////////////////////////////////////////////////////////////// void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min, Value *max, Value *step) { result->addOperands({min, max, step}); @@ -65,3 +68,75 @@ bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) { parser->resolveOperands(rangeInfo, affineIntTy, result->operands) || parser->addTypeToList(type, result->types); } + +////////////////////////////////////////////////////////////////////////////// +// BufferAllocOp +////////////////////////////////////////////////////////////////////////////// +void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type, + Value *size) { + result->addOperands({size}); + result->addTypes(type); +} + +mlir::LogicalResult mlir::BufferAllocOp::verify() { + if (!size() || !size()->getType().isa() || + !size()->getType().cast().isInteger(64)) + return emitOpError("first operand should be of type i64"); + if (!VectorType::isValidElementType(getElementType()) && + !getElementType().isa()) + return emitOpError("unsupported buffer element type"); + return mlir::success(); +} + +// A BufferAllocOp prints as: +// +// ```{.mlir} +// linalg.alloc %0 : !linalg.buffer +// ``` +void mlir::BufferAllocOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *size() << " : " << getType(); +} + +bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType sizeInfo; + BufferType bufferType; + auto int64Ty = parser->getBuilder().getIntegerType(64); + if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType)) + return true; + if (bufferType.getElementType() != parser->getBuilder().getF32Type()) + return parser->emitError( + parser->getNameLoc(), + "Only buffer supported until mlir::Parser pieces are exposed"); + return parser->resolveOperands(sizeInfo, int64Ty, result->operands) || + parser->addTypeToList(bufferType, result->types); +} + +////////////////////////////////////////////////////////////////////////////// +// BufferDeallocOp +////////////////////////////////////////////////////////////////////////////// +void mlir::BufferDeallocOp::build(Builder *b, OperationState *result, + Value *buffer) { + result->addOperands({buffer}); +} + +mlir::LogicalResult mlir::BufferDeallocOp::verify() { + if (!getBuffer()->getType()) + return emitOpError("first operand should be of type buffer"); + return mlir::success(); +} + +// A BufferDeallocOp prints as: +// +// ```{.mlir} +// linalg.dealloc %0 : !linalg.buffer +// ``` +void mlir::BufferDeallocOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getBuffer() << " : " << getBufferType(); +} + +bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType sizeInfo; + BufferType bufferType; + return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) || + parser->resolveOperands(sizeInfo, bufferType, result->operands); +} diff --git a/mlir/lib/Linalg/LinalgTypes.cpp b/mlir/lib/Linalg/LinalgTypes.cpp index 7aabfd2..822cd70 100644 --- a/mlir/lib/Linalg/LinalgTypes.cpp +++ b/mlir/lib/Linalg/LinalgTypes.cpp @@ -21,6 +21,7 @@ #include "mlir/Linalg/LinalgTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/Linalg/LinalgOps.h" #include "mlir/Support/LLVM.h" @@ -28,24 +29,73 @@ using namespace mlir; mlir::LinalgDialect::LinalgDialect(MLIRContext *context) : Dialect("linalg", context) { - addTypes(); - addOperations(); + addTypes(); + addOperations(); } +struct mlir::BufferTypeStorage : public mlir::TypeStorage { + /// Underlying Key type to transport the payload needed to construct a custom + /// type in a generic way. + struct Key { + Key(Type elementType) : elementType(elementType) {} + Type elementType; + }; + /// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing. + using KeyTy = Key; + + /// Construction in the llvm::BumpPtrAllocator given a key. + static BufferTypeStorage *construct(TypeStorageAllocator &allocator, + const Key &key) { + return new (allocator.allocate()) BufferTypeStorage(key); + } + + /// Equality operator for hashing. + bool operator==(const Key &key) const { + return elementType == key.elementType; + } + + /// Hashing for unique'ing. + static unsigned hashKey(const Key &key) { + return llvm::hash_combine(key.elementType); + } + + Type getElementType() { return elementType; }; + +private: + BufferTypeStorage(const Key &key) : elementType(key.elementType) {} + + Type elementType; +}; + +BufferType mlir::BufferType::get(MLIRContext *context, Type elementType) { + return Base::get(context, LinalgTypes::Buffer, elementType); +} + +Type mlir::BufferType::getElementType() { return getImpl()->getElementType(); } + Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const { MLIRContext *context = getContext(); if (spec == "range") return RangeType::get(getContext()); + // TODO(ntv): reuse mlir Parser once exposed. + if (spec == "buffer") + return BufferType::get(getContext(), FloatType::getF32(getContext())); return (context->emitError(loc, "unknown Linalg type: " + spec), Type()); } /// RangeType prints as just "range". +static void print(BufferType bt, raw_ostream &os) { + os << "buffer<" << bt.getElementType() << ">"; +} static void print(RangeType rt, raw_ostream &os) { os << "range"; } void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const { switch (type.getKind()) { default: llvm_unreachable("Unhandled Linalg type"); + case LinalgTypes::Buffer: + print(type.cast(), os); + break; case LinalgTypes::Range: print(type.cast(), os); break; diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir index f98558a..8544ed2 100644 --- a/mlir/test/Linalg/roundtrip.mlir +++ b/mlir/test/Linalg/roundtrip.mlir @@ -5,4 +5,15 @@ func @range(%arg0: index, %arg1: index, %arg2: index) { return } // CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) { -// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range \ No newline at end of file +// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range + +func @buffer(%arg0: i64, %arg1: i64) { + %0 = muli %arg0, %arg0 : i64 + %1 = linalg.buffer_alloc %0 : !linalg.buffer + linalg.buffer_dealloc %1 : !linalg.buffer + return +} +// CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) { +// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64 +// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer +// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer \ No newline at end of file -- 2.7.4