using namespace mlir;
using namespace sparse_tensor;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
MLIRContext *ctx = tp.getContext();
auto enc = tp.getEncoding();
getSpecifierFields(tp));
}
-StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
- addConversion([](Type type) { return type; });
- addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
-}
+//===----------------------------------------------------------------------===//
+// Specifier struct builder.
+//===----------------------------------------------------------------------===//
constexpr uint64_t kDimSizePosInSpecifier = 0;
constexpr uint64_t kMemSizePosInSpecifier = 1;
loc, value, size, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// The sparse storage specifier type converter (defined in Passes.h).
+//===----------------------------------------------------------------------===//
+
+StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
+}
+
+//===----------------------------------------------------------------------===//
+// Storage specifier conversion rules.
+//===----------------------------------------------------------------------===//
+
template <typename Base, typename SourceOp>
class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
public:
}
};
+//===----------------------------------------------------------------------===//
+// Public method for populating conversion rules.
+//===----------------------------------------------------------------------===//
+
void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+
#include "SparseTensorStorageLayout.h"
#include "CodegenUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace sparse_tensor;
+//===----------------------------------------------------------------------===//
+// Private helper methods.
+//===----------------------------------------------------------------------===//
+
static Value createIndexCast(OpBuilder &builder, Location loc, Value value,
Type to) {
if (value.getType() != to)
return success();
}
+//===----------------------------------------------------------------------===//
+// The sparse tensor type converter (defined in Passes.h).
+//===----------------------------------------------------------------------===//
+
SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
addConversion([](Type type) { return type; });
addConversion([&](RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
});
}
+//===----------------------------------------------------------------------===//
+// StorageLayout methods.
+//===----------------------------------------------------------------------===//
+
unsigned StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind,
std::optional<unsigned> dim) const {
unsigned fieldIdx = -1u;
return getMemRefFieldIndex(toFieldKind(kind), dim);
}
+//===----------------------------------------------------------------------===//
+// StorageTensorSpecifier methods.
+//===----------------------------------------------------------------------===//
+
Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc,
RankedTensorType rtp) {
return builder.create<StorageSpecifierInitOp>(
createIndexCast(builder, loc, v, getFieldType(kind, dim)));
}
+//===----------------------------------------------------------------------===//
+// Public methods.
+//===----------------------------------------------------------------------===//
+
constexpr uint64_t kDataFieldStartingIdx = 0;
void sparse_tensor::foreachFieldInSparseTensor(