namespace mlir {
-/// A `BaseViewOp` produces a `ViewType` which is a multi-dimensional range
-/// abstraction on top of an underlying linalg.buffer. A BaseViewOp gives a
-/// buffer an indexing structure.
-///
-/// A new value of ViewType is constructed from a buffer with a base_view op and
-/// ranges:
+/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
+/// upon which a base view can be laid out to give it indexing semantics.
+/// "buffer_alloc" takes a single argument, the size of the buffer to allocate
+/// (in number of elements).
///
/// ```{.mlir}
-/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
-/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
-/// %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
+/// %0 = linalg.buffer_alloc %arg0 : !linalg.buffer<f32>
/// ```
-class BaseViewOp : public mlir::Op<BaseViewOp, mlir::OpTrait::VariadicOperands,
- mlir::OpTrait::OneResult,
- mlir::OpTrait::HasNoSideEffect> {
- enum { FirstIndexingOperand = 1 };
-
-public:
- using Op::Op;
-
- // Hooks to customize the behavior of this op.
- static llvm::StringRef getOperationName() { return "linalg.base_view"; }
- static void build(mlir::Builder *b, mlir::OperationState *result,
- mlir::Value *buffer,
- llvm::ArrayRef<mlir::Value *> indexings);
- mlir::LogicalResult verify();
- static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
- void print(mlir::OpAsmPrinter *p);
-
- // Op-specific functionality.
- unsigned getRank() { return getViewType().getRank(); }
- mlir::Type getElementType() { return getViewType().getElementType(); }
- ViewType getViewType() { return getType().cast<ViewType>(); }
- mlir::Value *getSupportingBuffer() { return getOperand(0); }
- // Get the underlying indexing at a given rank.
- mlir::Value *getIndexing(unsigned rank) {
- return *(getIndexings().begin() + rank);
- }
- // Get all the indexings in this view.
- mlir::Operation::operand_range getIndexings() {
- return {operand_begin() + BaseViewOp::FirstIndexingOperand, operand_end()};
- }
-};
-
-/// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base
-/// view can be laid out. The size argument is an `i64` (and not an index), so
-/// that we can
class BufferAllocOp
: public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
public:
Type getElementType() { return getBufferType().getElementType(); }
};
-/// A BufferDeallocOp is used to free a !linalg.buffer.
+/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
+///
+/// ```{.mlir}
+/// linalg.buffer_dealloc %0 : !linalg.buffer<f32>
+/// ```
class BufferDeallocOp
: public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
public:
}
};
-/// A RangeOp is used to create a value of RangeType from 3 values of type index
+/// The "linalg.range" op creates a linalg.range from 3 values of type `index`
/// that represent the min, max and step values of the range.
+///
+/// ```{.mlir}
+/// %3 = linalg.range %0:%1:%2 : !linalg.range
+/// ```
class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
Value *step() { return getOperand(2); }
};
+/// The "linalg.slice" op produces a linalg.view which is a subview of a given
+/// base view. This allows defining a subregion within the underlying buffer to
+/// operate on only a subset of the buffer.
+///
+/// A "linalg.slice" op takes a base view and a variadic number of indexings and
+/// produces a linalg.view of the same elemental type as the buffer. An indexing
+/// is either:
+/// 1. a linalg.range, in which case it does not reduce the rank of the parent
+/// view.
+/// 2. an index, in which case it reduces the rank of the parent view by one.
+///
+/// The parent view must be a base view (i.e. either a function argument or has
+/// been produced by a linalg.view op). In other words, chains of
+/// linalg.slice operations cannot be constructed in the IR. This defines away
+/// problems related to keeping track of which dimensions of the base view have
+/// been rank-reduced.
+///
+/// Examples:
+/// 1. rank-preserving slice:
+///
+/// ```{.mlir}
+/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, !linalg.range,
+/// !linalg.range, !linalg.view<?x?xf32>
+/// ```
+///
+/// 2. rank-reducing slice (from 2-D to 1-D):
+///
+/// ```{.mlir}
+/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index,
+/// !linalg.range, !linalg.view<?xf32>
+/// ```
+///
+/// 3. rank-reducing slice (from 2-D to 0-D):
+///
+/// ```{.mlir}
+/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index, index,
+/// !linalg.view<f32>
+/// ```
+class ViewOp;
+class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+ enum { FirstIndexingOperand = 1 };
+
+public:
+ using Op::Op;
+
+ // Hooks to customize the behavior of this op.
+ static llvm::StringRef getOperationName() { return "linalg.slice"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *base, llvm::ArrayRef<mlir::Value *> indexings);
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ // Op-specific functionality.
+ unsigned getRank() { return getViewType().getRank(); }
+ mlir::Type getElementType() { return getViewType().getElementType(); }
+ ViewType getViewType() { return getType().cast<ViewType>(); }
+ Value *getBaseView() { return getOperand(0); }
+ ViewOp getBaseViewOp();
+ ViewType getBaseViewType();
+ unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
+ // Get the underlying indexing at a given rank.
+ mlir::Value *getIndexing(unsigned rank) {
+ return *(getIndexings().begin() + rank);
+ }
+ // Get all the indexings in this view.
+ mlir::Operation::operand_range getIndexings() {
+ return {operand_begin() + SliceOp::FirstIndexingOperand, operand_end()};
+ }
+ // Get the subset of indexings that are of RangeType.
+ SmallVector<Value *, 8> getRanges();
+};
+
+/// The "linalg.view" op produces a linalg.view which is a multi-dimensional
+/// range abstraction on top of an underlying linalg.buffer. This gives an
+/// indexing structure to an otherwise non-indexable linalg.buffer.
+///
+/// A "linalg.view" takes a buffer and a variadic number of ranges and produces
+/// a `view` of the same elemental type as the buffer and of rank the number of
+/// ranges:
+///
+/// ```{.mlir}
+/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
+/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
+/// ```
+class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::OneResult,
+ mlir::OpTrait::HasNoSideEffect> {
+ enum { FirstIndexingOperand = 1 };
+
+public:
+ using Op::Op;
+
+ // Hooks to customize the behavior of this op.
+ static llvm::StringRef getOperationName() { return "linalg.view"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *buffer,
+ llvm::ArrayRef<mlir::Value *> indexings);
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ // Op-specific functionality.
+ unsigned getRank() { return getViewType().getRank(); }
+ mlir::Type getElementType() { return getViewType().getElementType(); }
+ ViewType getViewType() { return getType().cast<ViewType>(); }
+ mlir::Value *getSupportingBuffer() { return getOperand(0); }
+ // Get the underlying indexing at a given rank.
+ mlir::Value *getIndexing(unsigned rank) {
+ return *(getIndexings().begin() + rank);
+ }
+ // Get all the indexings in this view.
+ mlir::Operation::operand_range getIndexings() {
+ return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
+ }
+};
+
} // namespace mlir
#endif // MLIR_LINALG_LINALGOPS_H_
void printType(Type type, llvm::raw_ostream &os) const override;
};
-/// A BufferType represents a minimal range abstraction (min, max, step).
+/// A BufferType represents a contiguous block of memory that can be allocated
+/// and deallocated. A buffer cannot be indexed directly, a view must be
+/// laid out on a buffer to give it indexing semantics.
class BufferTypeStorage;
class BufferType : public Type::TypeBase<BufferType, Type, BufferTypeStorage> {
public:
};
/// A RangeType represents a minimal range abstraction (min, max, step).
+/// It is constructed by calling the linalg.range op with three values index of
+/// index type:
+///
+/// ```{.mlir}
+/// func @foo(%arg0 : index, %arg1 : index, %arg2 : index) {
+/// %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
+/// }
+/// ```
class RangeType : public Type::TypeBase<RangeType, Type> {
public:
// Used for generic hooks in TypeBase.
/// A ViewType represents a multi-dimensional range abstraction on top of an
/// underlying storage type. It is parameterizable by the underlying element
/// type and the rank of the view.
-/// A new value of ViewType is constructed from a buffer with a base_view op and
+/// A new value of ViewType is constructed from a buffer with a view op and
/// passing it ranges:
///
/// ```{.mlir}
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
-/// %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
+/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
/// ```
class ViewTypeStorage;
class ViewType
using namespace mlir;
//////////////////////////////////////////////////////////////////////////////
-// BaseViewOp
-//////////////////////////////////////////////////////////////////////////////
-void mlir::BaseViewOp::build(Builder *b, OperationState *result, Value *buffer,
- ArrayRef<Value *> indexings) {
- BufferType bufferType = buffer->getType().cast<BufferType>();
- result->addOperands({buffer});
- result->addOperands(indexings);
- assert(
- std::none_of(indexings.begin(), indexings.end(),
- [](Value *v) { return !v->getType().isa<RangeType>(); }) &&
- "linalg.base_view takes only arguments of type linalg.range");
-
- Type elementType = bufferType.getElementType();
- result->addTypes(
- {ViewType::get(b->getContext(), elementType, indexings.size())});
-}
-
-LogicalResult mlir::BaseViewOp::verify() {
- if (llvm::empty(getOperands()))
- return emitOpError(
- "requires at least a buffer operand followed by indexings");
- auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
- if (!bufferType)
- return emitOpError("first operand must be of BufferType");
- unsigned index = 0;
- for (auto indexing : getIndexings()) {
- if (!indexing->getType().isa<RangeType>()) {
- return emitOpError(Twine(index) + "^th index must be of range type");
- }
- ++index;
- }
- if (getViewType().getRank() != index)
- return emitOpError(
- "the rank of the base view must be the number of its indexings");
- return success();
-}
-
-bool mlir::BaseViewOp::parse(OpAsmParser *parser, OperationState *result) {
- OpAsmParser::OperandType bufferInfo;
- SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
- Type type;
- if (parser->parseOperand(bufferInfo) ||
- parser->parseOperandList(indexingsInfo, -1,
- OpAsmParser::Delimiter::Square) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type))
- return true;
-
- ViewType viewType = type.dyn_cast<ViewType>();
- if (!viewType)
- return parser->emitError(parser->getNameLoc(), "view type expected");
- if (viewType.getRank() != indexingsInfo.size())
- return parser->emitError(parser->getNameLoc(),
- "expected" + Twine(viewType.getRank()) +
- " range indexings");
- return parser->resolveOperand(
- bufferInfo,
- BufferType::get(type.getContext(), viewType.getElementType()),
- result->operands) ||
- (!indexingsInfo.empty() &&
- parser->resolveOperands(indexingsInfo,
- RangeType::get(type.getContext()),
- result->operands)) ||
- parser->addTypeToList(viewType, result->types);
-}
-
-// A BaseViewOp prints as:
-//
-// ```{.mlir}
-// linalg.base_view %0[%1, %2] : !linalg.view<?x?xf32>
-// ```
-//
-// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
-// holding a range.
-void mlir::BaseViewOp::print(OpAsmPrinter *p) {
- *p << getOperationName() << " " << *getSupportingBuffer() << "[";
- interleave(
- getIndexings().begin(), getIndexings().end(),
- [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
- *p << "] : " << getType();
-}
-
-//////////////////////////////////////////////////////////////////////////////
// BufferAllocOp
//////////////////////////////////////////////////////////////////////////////
void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type,
}
mlir::LogicalResult mlir::BufferAllocOp::verify() {
- if (!size() || !size()->getType().isa<IntegerType>() ||
- !size()->getType().cast<IntegerType>().isInteger(64))
- return emitOpError("first operand should be of type i64");
+ if (!size() || !size()->getType().isa<IndexType>())
+ return emitOpError("first operand should be of type index");
if (!VectorType::isValidElementType(getElementType()) &&
!getElementType().isa<VectorType>())
return emitOpError("unsupported buffer element type");
bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType sizeInfo;
BufferType bufferType;
- auto int64Ty = parser->getBuilder().getIntegerType(64);
+ auto indexTy = parser->getBuilder().getIndexType();
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
return true;
if (bufferType.getElementType() != parser->getBuilder().getF32Type())
return parser->emitError(
parser->getNameLoc(),
"Only buffer<f32> supported until mlir::Parser pieces are exposed");
- return parser->resolveOperands(sizeInfo, int64Ty, result->operands) ||
+ return parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types);
}
return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
parser->resolveOperands(sizeInfo, bufferType, result->operands);
}
-
//////////////////////////////////////////////////////////////////////////////
// RangeOp
//////////////////////////////////////////////////////////////////////////////
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types);
}
+
+//////////////////////////////////////////////////////////////////////////////
+// SliceOp
+//////////////////////////////////////////////////////////////////////////////
+void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base,
+ ArrayRef<Value *> indexings) {
+ result->addOperands({base});
+ result->addOperands(indexings);
+
+ ViewType viewType = base->getType().cast<ViewType>();
+ unsigned rank = viewType.getRank();
+ for (auto *i : indexings)
+ if (!i->getType().isa<RangeType>())
+ rank--;
+ Type elementType = viewType.getElementType();
+ result->addTypes(
+ {ViewType::get(b->getContext(), elementType, indexings.size())});
+}
+
+LogicalResult mlir::SliceOp::verify() {
+ if (llvm::empty(getOperands()))
+ return emitOpError(
+ "requires at least a view operand followed by 'rank' indices");
+ if (!getOperand(0)->getDefiningOp()->isa<ViewOp>())
+ return emitOpError(
+ "requires at least a view operand followed by 'rank' indices");
+
+ auto viewOp = getOperand(0)->getDefiningOp()->dyn_cast<ViewOp>();
+ if (!viewOp)
+ return emitOpError("first operand must come from a ViewOp");
+ unsigned rank = getBaseViewRank();
+ if (llvm::size(getIndexings()) != rank) {
+ return emitOpError("requires at least a view operand followed by " +
+ Twine(rank) + " indexings");
+ }
+ unsigned index = 0;
+ for (auto indexing : getIndexings()) {
+ if (!indexing->getType().isa<RangeType>() &&
+ !indexing->getType().isa<IndexType>()) {
+ return emitOpError(Twine(index) +
+ "^th index must be of range or index type");
+ }
+ if (indexing->getType().isa<IndexType>())
+ --rank;
+ ++index;
+ }
+ if (getRank() != rank) {
+ return emitOpError("the rank of the view must be the number of its range "
+ "indices: " +
+ Twine(rank));
+ }
+ return success();
+}
+
+bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType baseInfo;
+ SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
+ SmallVector<Type, 8> types;
+ if (parser->parseOperand(baseInfo) ||
+ parser->parseOperandList(indexingsInfo, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonTypeList(types))
+ return true;
+
+ if (types.size() != 2 + indexingsInfo.size())
+ return parser->emitError(parser->getNameLoc(),
+ "unexpected number of types ");
+ ViewType baseViewType = types[0].dyn_cast<ViewType>();
+ if (!baseViewType)
+ return parser->emitError(parser->getNameLoc(),
+ "view type expected for first type");
+ if (indexingsInfo.size() != baseViewType.getRank())
+ return parser->emitError(parser->getNameLoc(),
+ "expected " + Twine(baseViewType.getRank()) +
+ " indexings");
+ ViewType viewType = types.back().dyn_cast<ViewType>();
+ if (!viewType)
+ return parser->emitError(parser->getNameLoc(), "view type expected");
+
+ ArrayRef<Type> indexingTypes =
+ ArrayRef<Type>(types).drop_front(1).drop_back(1);
+ if (indexingTypes.size() != baseViewType.getRank())
+ return parser->emitError(parser->getNameLoc(),
+ "expected " + Twine(baseViewType.getRank()) +
+ " indexing types");
+ return parser->resolveOperand(baseInfo, baseViewType, result->operands) ||
+ (!indexingsInfo.empty() &&
+ parser->resolveOperands(indexingsInfo, indexingTypes,
+ indexingsInfo.front().location,
+ result->operands)) ||
+ parser->addTypeToList(viewType, result->types);
+}
+
+// A SliceOp prints as:
+//
+// ```{.mlir}
+// linalg.slice %0[%1, %2] :
+// !linalg.view<?x?xf32>, [indexing-types], !linalg.view<?x?xf32>
+// ```
+//
+// Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are
+// ssa-value each holding a range.
+void mlir::SliceOp::print(OpAsmPrinter *p) {
+ *p << getOperationName() << " " << *getBaseView() << "[";
+ interleave(
+ getIndexings().begin(), getIndexings().end(),
+ [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
+ *p << "] : " << getBaseViewType();
+ for (auto indexing : getIndexings()) {
+ *p << ", " << indexing->getType();
+ }
+ *p << ", " << getType();
+}
+
+ViewOp mlir::SliceOp::getBaseViewOp() {
+ return getOperand(0)->getDefiningOp()->cast<ViewOp>();
+}
+
+ViewType mlir::SliceOp::getBaseViewType() {
+ return getBaseViewOp().getType().cast<ViewType>();
+}
+
+SmallVector<Value *, 8> mlir::SliceOp::getRanges() {
+ llvm::SmallVector<Value *, 8> res;
+ for (auto *operand : getIndexings()) {
+ if (!operand->getType().isa<IndexType>()) {
+ res.push_back(operand);
+ }
+ }
+ return res;
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// ViewOp
+//////////////////////////////////////////////////////////////////////////////
+void mlir::ViewOp::build(Builder *b, OperationState *result, Value *buffer,
+ ArrayRef<Value *> indexings) {
+ BufferType bufferType = buffer->getType().cast<BufferType>();
+ result->addOperands({buffer});
+ result->addOperands(indexings);
+ assert(
+ std::none_of(indexings.begin(), indexings.end(),
+ [](Value *v) { return !v->getType().isa<RangeType>(); }) &&
+ "linalg.view takes only arguments of type linalg.range");
+
+ Type elementType = bufferType.getElementType();
+ result->addTypes(
+ {ViewType::get(b->getContext(), elementType, indexings.size())});
+}
+
+LogicalResult mlir::ViewOp::verify() {
+ if (llvm::empty(getOperands()))
+ return emitOpError(
+ "requires at least a buffer operand followed by indexings");
+ auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
+ if (!bufferType)
+ return emitOpError("first operand must be of BufferType");
+ unsigned index = 0;
+ for (auto indexing : getIndexings()) {
+ if (!indexing->getType().isa<RangeType>()) {
+ return emitOpError(Twine(index) + "^th index must be of range type");
+ }
+ ++index;
+ }
+ if (getViewType().getRank() != index)
+ return emitOpError(
+ "the rank of the view must be the number of its indexings");
+ return success();
+}
+
+bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType bufferInfo;
+ SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
+ Type type;
+ if (parser->parseOperand(bufferInfo) ||
+ parser->parseOperandList(indexingsInfo, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type))
+ return true;
+
+ ViewType viewType = type.dyn_cast<ViewType>();
+ if (!viewType)
+ return parser->emitError(parser->getNameLoc(), "view type expected");
+ if (viewType.getRank() != indexingsInfo.size())
+ return parser->emitError(parser->getNameLoc(),
+ "expected" + Twine(viewType.getRank()) +
+ " range indexings");
+ return parser->resolveOperand(
+ bufferInfo,
+ BufferType::get(type.getContext(), viewType.getElementType()),
+ result->operands) ||
+ (!indexingsInfo.empty() &&
+ parser->resolveOperands(indexingsInfo,
+ RangeType::get(type.getContext()),
+ result->operands)) ||
+ parser->addTypeToList(viewType, result->types);
+}
+
+// A ViewOp prints as:
+//
+// ```{.mlir}
+// linalg.view %0[%1, %2] : !linalg.view<?x?xf32>
+// ```
+//
+// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
+// holding a range.
+void mlir::ViewOp::print(OpAsmPrinter *p) {
+ *p << getOperationName() << " " << *getSupportingBuffer() << "[";
+ interleave(
+ getIndexings().begin(), getIndexings().end(),
+ [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
+ *p << "] : " << getType();
+}
mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect("linalg", context) {
addTypes<BufferType, RangeType, ViewType>();
- addOperations<BaseViewOp, BufferAllocOp, BufferDeallocOp, RangeOp>();
+ addOperations<BufferAllocOp, BufferDeallocOp, RangeOp, SliceOp, ViewOp>();
}
struct mlir::BufferTypeStorage : public mlir::TypeStorage {
// CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) {
// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
-func @buffer(%arg0: i64, %arg1: i64) {
- %0 = muli %arg0, %arg0 : i64
+func @buffer(%arg0: index, %arg1: index) {
+ %0 = muli %arg0, %arg0 : index
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
linalg.buffer_dealloc %1 : !linalg.buffer<f32>
return
}
-// CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) {
-// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
+// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
+// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
-func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
- %0 = muli %arg0, %arg0 : i64
+func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
+ %0 = muli %arg0, %arg0 : index
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
- %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
+ %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
+ %4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+ %5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
+ %6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
+ %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
linalg.buffer_dealloc %1 : !linalg.buffer<f32>
return
}
-// CHECK-LABEL: func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
-// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
+// CHECK-LABEL: func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
+// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
// CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
-// CHECK-NEXT: %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
+// CHECK-NEXT: %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
+// CHECK-NEXT: %4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+// CHECK-NEXT: %5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
+// CHECK-NEXT: %6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
+// CHECK-NEXT: %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
\ No newline at end of file