#define MLIR_LINALG_LINALGOPS_H_
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Linalg/LinalgTypes.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
+/// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base
+/// view can be laid out. The size argument is an `i64` (and not an index), so
+/// that we can
+class BufferAllocOp
+ : public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
+public:
+ using Op::Op;
+
+ // Hooks to customize the behavior of this op.
+ static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; }
+ static void build(Builder *b, OperationState *result, Type type, Value *size);
+ LogicalResult verify();
+ static bool parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
+
+ // Op-specific functionality.
+ Value *size() { return getOperand(); }
+ BufferType getBufferType() { return getType().cast<BufferType>(); }
+ Type getElementType() { return getBufferType().getElementType(); }
+};
+
+/// A BufferDeallocOp is used to free a !linalg.buffer.
+class BufferDeallocOp
+ : public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
+public:
+ using Op::Op;
+
+ // Hooks to customize the behavior of this op.
+ static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; }
+ static void build(Builder *b, OperationState *result, Value *buffer);
+ LogicalResult verify();
+ static bool parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
+
+ // Op-specific functionality.
+ Value *getBuffer() { return getOperand(); }
+ BufferType getBufferType() {
+ return getOperand()->getType().cast<BufferType>();
+ }
+};
+
/// A RangeOp is used to create a value of RangeType from 3 values of type index
/// that represent the min, max and step values of the range.
class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
public:
using Op::Op;
- //////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
- //////////////////////////////////////////////////////////////////////////////
static llvm::StringRef getOperationName() { return "linalg.range"; }
static void build(Builder *b, OperationState *result, Value *min, Value *max,
Value *step);
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
- //////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
- //////////////////////////////////////////////////////////////////////////////
Value *min() { return getOperand(0); }
Value *max() { return getOperand(1); }
Value *step() { return getOperand(2); }
class MLIRContext;
enum LinalgTypes {
- Range = Type::FIRST_LINALG_TYPE,
+ Buffer = Type::FIRST_LINALG_TYPE,
+ Range,
LAST_USED_LINALG_TYPE = Range,
};
void printType(Type type, llvm::raw_ostream &os) const override;
};
+/// A BufferType represents a minimal range abstraction (min, max, step).
+class BufferTypeStorage;
+class BufferType : public Type::TypeBase<BufferType, Type, BufferTypeStorage> {
+public:
+ // Used for generic hooks in TypeBase.
+ using Base::Base;
+ /// Construction hook.
+ static BufferType get(MLIRContext *context, Type elementType);
+ /// Used to implement llvm-style cast.
+ static bool kindof(unsigned kind) { return kind == LinalgTypes::Buffer; }
+ //////////////////////////////////////////////////////////////////////////////
+ // Type-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ Type getElementType();
+};
+
/// A RangeType represents a minimal range abstraction (min, max, step).
class RangeType : public Type::TypeBase<RangeType, Type> {
public:
using namespace mlir;
+//////////////////////////////////////////////////////////////////////////////
+// RangeOp
+//////////////////////////////////////////////////////////////////////////////
void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
Value *max, Value *step) {
result->addOperands({min, max, step});
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types);
}
+
+//////////////////////////////////////////////////////////////////////////////
+// BufferAllocOp
+//////////////////////////////////////////////////////////////////////////////
+void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type,
+ Value *size) {
+ result->addOperands({size});
+ result->addTypes(type);
+}
+
+mlir::LogicalResult mlir::BufferAllocOp::verify() {
+ if (!size() || !size()->getType().isa<IntegerType>() ||
+ !size()->getType().cast<IntegerType>().isInteger(64))
+ return emitOpError("first operand should be of type i64");
+ if (!VectorType::isValidElementType(getElementType()) &&
+ !getElementType().isa<VectorType>())
+ return emitOpError("unsupported buffer element type");
+ return mlir::success();
+}
+
+// A BufferAllocOp prints as:
+//
+// ```{.mlir}
+// linalg.alloc %0 : !linalg.buffer<f32>
+// ```
+void mlir::BufferAllocOp::print(OpAsmPrinter *p) {
+ *p << getOperationName() << " " << *size() << " : " << getType();
+}
+
+bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType sizeInfo;
+ BufferType bufferType;
+ auto int64Ty = parser->getBuilder().getIntegerType(64);
+ if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
+ return true;
+ if (bufferType.getElementType() != parser->getBuilder().getF32Type())
+ return parser->emitError(
+ parser->getNameLoc(),
+ "Only buffer<f32> supported until mlir::Parser pieces are exposed");
+ return parser->resolveOperands(sizeInfo, int64Ty, result->operands) ||
+ parser->addTypeToList(bufferType, result->types);
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// BufferDeallocOp
+//////////////////////////////////////////////////////////////////////////////
+void mlir::BufferDeallocOp::build(Builder *b, OperationState *result,
+ Value *buffer) {
+ result->addOperands({buffer});
+}
+
+mlir::LogicalResult mlir::BufferDeallocOp::verify() {
+ if (!getBuffer()->getType())
+ return emitOpError("first operand should be of type buffer");
+ return mlir::success();
+}
+
+// A BufferDeallocOp prints as:
+//
+// ```{.mlir}
+// linalg.dealloc %0 : !linalg.buffer<f32>
+// ```
+void mlir::BufferDeallocOp::print(OpAsmPrinter *p) {
+ *p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
+}
+
+bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType sizeInfo;
+ BufferType bufferType;
+ return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
+ parser->resolveOperands(sizeInfo, bufferType, result->operands);
+}
#include "mlir/Linalg/LinalgTypes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/StandardTypes.h"
#include "mlir/Linalg/LinalgOps.h"
#include "mlir/Support/LLVM.h"
mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect("linalg", context) {
- addTypes<RangeType>();
- addOperations<RangeOp>();
+ addTypes<BufferType, RangeType>();
+ addOperations<BufferAllocOp, BufferDeallocOp, RangeOp>();
}
+struct mlir::BufferTypeStorage : public mlir::TypeStorage {
+ /// Underlying Key type to transport the payload needed to construct a custom
+ /// type in a generic way.
+ struct Key {
+ Key(Type elementType) : elementType(elementType) {}
+ Type elementType;
+ };
+ /// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
+ using KeyTy = Key;
+
+ /// Construction in the llvm::BumpPtrAllocator given a key.
+ static BufferTypeStorage *construct(TypeStorageAllocator &allocator,
+ const Key &key) {
+ return new (allocator.allocate<BufferTypeStorage>()) BufferTypeStorage(key);
+ }
+
+ /// Equality operator for hashing.
+ bool operator==(const Key &key) const {
+ return elementType == key.elementType;
+ }
+
+ /// Hashing for unique'ing.
+ static unsigned hashKey(const Key &key) {
+ return llvm::hash_combine(key.elementType);
+ }
+
+ Type getElementType() { return elementType; };
+
+private:
+ BufferTypeStorage(const Key &key) : elementType(key.elementType) {}
+
+ Type elementType;
+};
+
+BufferType mlir::BufferType::get(MLIRContext *context, Type elementType) {
+ return Base::get(context, LinalgTypes::Buffer, elementType);
+}
+
+Type mlir::BufferType::getElementType() { return getImpl()->getElementType(); }
+
Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
MLIRContext *context = getContext();
if (spec == "range")
return RangeType::get(getContext());
+ // TODO(ntv): reuse mlir Parser once exposed.
+ if (spec == "buffer<f32>")
+ return BufferType::get(getContext(), FloatType::getF32(getContext()));
return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
}
/// RangeType prints as just "range".
+static void print(BufferType bt, raw_ostream &os) {
+ os << "buffer<" << bt.getElementType() << ">";
+}
static void print(RangeType rt, raw_ostream &os) { os << "range"; }
void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
switch (type.getKind()) {
default:
llvm_unreachable("Unhandled Linalg type");
+ case LinalgTypes::Buffer:
+ print(type.cast<BufferType>(), os);
+ break;
case LinalgTypes::Range:
print(type.cast<RangeType>(), os);
break;
return
}
// CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) {
-// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
\ No newline at end of file
+// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
+
+func @buffer(%arg0: i64, %arg1: i64) {
+ %0 = muli %arg0, %arg0 : i64
+ %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+ linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+ return
+}
+// CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) {
+// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
+// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
\ No newline at end of file