From 0b47f740376e273a8be42fd0ae7b103623a61af5 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 19 Apr 2019 12:55:34 -0700 Subject: [PATCH] [Linalg] Add a slice op This CL adds a linalg.slice op with the proper roundtripping test. A slice op allows taking subviews that may be rank-reducing (if some indexing is of index type) or not (if all indexings are of linalg.range type). A slice must be constructed directly from a base view (no chains of slices may exist in the IR). Helper functions that fold will be provided for construction if/when necessary. This also renames base_view to view. -- PiperOrigin-RevId: 244406827 --- mlir/include/mlir/Linalg/LinalgOps.h | 181 ++++++++++++++----- mlir/include/mlir/Linalg/LinalgTypes.h | 16 +- mlir/lib/Linalg/LinalgOps.cpp | 308 +++++++++++++++++++++++---------- mlir/lib/Linalg/LinalgTypes.cpp | 2 +- mlir/test/Linalg/roundtrip.mlir | 28 +-- 5 files changed, 386 insertions(+), 149 deletions(-) diff --git a/mlir/include/mlir/Linalg/LinalgOps.h b/mlir/include/mlir/Linalg/LinalgOps.h index 925129e..2142459 100644 --- a/mlir/include/mlir/Linalg/LinalgOps.h +++ b/mlir/include/mlir/Linalg/LinalgOps.h @@ -24,53 +24,14 @@ namespace mlir { -/// A `BaseViewOp` produces a `ViewType` which is a multi-dimensional range -/// abstraction on top of an underlying linalg.buffer. A BaseViewOp gives a -/// buffer an indexing structure. -/// -/// A new value of ViewType is constructed from a buffer with a base_view op and -/// ranges: +/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type, +/// upon which a base view can be laid out to give it indexing semantics. +/// "buffer_alloc" takes a single argument, the size of the buffer to allocate +/// (in number of elements). /// /// ```{.mlir} -/// %1 = linalg.buffer_alloc %0 : !linalg.buffer -/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range -/// %3 = linalg.base_view %1[%2, %2] : !linalg.view +/// %0 = linalg.buffer_alloc %arg0 : !linalg.buffer /// ``` -class BaseViewOp : public mlir::Op { - enum { FirstIndexingOperand = 1 }; - -public: - using Op::Op; - - // Hooks to customize the behavior of this op. - static llvm::StringRef getOperationName() { return "linalg.base_view"; } - static void build(mlir::Builder *b, mlir::OperationState *result, - mlir::Value *buffer, - llvm::ArrayRef indexings); - mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); - void print(mlir::OpAsmPrinter *p); - - // Op-specific functionality. - unsigned getRank() { return getViewType().getRank(); } - mlir::Type getElementType() { return getViewType().getElementType(); } - ViewType getViewType() { return getType().cast(); } - mlir::Value *getSupportingBuffer() { return getOperand(0); } - // Get the underlying indexing at a given rank. - mlir::Value *getIndexing(unsigned rank) { - return *(getIndexings().begin() + rank); - } - // Get all the indexings in this view. - mlir::Operation::operand_range getIndexings() { - return {operand_begin() + BaseViewOp::FirstIndexingOperand, operand_end()}; - } -}; - -/// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base -/// view can be laid out. The size argument is an `i64` (and not an index), so -/// that we can class BufferAllocOp : public Op { public: @@ -89,7 +50,11 @@ public: Type getElementType() { return getBufferType().getElementType(); } }; -/// A BufferDeallocOp is used to free a !linalg.buffer. +/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type. +/// +/// ```{.mlir} +/// linalg.buffer_dealloc %0 : !linalg.buffer +/// ``` class BufferDeallocOp : public Op { public: @@ -109,8 +74,12 @@ public: } }; -/// A RangeOp is used to create a value of RangeType from 3 values of type index +/// The "linalg.range" op creates a linalg.range from 3 values of type `index` /// that represent the min, max and step values of the range. +/// +/// ```{.mlir} +/// %3 = linalg.range %0:%1:%2 : !linalg.range +/// ``` class RangeOp : public Op::Impl, OpTrait::OneResult, OpTrait::HasNoSideEffect> { public: @@ -130,6 +99,126 @@ public: Value *step() { return getOperand(2); } }; +/// The "linalg.slice" op produces a linalg.view which is a subview of a given +/// base view. This allows defining a subregion within the underlying buffer to +/// operate on only a subset of the buffer. +/// +/// A "linalg.slice" op takes a base view and a variadic number of indexings and +/// produces a linalg.view of the same elemental type as the buffer. An indexing +/// is either: +/// 1. a linalg.range, in which case it does not reduce the rank of the parent +/// view. +/// 2. an index, in which case it reduces the rank of the parent view by one. +/// +/// The parent view must be a base view (i.e. either a function argument or has +/// been produced by a linalg.view op). In other words, chains of +/// linalg.slice operations cannot be constructed in the IR. This defines away +/// problems related to keeping track of which dimensions of the base view have +/// been rank-reduced. +/// +/// Examples: +/// 1. rank-preserving slice: +/// +/// ```{.mlir} +/// %4 = linalg.slice %0[%1, %2] : !linalg.view, !linalg.range, +/// !linalg.range, !linalg.view +/// ``` +/// +/// 2. rank-reducing slice (from 2-D to 1-D): +/// +/// ```{.mlir} +/// %4 = linalg.slice %0[%1, %2] : !linalg.view, index, +/// !linalg.range, !linalg.view +/// ``` +/// +/// 3. rank-reducing slice (from 2-D to 0-D): +/// +/// ```{.mlir} +/// %4 = linalg.slice %0[%1, %2] : !linalg.view, index, index, +/// !linalg.view +/// ``` +class ViewOp; +class SliceOp : public mlir::Op { + enum { FirstIndexingOperand = 1 }; + +public: + using Op::Op; + + // Hooks to customize the behavior of this op. + static llvm::StringRef getOperationName() { return "linalg.slice"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *base, llvm::ArrayRef indexings); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + // Op-specific functionality. + unsigned getRank() { return getViewType().getRank(); } + mlir::Type getElementType() { return getViewType().getElementType(); } + ViewType getViewType() { return getType().cast(); } + Value *getBaseView() { return getOperand(0); } + ViewOp getBaseViewOp(); + ViewType getBaseViewType(); + unsigned getBaseViewRank() { return getBaseViewType().getRank(); } + // Get the underlying indexing at a given rank. + mlir::Value *getIndexing(unsigned rank) { + return *(getIndexings().begin() + rank); + } + // Get all the indexings in this view. + mlir::Operation::operand_range getIndexings() { + return {operand_begin() + SliceOp::FirstIndexingOperand, operand_end()}; + } + // Get the subset of indexings that are of RangeType. + SmallVector getRanges(); +}; + +/// The "linalg.view" op produces a linalg.view which is a multi-dimensional +/// range abstraction on top of an underlying linalg.buffer. This gives an +/// indexing structure to an otherwise non-indexable linalg.buffer. +/// +/// A "linalg.view" takes a buffer and a variadic number of ranges and produces +/// a `view` of the same elemental type as the buffer and of rank the number of +/// ranges: +/// +/// ```{.mlir} +/// %1 = linalg.buffer_alloc %0 : !linalg.buffer +/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range +/// %3 = linalg.view %1[%2, %2] : !linalg.view +/// ``` +class ViewOp : public mlir::Op { + enum { FirstIndexingOperand = 1 }; + +public: + using Op::Op; + + // Hooks to customize the behavior of this op. + static llvm::StringRef getOperationName() { return "linalg.view"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *buffer, + llvm::ArrayRef indexings); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + // Op-specific functionality. + unsigned getRank() { return getViewType().getRank(); } + mlir::Type getElementType() { return getViewType().getElementType(); } + ViewType getViewType() { return getType().cast(); } + mlir::Value *getSupportingBuffer() { return getOperand(0); } + // Get the underlying indexing at a given rank. + mlir::Value *getIndexing(unsigned rank) { + return *(getIndexings().begin() + rank); + } + // Get all the indexings in this view. + mlir::Operation::operand_range getIndexings() { + return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()}; + } +}; + } // namespace mlir #endif // MLIR_LINALG_LINALGOPS_H_ diff --git a/mlir/include/mlir/Linalg/LinalgTypes.h b/mlir/include/mlir/Linalg/LinalgTypes.h index b294bb3..fbb1cbf 100644 --- a/mlir/include/mlir/Linalg/LinalgTypes.h +++ b/mlir/include/mlir/Linalg/LinalgTypes.h @@ -42,7 +42,9 @@ public: void printType(Type type, llvm::raw_ostream &os) const override; }; -/// A BufferType represents a minimal range abstraction (min, max, step). +/// A BufferType represents a contiguous block of memory that can be allocated +/// and deallocated. A buffer cannot be indexed directly, a view must be +/// laid out on a buffer to give it indexing semantics. class BufferTypeStorage; class BufferType : public Type::TypeBase { public: @@ -58,6 +60,14 @@ public: }; /// A RangeType represents a minimal range abstraction (min, max, step). +/// It is constructed by calling the linalg.range op with three values index of +/// index type: +/// +/// ```{.mlir} +/// func @foo(%arg0 : index, %arg1 : index, %arg2 : index) { +/// %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range +/// } +/// ``` class RangeType : public Type::TypeBase { public: // Used for generic hooks in TypeBase. @@ -74,13 +84,13 @@ public: /// A ViewType represents a multi-dimensional range abstraction on top of an /// underlying storage type. It is parameterizable by the underlying element /// type and the rank of the view. -/// A new value of ViewType is constructed from a buffer with a base_view op and +/// A new value of ViewType is constructed from a buffer with a view op and /// passing it ranges: /// /// ```{.mlir} /// %1 = linalg.buffer_alloc %0 : !linalg.buffer /// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range -/// %3 = linalg.base_view %1[%2, %2] : !linalg.view +/// %3 = linalg.view %1[%2, %2] : !linalg.view /// ``` class ViewTypeStorage; class ViewType diff --git a/mlir/lib/Linalg/LinalgOps.cpp b/mlir/lib/Linalg/LinalgOps.cpp index 47aae67..2db7696 100644 --- a/mlir/lib/Linalg/LinalgOps.cpp +++ b/mlir/lib/Linalg/LinalgOps.cpp @@ -30,89 +30,6 @@ using namespace mlir; ////////////////////////////////////////////////////////////////////////////// -// BaseViewOp -////////////////////////////////////////////////////////////////////////////// -void mlir::BaseViewOp::build(Builder *b, OperationState *result, Value *buffer, - ArrayRef indexings) { - BufferType bufferType = buffer->getType().cast(); - result->addOperands({buffer}); - result->addOperands(indexings); - assert( - std::none_of(indexings.begin(), indexings.end(), - [](Value *v) { return !v->getType().isa(); }) && - "linalg.base_view takes only arguments of type linalg.range"); - - Type elementType = bufferType.getElementType(); - result->addTypes( - {ViewType::get(b->getContext(), elementType, indexings.size())}); -} - -LogicalResult mlir::BaseViewOp::verify() { - if (llvm::empty(getOperands())) - return emitOpError( - "requires at least a buffer operand followed by indexings"); - auto bufferType = getOperand(0)->getType().dyn_cast(); - if (!bufferType) - return emitOpError("first operand must be of BufferType"); - unsigned index = 0; - for (auto indexing : getIndexings()) { - if (!indexing->getType().isa()) { - return emitOpError(Twine(index) + "^th index must be of range type"); - } - ++index; - } - if (getViewType().getRank() != index) - return emitOpError( - "the rank of the base view must be the number of its indexings"); - return success(); -} - -bool mlir::BaseViewOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType bufferInfo; - SmallVector indexingsInfo; - Type type; - if (parser->parseOperand(bufferInfo) || - parser->parseOperandList(indexingsInfo, -1, - OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type)) - return true; - - ViewType viewType = type.dyn_cast(); - if (!viewType) - return parser->emitError(parser->getNameLoc(), "view type expected"); - if (viewType.getRank() != indexingsInfo.size()) - return parser->emitError(parser->getNameLoc(), - "expected" + Twine(viewType.getRank()) + - " range indexings"); - return parser->resolveOperand( - bufferInfo, - BufferType::get(type.getContext(), viewType.getElementType()), - result->operands) || - (!indexingsInfo.empty() && - parser->resolveOperands(indexingsInfo, - RangeType::get(type.getContext()), - result->operands)) || - parser->addTypeToList(viewType, result->types); -} - -// A BaseViewOp prints as: -// -// ```{.mlir} -// linalg.base_view %0[%1, %2] : !linalg.view -// ``` -// -// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each -// holding a range. -void mlir::BaseViewOp::print(OpAsmPrinter *p) { - *p << getOperationName() << " " << *getSupportingBuffer() << "["; - interleave( - getIndexings().begin(), getIndexings().end(), - [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; }); - *p << "] : " << getType(); -} - -////////////////////////////////////////////////////////////////////////////// // BufferAllocOp ////////////////////////////////////////////////////////////////////////////// void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type, @@ -122,9 +39,8 @@ void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type, } mlir::LogicalResult mlir::BufferAllocOp::verify() { - if (!size() || !size()->getType().isa() || - !size()->getType().cast().isInteger(64)) - return emitOpError("first operand should be of type i64"); + if (!size() || !size()->getType().isa()) + return emitOpError("first operand should be of type index"); if (!VectorType::isValidElementType(getElementType()) && !getElementType().isa()) return emitOpError("unsupported buffer element type"); @@ -143,14 +59,14 @@ void mlir::BufferAllocOp::print(OpAsmPrinter *p) { bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType sizeInfo; BufferType bufferType; - auto int64Ty = parser->getBuilder().getIntegerType(64); + auto indexTy = parser->getBuilder().getIndexType(); if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType)) return true; if (bufferType.getElementType() != parser->getBuilder().getF32Type()) return parser->emitError( parser->getNameLoc(), "Only buffer supported until mlir::Parser pieces are exposed"); - return parser->resolveOperands(sizeInfo, int64Ty, result->operands) || + return parser->resolveOperands(sizeInfo, indexTy, result->operands) || parser->addTypeToList(bufferType, result->types); } @@ -183,7 +99,6 @@ bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) { return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) || parser->resolveOperands(sizeInfo, bufferType, result->operands); } - ////////////////////////////////////////////////////////////////////////////// // RangeOp ////////////////////////////////////////////////////////////////////////////// @@ -224,3 +139,218 @@ bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) { parser->resolveOperands(rangeInfo, affineIntTy, result->operands) || parser->addTypeToList(type, result->types); } + +////////////////////////////////////////////////////////////////////////////// +// SliceOp +////////////////////////////////////////////////////////////////////////////// +void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base, + ArrayRef indexings) { + result->addOperands({base}); + result->addOperands(indexings); + + ViewType viewType = base->getType().cast(); + unsigned rank = viewType.getRank(); + for (auto *i : indexings) + if (!i->getType().isa()) + rank--; + Type elementType = viewType.getElementType(); + result->addTypes( + {ViewType::get(b->getContext(), elementType, indexings.size())}); +} + +LogicalResult mlir::SliceOp::verify() { + if (llvm::empty(getOperands())) + return emitOpError( + "requires at least a view operand followed by 'rank' indices"); + if (!getOperand(0)->getDefiningOp()->isa()) + return emitOpError( + "requires at least a view operand followed by 'rank' indices"); + + auto viewOp = getOperand(0)->getDefiningOp()->dyn_cast(); + if (!viewOp) + return emitOpError("first operand must come from a ViewOp"); + unsigned rank = getBaseViewRank(); + if (llvm::size(getIndexings()) != rank) { + return emitOpError("requires at least a view operand followed by " + + Twine(rank) + " indexings"); + } + unsigned index = 0; + for (auto indexing : getIndexings()) { + if (!indexing->getType().isa() && + !indexing->getType().isa()) { + return emitOpError(Twine(index) + + "^th index must be of range or index type"); + } + if (indexing->getType().isa()) + --rank; + ++index; + } + if (getRank() != rank) { + return emitOpError("the rank of the view must be the number of its range " + "indices: " + + Twine(rank)); + } + return success(); +} + +bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType baseInfo; + SmallVector indexingsInfo; + SmallVector types; + if (parser->parseOperand(baseInfo) || + parser->parseOperandList(indexingsInfo, -1, + OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonTypeList(types)) + return true; + + if (types.size() != 2 + indexingsInfo.size()) + return parser->emitError(parser->getNameLoc(), + "unexpected number of types "); + ViewType baseViewType = types[0].dyn_cast(); + if (!baseViewType) + return parser->emitError(parser->getNameLoc(), + "view type expected for first type"); + if (indexingsInfo.size() != baseViewType.getRank()) + return parser->emitError(parser->getNameLoc(), + "expected " + Twine(baseViewType.getRank()) + + " indexings"); + ViewType viewType = types.back().dyn_cast(); + if (!viewType) + return parser->emitError(parser->getNameLoc(), "view type expected"); + + ArrayRef indexingTypes = + ArrayRef(types).drop_front(1).drop_back(1); + if (indexingTypes.size() != baseViewType.getRank()) + return parser->emitError(parser->getNameLoc(), + "expected " + Twine(baseViewType.getRank()) + + " indexing types"); + return parser->resolveOperand(baseInfo, baseViewType, result->operands) || + (!indexingsInfo.empty() && + parser->resolveOperands(indexingsInfo, indexingTypes, + indexingsInfo.front().location, + result->operands)) || + parser->addTypeToList(viewType, result->types); +} + +// A SliceOp prints as: +// +// ```{.mlir} +// linalg.slice %0[%1, %2] : +// !linalg.view, [indexing-types], !linalg.view +// ``` +// +// Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are +// ssa-value each holding a range. +void mlir::SliceOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getBaseView() << "["; + interleave( + getIndexings().begin(), getIndexings().end(), + [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; }); + *p << "] : " << getBaseViewType(); + for (auto indexing : getIndexings()) { + *p << ", " << indexing->getType(); + } + *p << ", " << getType(); +} + +ViewOp mlir::SliceOp::getBaseViewOp() { + return getOperand(0)->getDefiningOp()->cast(); +} + +ViewType mlir::SliceOp::getBaseViewType() { + return getBaseViewOp().getType().cast(); +} + +SmallVector mlir::SliceOp::getRanges() { + llvm::SmallVector res; + for (auto *operand : getIndexings()) { + if (!operand->getType().isa()) { + res.push_back(operand); + } + } + return res; +} + +////////////////////////////////////////////////////////////////////////////// +// ViewOp +////////////////////////////////////////////////////////////////////////////// +void mlir::ViewOp::build(Builder *b, OperationState *result, Value *buffer, + ArrayRef indexings) { + BufferType bufferType = buffer->getType().cast(); + result->addOperands({buffer}); + result->addOperands(indexings); + assert( + std::none_of(indexings.begin(), indexings.end(), + [](Value *v) { return !v->getType().isa(); }) && + "linalg.view takes only arguments of type linalg.range"); + + Type elementType = bufferType.getElementType(); + result->addTypes( + {ViewType::get(b->getContext(), elementType, indexings.size())}); +} + +LogicalResult mlir::ViewOp::verify() { + if (llvm::empty(getOperands())) + return emitOpError( + "requires at least a buffer operand followed by indexings"); + auto bufferType = getOperand(0)->getType().dyn_cast(); + if (!bufferType) + return emitOpError("first operand must be of BufferType"); + unsigned index = 0; + for (auto indexing : getIndexings()) { + if (!indexing->getType().isa()) { + return emitOpError(Twine(index) + "^th index must be of range type"); + } + ++index; + } + if (getViewType().getRank() != index) + return emitOpError( + "the rank of the view must be the number of its indexings"); + return success(); +} + +bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType bufferInfo; + SmallVector indexingsInfo; + Type type; + if (parser->parseOperand(bufferInfo) || + parser->parseOperandList(indexingsInfo, -1, + OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type)) + return true; + + ViewType viewType = type.dyn_cast(); + if (!viewType) + return parser->emitError(parser->getNameLoc(), "view type expected"); + if (viewType.getRank() != indexingsInfo.size()) + return parser->emitError(parser->getNameLoc(), + "expected" + Twine(viewType.getRank()) + + " range indexings"); + return parser->resolveOperand( + bufferInfo, + BufferType::get(type.getContext(), viewType.getElementType()), + result->operands) || + (!indexingsInfo.empty() && + parser->resolveOperands(indexingsInfo, + RangeType::get(type.getContext()), + result->operands)) || + parser->addTypeToList(viewType, result->types); +} + +// A ViewOp prints as: +// +// ```{.mlir} +// linalg.view %0[%1, %2] : !linalg.view +// ``` +// +// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each +// holding a range. +void mlir::ViewOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getSupportingBuffer() << "["; + interleave( + getIndexings().begin(), getIndexings().end(), + [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; }); + *p << "] : " << getType(); +} diff --git a/mlir/lib/Linalg/LinalgTypes.cpp b/mlir/lib/Linalg/LinalgTypes.cpp index e164b0b..fa08f75 100644 --- a/mlir/lib/Linalg/LinalgTypes.cpp +++ b/mlir/lib/Linalg/LinalgTypes.cpp @@ -30,7 +30,7 @@ using namespace mlir; mlir::LinalgDialect::LinalgDialect(MLIRContext *context) : Dialect("linalg", context) { addTypes(); - addOperations(); + addOperations(); } struct mlir::BufferTypeStorage : public mlir::TypeStorage { diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir index eab2aa3..4327e5d 100644 --- a/mlir/test/Linalg/roundtrip.mlir +++ b/mlir/test/Linalg/roundtrip.mlir @@ -7,28 +7,36 @@ func @range(%arg0: index, %arg1: index, %arg2: index) { // CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) { // CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range -func @buffer(%arg0: i64, %arg1: i64) { - %0 = muli %arg0, %arg0 : i64 +func @buffer(%arg0: index, %arg1: index) { + %0 = muli %arg0, %arg0 : index %1 = linalg.buffer_alloc %0 : !linalg.buffer linalg.buffer_dealloc %1 : !linalg.buffer return } -// CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) { -// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64 +// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) { +// CHECK-NEXT: %0 = muli %arg0, %arg0 : index // CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer // CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer -func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) { - %0 = muli %arg0, %arg0 : i64 +func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { + %0 = muli %arg0, %arg0 : index %1 = linalg.buffer_alloc %0 : !linalg.buffer %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range - %3 = linalg.base_view %1[%2, %2] : !linalg.view + %3 = linalg.view %1[%2, %2] : !linalg.view + %4 = linalg.slice %3[%2, %2] : !linalg.view, !linalg.range, !linalg.range, !linalg.view + %5 = linalg.slice %3[%2, %arg2] : !linalg.view, !linalg.range, index, !linalg.view + %6 = linalg.slice %3[%arg2, %2] : !linalg.view, index, !linalg.range, !linalg.view + %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view, index, index, !linalg.view linalg.buffer_dealloc %1 : !linalg.buffer return } -// CHECK-LABEL: func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) { -// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64 +// CHECK-LABEL: func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { +// CHECK-NEXT: %0 = muli %arg0, %arg0 : index // CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer // CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range -// CHECK-NEXT: %3 = linalg.base_view %1[%2, %2] : !linalg.view +// CHECK-NEXT: %3 = linalg.view %1[%2, %2] : !linalg.view +// CHECK-NEXT: %4 = linalg.slice %3[%2, %2] : !linalg.view, !linalg.range, !linalg.range, !linalg.view +// CHECK-NEXT: %5 = linalg.slice %3[%2, %arg2] : !linalg.view, !linalg.range, index, !linalg.view +// CHECK-NEXT: %6 = linalg.slice %3[%arg2, %2] : !linalg.view, index, !linalg.range, !linalg.view +// CHECK-NEXT: %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view, index, index, !linalg.view // CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer \ No newline at end of file -- 2.7.4