// Helper methods.
//===----------------------------------------------------------------------===//
-static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
+static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) {
MLIRContext *ctx = tp.getContext();
auto enc = tp.getEncoding();
const Level lvlRank = enc.getLvlRank();
- SmallVector<Type, 2> result;
+ SmallVector<Type, 4> result;
// TODO: how can we get the lowering type for index type in the later pipeline
// to be consistent? LLVM::StructureType does not allow index fields.
auto sizeType = IntegerType::get(tp.getContext(), 64);
getNumDataFieldsFromEncoding(enc));
result.push_back(lvlSizes);
result.push_back(memSizes);
+
+ if (enc.isSlice()) {
+ // Extra fields are required for the slice information.
+ auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
+ auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
+
+ result.push_back(dimOffset);
+ result.push_back(dimStride);
+ }
+
return result;
}
constexpr uint64_t kLvlSizePosInSpecifier = 0;
constexpr uint64_t kMemSizePosInSpecifier = 1;
+constexpr uint64_t kDimOffsetPosInSpecifier = 2;
+constexpr uint64_t kDimStridePosInSpecifier = 3;
class SpecifierStructBuilder : public StructBuilder {
private:
Value extractField(OpBuilder &builder, Location loc,
- ArrayRef<int64_t> indices) {
+ ArrayRef<int64_t> indices) const {
return genCast(builder, loc,
builder.create<LLVM::ExtractValueOp>(loc, value, indices),
builder.getIndexType());
assert(value);
}
- // Undef value for level-sizes, all zero values for memory-sizes.
- static Value getInitValue(OpBuilder &builder, Location loc, Type structType);
+ // Undef value for dimension sizes, all zero value for memory sizes.
+ static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
+ Value source);
- Value lvlSize(OpBuilder &builder, Location loc, Level lvl);
+ Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const;
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size);
- Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx);
+ Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const;
+ void setDimOffset(OpBuilder &builder, Location loc, Dimension dim,
+ Value size);
+
+ Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const;
+ void setDimStride(OpBuilder &builder, Location loc, Dimension dim,
+ Value size);
+
+ Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const;
void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx,
Value size);
+
+ Value memSizeArray(OpBuilder &builder, Location loc) const;
+ void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
};
Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
- Type structType) {
+ Type structType, Value source) {
Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
SpecifierStructBuilder md(metaData);
- auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
- .getBody()[kMemSizePosInSpecifier]
- .cast<LLVM::LLVMArrayType>();
+ if (!source) {
+ auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
+ .getBody()[kMemSizePosInSpecifier]
+ .cast<LLVM::LLVMArrayType>();
+
+ Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
+ // Fill memSizes array with zero.
+ for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
+ md.setMemSize(builder, loc, i, zero);
+ } else {
+ // We copy non-slice information (memory sizes array) from source
+ SpecifierStructBuilder sourceMd(source);
+ md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
+ }
+ return md;
+}
- Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
- // Fill memSizes array with zero.
- for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
- md.setMemSize(builder, loc, i, zero);
+/// Builds IR extracting the pos-th offset from the descriptor.
+Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
+ Dimension dim) const {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, value,
+ ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
+}
- return md;
+/// Builds IR inserting the pos-th offset into the descriptor.
+void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
+ Dimension dim, Value size) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, value, size,
+ ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
}
/// Builds IR extracting the `lvl`-th level-size from the descriptor.
Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
- Level lvl) {
+ Level lvl) const {
// This static_cast makes the narrowing of `lvl` explicit, as required
// by the braces notation for the ctor.
return extractField(
size);
}
-/// Builds IR extracting the `fidx`-th memory-size from the descriptor.
+/// Builds IR extracting the pos-th stride from the descriptor.
+Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
+ Dimension dim) const {
+ return extractField(
+ builder, loc,
+ ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)});
+}
+
+/// Builds IR inserting the pos-th stride into the descriptor.
+void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
+ Dimension dim, Value size) {
+ insertField(
+ builder, loc,
+ ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)},
+ size);
+}
+
+/// Builds IR extracting the pos-th memory size into the descriptor.
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
- FieldIndex fidx) {
- return extractField(builder, loc,
- ArrayRef<int64_t>{kMemSizePosInSpecifier, fidx});
+ FieldIndex fidx) const {
+ return extractField(
+ builder, loc,
+ ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)});
}
/// Builds IR inserting the `fidx`-th memory-size into the descriptor.
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
FieldIndex fidx, Value size) {
- insertField(builder, loc, ArrayRef<int64_t>{kMemSizePosInSpecifier, fidx},
- size);
+ insertField(
+ builder, loc,
+ ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)},
+ size);
+}
+
+/// Builds IR extracting the memory size array from the descriptor.
+Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
+ Location loc) const {
+ return builder.create<LLVM::ExtractValueOp>(loc, value,
+ kMemSizePosInSpecifier);
+}
+
+/// Builds IR inserting the memory size array into the descriptor.
+void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
+ Value array) {
+ value = builder.create<LLVM::InsertValueOp>(loc, value, array,
+ kMemSizePosInSpecifier);
}
} // namespace
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SpecifierStructBuilder spec(adaptor.getSpecifier());
- Value v;
- if (op.getSpecifierKind() == StorageSpecifierKind::LvlSize) {
- assert(op.getLevel().has_value());
- v = Base::onLvlSize(rewriter, op, spec, op.getLevel().value());
- } else {
+ switch (op.getSpecifierKind()) {
+ case StorageSpecifierKind::LvlSize: {
+ Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
+ rewriter.replaceOp(op, v);
+ return success();
+ }
+ case StorageSpecifierKind::DimOffset: {
+ Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
+ rewriter.replaceOp(op, v);
+ return success();
+ }
+ case StorageSpecifierKind::DimStride: {
+ Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
+ rewriter.replaceOp(op, v);
+ return success();
+ }
+ case StorageSpecifierKind::CrdMemSize:
+ case StorageSpecifierKind::PosMemSize:
+ case StorageSpecifierKind::ValMemSize: {
auto enc = op.getSpecifier().getType().getEncoding();
StorageLayout layout(enc);
- FieldIndex fidx =
- layout.getMemRefFieldIndex(op.getSpecifierKind(), op.getLevel());
- v = Base::onMemSize(rewriter, op, spec, fidx);
+ std::optional<unsigned> lvl;
+ if (op.getLevel())
+ lvl = (*op.getLevel());
+ unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), lvl);
+ Value v = Base::onMemSize(rewriter, op, spec, idx);
+ rewriter.replaceOp(op, v);
+ return success();
}
-
- rewriter.replaceOp(op, v);
- return success();
+ }
+ llvm_unreachable("unrecognized specifer kind");
}
};
: public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
SetStorageSpecifierOp> {
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
+
static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, Level lvl) {
spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
return spec;
}
+ static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op,
+ SpecifierStructBuilder &spec, Dimension d) {
+ spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
+ return spec;
+ }
+
+ static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op,
+ SpecifierStructBuilder &spec, Dimension d) {
+ spec.setDimStride(builder, op.getLoc(), d, op.getValue());
+ return spec;
+ }
+
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, FieldIndex fidx) {
spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
: public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
GetStorageSpecifierOp> {
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
+
static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op,
SpecifierStructBuilder &spec, Level lvl) {
return spec.lvlSize(builder, op.getLoc(), lvl);
}
+
+ static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op,
+ const SpecifierStructBuilder &spec, Dimension d) {
+ return spec.dimOffset(builder, op.getLoc(), d);
+ }
+
+ static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op,
+ const SpecifierStructBuilder &spec, Dimension d) {
+ return spec.dimStride(builder, op.getLoc(), d);
+ }
+
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
SpecifierStructBuilder &spec, FieldIndex fidx) {
return spec.memSize(builder, op.getLoc(), fidx);
matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
- rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue(
- rewriter, op.getLoc(), llvmType));
+ rewriter.replaceOp(
+ op, SpecifierStructBuilder::getInitValue(
+ rewriter, op.getLoc(), llvmType, adaptor.getSource()));
return success();
}
};