From 1d5dc840e7678af1f64b772f4d838e44fa10bef7 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 19 Apr 2019 09:56:11 -0700 Subject: [PATCH] [Linalg] Add a view type with base_view op This CL adds a linalg.view type and base_view op with the proper roundtripping test. The parser will be improved in a subsequent CL once portions of the mlir::Parser are exposed. For now this only supports dynamic views, static views will be introduced at a later time when they are needed. -- PiperOrigin-RevId: 244374180 --- mlir/include/mlir/Linalg/LinalgOps.h | 44 ++++++++++ mlir/include/mlir/Linalg/LinalgTypes.h | 37 ++++++++- mlir/lib/Linalg/LinalgOps.cpp | 144 ++++++++++++++++++++++++++------- mlir/lib/Linalg/LinalgTypes.cpp | 91 ++++++++++++++++++++- mlir/test/Linalg/roundtrip.mlir | 17 +++- 5 files changed, 296 insertions(+), 37 deletions(-) diff --git a/mlir/include/mlir/Linalg/LinalgOps.h b/mlir/include/mlir/Linalg/LinalgOps.h index 8e83d7c..925129e 100644 --- a/mlir/include/mlir/Linalg/LinalgOps.h +++ b/mlir/include/mlir/Linalg/LinalgOps.h @@ -24,6 +24,50 @@ namespace mlir { +/// A `BaseViewOp` produces a `ViewType` which is a multi-dimensional range +/// abstraction on top of an underlying linalg.buffer. A BaseViewOp gives a +/// buffer an indexing structure. +/// +/// A new value of ViewType is constructed from a buffer with a base_view op and +/// ranges: +/// +/// ```{.mlir} +/// %1 = linalg.buffer_alloc %0 : !linalg.buffer +/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range +/// %3 = linalg.base_view %1[%2, %2] : !linalg.view +/// ``` +class BaseViewOp : public mlir::Op { + enum { FirstIndexingOperand = 1 }; + +public: + using Op::Op; + + // Hooks to customize the behavior of this op. + static llvm::StringRef getOperationName() { return "linalg.base_view"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *buffer, + llvm::ArrayRef indexings); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + // Op-specific functionality. + unsigned getRank() { return getViewType().getRank(); } + mlir::Type getElementType() { return getViewType().getElementType(); } + ViewType getViewType() { return getType().cast(); } + mlir::Value *getSupportingBuffer() { return getOperand(0); } + // Get the underlying indexing at a given rank. + mlir::Value *getIndexing(unsigned rank) { + return *(getIndexings().begin() + rank); + } + // Get all the indexings in this view. + mlir::Operation::operand_range getIndexings() { + return {operand_begin() + BaseViewOp::FirstIndexingOperand, operand_end()}; + } +}; + /// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base /// view can be laid out. The size argument is an `i64` (and not an index), so /// that we can diff --git a/mlir/include/mlir/Linalg/LinalgTypes.h b/mlir/include/mlir/Linalg/LinalgTypes.h index f08d071..b294bb3 100644 --- a/mlir/include/mlir/Linalg/LinalgTypes.h +++ b/mlir/include/mlir/Linalg/LinalgTypes.h @@ -27,7 +27,8 @@ class MLIRContext; enum LinalgTypes { Buffer = Type::FIRST_LINALG_TYPE, Range, - LAST_USED_LINALG_TYPE = Range, + View, + LAST_USED_LINALG_TYPE = View, }; class LinalgDialect : public Dialect { @@ -51,9 +52,8 @@ public: static BufferType get(MLIRContext *context, Type elementType); /// Used to implement llvm-style cast. static bool kindof(unsigned kind) { return kind == LinalgTypes::Buffer; } - ////////////////////////////////////////////////////////////////////////////// + // Type-specific functionality. - ////////////////////////////////////////////////////////////////////////////// Type getElementType(); }; @@ -71,6 +71,37 @@ public: static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; } }; +/// A ViewType represents a multi-dimensional range abstraction on top of an +/// underlying storage type. It is parameterizable by the underlying element +/// type and the rank of the view. +/// A new value of ViewType is constructed from a buffer with a base_view op and +/// passing it ranges: +/// +/// ```{.mlir} +/// %1 = linalg.buffer_alloc %0 : !linalg.buffer +/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range +/// %3 = linalg.base_view %1[%2, %2] : !linalg.view +/// ``` +class ViewTypeStorage; +class ViewType + : public mlir::Type::TypeBase { +public: + // Used for generic hooks in TypeBase. + using Base::Base; + /// Construction hook. + static ViewType get(mlir::MLIRContext *context, mlir::Type elementType, + unsigned rank); + // Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == LinalgTypes::View; } + + // Type-specific functionality. + /// Return the underlying elemental type. + mlir::Type getElementType(); + /// Return the rank of the view. + /// This is the number of indexings needed to reach an underlying element. + unsigned getRank(); +}; + } // namespace mlir #endif // MLIR_LINALG_LINALGTYPES_H_ diff --git a/mlir/lib/Linalg/LinalgOps.cpp b/mlir/lib/Linalg/LinalgOps.cpp index c6260d8..47aae67 100644 --- a/mlir/lib/Linalg/LinalgOps.cpp +++ b/mlir/lib/Linalg/LinalgOps.cpp @@ -25,48 +25,91 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/Linalg/LinalgTypes.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/STLExtras.h" using namespace mlir; ////////////////////////////////////////////////////////////////////////////// -// RangeOp +// BaseViewOp ////////////////////////////////////////////////////////////////////////////// -void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min, - Value *max, Value *step) { - result->addOperands({min, max, step}); - result->addTypes({RangeType::get(b->getContext())}); +void mlir::BaseViewOp::build(Builder *b, OperationState *result, Value *buffer, + ArrayRef indexings) { + BufferType bufferType = buffer->getType().cast(); + result->addOperands({buffer}); + result->addOperands(indexings); + assert( + std::none_of(indexings.begin(), indexings.end(), + [](Value *v) { return !v->getType().isa(); }) && + "linalg.base_view takes only arguments of type linalg.range"); + + Type elementType = bufferType.getElementType(); + result->addTypes( + {ViewType::get(b->getContext(), elementType, indexings.size())}); } -// Verification is simply that a RangeOp takes 3 index ssa-value. -mlir::LogicalResult mlir::RangeOp::verify() { - if (!min() || !min()->getType().isa()) - return emitOpError("first operand should be of type index"); - if (!max() || !max()->getType().isa()) - return emitOpError("second operand should be of type index"); - if (!step() || !step()->getType().isa()) - return emitOpError("third operand should be of type index"); - return mlir::success(); +LogicalResult mlir::BaseViewOp::verify() { + if (llvm::empty(getOperands())) + return emitOpError( + "requires at least a buffer operand followed by indexings"); + auto bufferType = getOperand(0)->getType().dyn_cast(); + if (!bufferType) + return emitOpError("first operand must be of BufferType"); + unsigned index = 0; + for (auto indexing : getIndexings()) { + if (!indexing->getType().isa()) { + return emitOpError(Twine(index) + "^th index must be of range type"); + } + ++index; + } + if (getViewType().getRank() != index) + return emitOpError( + "the rank of the base view must be the number of its indexings"); + return success(); } -// A RangeOp prints as: +bool mlir::BaseViewOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType bufferInfo; + SmallVector indexingsInfo; + Type type; + if (parser->parseOperand(bufferInfo) || + parser->parseOperandList(indexingsInfo, -1, + OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type)) + return true; + + ViewType viewType = type.dyn_cast(); + if (!viewType) + return parser->emitError(parser->getNameLoc(), "view type expected"); + if (viewType.getRank() != indexingsInfo.size()) + return parser->emitError(parser->getNameLoc(), + "expected" + Twine(viewType.getRank()) + + " range indexings"); + return parser->resolveOperand( + bufferInfo, + BufferType::get(type.getContext(), viewType.getElementType()), + result->operands) || + (!indexingsInfo.empty() && + parser->resolveOperands(indexingsInfo, + RangeType::get(type.getContext()), + result->operands)) || + parser->addTypeToList(viewType, result->types); +} + +// A BaseViewOp prints as: // // ```{.mlir} -// linalg.range %0:%1:%2 : !linalg.range +// linalg.base_view %0[%1, %2] : !linalg.view // ``` -void mlir::RangeOp::print(OpAsmPrinter *p) { - *p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step() - << " : " << getType(); -} - -bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) { - SmallVector rangeInfo(3); - RangeType type; - auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(rangeInfo[0]) || parser->parseColon() || - parser->parseOperand(rangeInfo[1]) || parser->parseColon() || - parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) || - parser->resolveOperands(rangeInfo, affineIntTy, result->operands) || - parser->addTypeToList(type, result->types); +// +// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each +// holding a range. +void mlir::BaseViewOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getSupportingBuffer() << "["; + interleave( + getIndexings().begin(), getIndexings().end(), + [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; }); + *p << "] : " << getType(); } ////////////////////////////////////////////////////////////////////////////// @@ -140,3 +183,44 @@ bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) { return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) || parser->resolveOperands(sizeInfo, bufferType, result->operands); } + +////////////////////////////////////////////////////////////////////////////// +// RangeOp +////////////////////////////////////////////////////////////////////////////// +void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min, + Value *max, Value *step) { + result->addOperands({min, max, step}); + result->addTypes({RangeType::get(b->getContext())}); +} + +// Verification is simply that a RangeOp takes 3 index ssa-value. +mlir::LogicalResult mlir::RangeOp::verify() { + if (!min() || !min()->getType().isa()) + return emitOpError("first operand should be of type index"); + if (!max() || !max()->getType().isa()) + return emitOpError("second operand should be of type index"); + if (!step() || !step()->getType().isa()) + return emitOpError("third operand should be of type index"); + return mlir::success(); +} + +// A RangeOp prints as: +// +// ```{.mlir} +// linalg.range %0:%1:%2 : !linalg.range +// ``` +void mlir::RangeOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step() + << " : " << getType(); +} + +bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector rangeInfo(3); + RangeType type; + auto affineIntTy = parser->getBuilder().getIndexType(); + return parser->parseOperand(rangeInfo[0]) || parser->parseColon() || + parser->parseOperand(rangeInfo[1]) || parser->parseColon() || + parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) || + parser->resolveOperands(rangeInfo, affineIntTy, result->operands) || + parser->addTypeToList(type, result->types); +} diff --git a/mlir/lib/Linalg/LinalgTypes.cpp b/mlir/lib/Linalg/LinalgTypes.cpp index 822cd70..e164b0b 100644 --- a/mlir/lib/Linalg/LinalgTypes.cpp +++ b/mlir/lib/Linalg/LinalgTypes.cpp @@ -29,8 +29,8 @@ using namespace mlir; mlir::LinalgDialect::LinalgDialect(MLIRContext *context) : Dialect("linalg", context) { - addTypes(); - addOperations(); + addTypes(); + addOperations(); } struct mlir::BufferTypeStorage : public mlir::TypeStorage { @@ -80,15 +80,97 @@ Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const { // TODO(ntv): reuse mlir Parser once exposed. if (spec == "buffer") return BufferType::get(getContext(), FloatType::getF32(getContext())); + // TODO(ntv): reuse mlir Parser once exposed. + if (spec.startswith("view")) { + spec.consume_front("view"); + // Just count the number of ? to get the rank, the type must be f32 for now. + unsigned rank = 0; + for (unsigned i = 0, e = spec.size(); i < e; ++i) + if (spec[i] == '?') + ++rank; + return ViewType::get(context, FloatType::getF32(context), rank); + } return (context->emitError(loc, "unknown Linalg type: " + spec), Type()); } -/// RangeType prints as just "range". +struct mlir::ViewTypeStorage : public mlir::TypeStorage { + /// Underlying Key type to transport the payload needed to construct a custom + /// type in a generic way. + struct Key { + Key(Type elementType, unsigned rank) + : elementType(elementType), rank(rank) {} + Type elementType; + unsigned rank; + }; + /// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing. + using KeyTy = Key; + + /// Construction in the llvm::BumpPtrAllocator given a key. + static ViewTypeStorage *construct(TypeStorageAllocator &allocator, + const Key &key) { + return new (allocator.allocate()) ViewTypeStorage(key); + } + + /// Equality operator for hashing. + bool operator==(const Key &key) const { + return elementType == key.elementType && rank == key.rank; + } + + /// Hashing for unique'ing. + static unsigned hashKey(const Key &key) { + return llvm::hash_combine(key.elementType, key.rank); + } + + unsigned getRank() { return rank; }; + Type getElementType() { return elementType; }; + +private: + ViewTypeStorage(const Key &key) + : elementType(key.elementType), rank(key.rank) {} + + Type elementType; + unsigned rank; +}; + +ViewType mlir::ViewType::get(MLIRContext *context, Type elementType, + unsigned rank) { + return Base::get(context, LinalgTypes::View, elementType, rank); +} + +Type mlir::ViewType::getElementType() { return getImpl()->getElementType(); } + +unsigned mlir::ViewType::getRank() { return getImpl()->getRank(); } + +/// BufferType prints as "buffer". static void print(BufferType bt, raw_ostream &os) { os << "buffer<" << bt.getElementType() << ">"; } + +/// RangeType prints as just "range". static void print(RangeType rt, raw_ostream &os) { os << "range"; } +/// ViewType prints as: +/// +/// ```{.mlir} +/// view +/// ``` +/// +/// or +/// +/// ```{.mlir} +/// view +/// ``` +/// +/// for 0-D views (a.k.a pointer to a scalar value). +static void print(mlir::ViewType rt, raw_ostream &os) { + os << "view<"; + for (unsigned i = 0, e = rt.getRank(); i < e; ++i) { + os << "?x"; + } + os << rt.getElementType(); + os << ">"; +} + void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const { switch (type.getKind()) { default: @@ -99,5 +181,8 @@ void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const { case LinalgTypes::Range: print(type.cast(), os); break; + case LinalgTypes::View: + print(type.cast(), os); + break; } } diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir index 8544ed2..eab2aa3 100644 --- a/mlir/test/Linalg/roundtrip.mlir +++ b/mlir/test/Linalg/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -verify | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -verify | FileCheck %s func @range(%arg0: index, %arg1: index, %arg2: index) { %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range @@ -16,4 +16,19 @@ func @buffer(%arg0: i64, %arg1: i64) { // CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) { // CHECK-NEXT: %0 = muli %arg0, %arg0 : i64 // CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer +// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer + +func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) { + %0 = muli %arg0, %arg0 : i64 + %1 = linalg.buffer_alloc %0 : !linalg.buffer + %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range + %3 = linalg.base_view %1[%2, %2] : !linalg.view + linalg.buffer_dealloc %1 : !linalg.buffer + return +} +// CHECK-LABEL: func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) { +// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64 +// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer +// CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range +// CHECK-NEXT: %3 = linalg.base_view %1[%2, %2] : !linalg.view // CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer \ No newline at end of file -- 2.7.4