The bulk of D142074 seems to have gotten overwritten due to some sort of merge conflict (afaict there's no record of it having been reverted intentionally). So this commit redoes those changes. In addition to the original changes, this commit also:
* moves the definition of `getRankedTensorType` (from `Transforms/CodegenUtils.h` to `IR/SparseTensor.h`), so that it can be used by `IR/SparseTensorDialect.cpp`.
* adds `getMemRefType` as another abbreviation.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D142503
namespace mlir {
namespace sparse_tensor {
+/// Convenience method to abbreviate casting `getType()`.
+template <typename T>
+inline RankedTensorType getRankedTensorType(T t) {
+ return t.getType().template cast<RankedTensorType>();
+}
+
+/// Convenience method to abbreviate casting `getType()`.
+template <typename T>
+inline MemRefType getMemRefType(T t) {
+ return t.getType().template cast<MemRefType>();
+}
+
/// Convenience method to get a sparse encoding attribute from a type.
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
using namespace mlir::sparse_tensor;
//===----------------------------------------------------------------------===//
+// Additional convenience methods.
+//===----------------------------------------------------------------------===//
+
+template <typename T>
+static inline int64_t getTypeRank(T t) {
+ return getRankedTensorType(t).getRank();
+}
+
+//===----------------------------------------------------------------------===//
// TensorDialect Attribute Methods.
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
static LogicalResult isInBounds(uint64_t dim, Value tensor) {
- return success(dim <
- (uint64_t)tensor.getType().cast<RankedTensorType>().getRank());
+ return success(dim < static_cast<uint64_t>(getTypeRank(tensor)));
}
static LogicalResult isMatchingWidth(Value result, unsigned width) {
- const Type etp = result.getType().cast<MemRefType>().getElementType();
+ const Type etp = getMemRefType(result).getElementType();
return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
}
}
LogicalResult NewOp::verify() {
- if (getExpandSymmetry() &&
- getResult().getType().cast<RankedTensorType>().getRank() != 2)
+ if (getExpandSymmetry() && getTypeRank(getResult()) != 2)
return emitOpError("expand_symmetry can only be used for 2D tensors");
return success();
}
}
LogicalResult ToValuesOp::verify() {
- RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
- MemRefType mtp = getResult().getType().cast<MemRefType>();
+ auto ttp = getRankedTensorType(getTensor());
+ auto mtp = getMemRefType(getResult());
if (ttp.getElementType() != mtp.getElementType())
return emitError("unexpected mismatch in element types");
return success();
}
LogicalResult ConcatenateOp::verify() {
- auto dstTp = getType().cast<RankedTensorType>();
+ auto dstTp = getRankedTensorType(*this);
uint64_t concatDim = getDimension().getZExtValue();
unsigned rank = dstTp.getRank();
concatDim));
for (size_t i = 0, e = getInputs().size(); i < e; i++) {
- Value input = getInputs()[i];
- auto inputRank = input.getType().cast<RankedTensorType>().getRank();
+ const auto inputRank = getTypeRank(getInputs()[i]);
if (inputRank != rank)
return emitError(
llvm::formatv("The input tensor ${0} has a different rank (rank={1}) "
}
for (unsigned i = 0; i < rank; i++) {
- auto dstDim = dstTp.getShape()[i];
+ const auto dstDim = dstTp.getShape()[i];
if (i == concatDim) {
if (!ShapedType::isDynamic(dstDim)) {
+ // If we reach here, all inputs should have static shapes.
unsigned sumDim = 0;
- for (auto src : getInputs()) {
- // If we reach here, all inputs should have static shapes.
- auto d = src.getType().cast<RankedTensorType>().getShape()[i];
- sumDim += d;
- }
+ for (auto src : getInputs())
+ sumDim += getRankedTensorType(src).getShape()[i];
// If all dimension are statically known, the sum of all the input
// dimensions should be equal to the output dimension.
if (sumDim != dstDim)
} else {
int64_t prev = dstDim;
for (auto src : getInputs()) {
- auto d = src.getType().cast<RankedTensorType>().getShape()[i];
+ const auto d = getRankedTensorType(src).getShape()[i];
if (!ShapedType::isDynamic(prev) && d != prev)
return emitError("All dimensions (expect for the concatenating one) "
"should be equal.");
}
LogicalResult InsertOp::verify() {
- RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
- if (ttp.getRank() != static_cast<int64_t>(getIndices().size()))
+ if (getTypeRank(getTensor()) != static_cast<int64_t>(getIndices().size()))
return emitOpError("incorrect number of indices");
return success();
}
}
LogicalResult CompressOp::verify() {
- RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
- if (ttp.getRank() != 1 + static_cast<int64_t>(getIndices().size()))
+ if (getTypeRank(getTensor()) != 1 + static_cast<int64_t>(getIndices().size()))
return emitOpError("incorrect number of indices");
return success();
}
// Builds foreach body.
if (!bodyBuilder)
return;
- auto rtp = tensor.getType().cast<RankedTensorType>();
+ auto rtp = getRankedTensorType(tensor);
int64_t rank = rtp.getRank();
SmallVector<Type> blockArgTypes;
}
LogicalResult ForeachOp::verify() {
- auto t = getTensor().getType().cast<RankedTensorType>();
+ auto t = getRankedTensorType(getTensor());
auto args = getBody()->getArguments();
if (static_cast<size_t>(t.getRank()) + 1 + getInitArgs().size() !=
auto n = getN().getDefiningOp<arith::ConstantIndexOp>();
- Type xtp = getXs().front().getType().cast<MemRefType>().getElementType();
+ Type xtp = getMemRefType(getXs().front()).getElementType();
auto checkTypes = [&](ValueRange operands,
bool checkEleType = true) -> LogicalResult {
for (Value opnd : operands) {
- MemRefType mtp = opnd.getType().cast<MemRefType>();
+ auto mtp = getMemRefType(opnd);
int64_t dim = mtp.getShape()[0];
// We can't check the size of dynamic dimension at compile-time, but all
// xs and ys should have a dimension not less than n at runtime.
}
auto checkDim = [&](Value v, uint64_t min, const char *message) {
- MemRefType tp = v.getType().cast<MemRefType>();
+ auto tp = getMemRefType(v);
int64_t dim = tp.getShape()[0];
if (!ShapedType::isDynamic(dim) && dim < (int64_t)min) {
emitError(llvm::formatv("{0} got {1} < {2}", message, dim, min));
idxBuffer = builder.create<memref::CastOp>(
loc, MemRefType::get({rank}, builder.getIndexType()), idxBuffer);
SmallVector<int64_t> shape(rank, ShapedType::kDynamic);
- Type elemTp = valuesBuffer.getType().cast<MemRefType>().getElementType();
+ Type elemTp = getMemRefType(valuesBuffer).getElementType();
return builder.create<memref::ReshapeOp>(loc, MemRefType::get(shape, elemTp),
valuesBuffer, idxBuffer);
}
// Misc code generators and utilities.
//===----------------------------------------------------------------------===//
-template <typename T>
-inline RankedTensorType getRankedTensorType(T t) {
- return t.getType().template cast<RankedTensorType>();
-}
-
/// Generates a 1-valued attribute of the given type. This supports
/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
/// for unsupported types we raise `llvm_unreachable` rather than
StringRef namePrefix, uint64_t nx,
uint64_t ny, bool isCoo,
ValueRange operands) {
- nameOstream
- << namePrefix << nx << "_"
- << operands[xStartIdx].getType().cast<MemRefType>().getElementType();
+ nameOstream << namePrefix << nx << "_"
+ << getMemRefType(operands[xStartIdx]).getElementType();
if (isCoo)
nameOstream << "_coo_" << ny;
uint64_t yBufferOffset = isCoo ? 1 : nx;
for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
- nameOstream << "_" << v.getType().cast<MemRefType>().getElementType();
+ nameOstream << "_" << getMemRefType(v).getElementType();
}
/// Looks up a function that is appropriate for the given operands being
// Convert `values` to have dynamic shape and append them to `operands`.
for (Value v : xys) {
- auto mtp = v.getType().cast<MemRefType>();
+ auto mtp = getMemRefType(v);
if (!mtp.isDynamicDim(0)) {
auto newMtp =
MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) {
Location loc = op.getLoc();
- auto srcTp = op.getSrc().getType().template cast<RankedTensorType>();
- auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
+ auto srcTp = getRankedTensorType(op.getSrc());
+ auto dstTp = getRankedTensorType(op.getResult());
auto encSrc = getSparseTensorEncoding(srcTp);
auto encDst = getSparseTensorEncoding(dstTp);
if (!encDst || !encSrc)
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
- Type resType = op.getType();
- Type srcType = op.getSource().getType();
+ auto resType = getRankedTensorType(op);
+ auto srcType = getRankedTensorType(op.getSource());
auto encDst = getSparseTensorEncoding(resType);
auto encSrc = getSparseTensorEncoding(srcType);
Value src = adaptor.getOperands()[0];
// dst[elem.indices] = elem.value;
// }
// delete iter;
- RankedTensorType dstTensorTp = resType.cast<RankedTensorType>();
- RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
- unsigned rank = dstTensorTp.getRank();
- Type elemTp = dstTensorTp.getElementType();
+ const unsigned rank = resType.getRank();
+ const Type elemTp = resType.getElementType();
// Fabricate a no-permutation encoding for NewCallParams
// The pointer/index types must be those of `src`.
// The dimLevelTypes aren't actually used by Action::kToIterator.
SmallVector<DimLevelType>(rank, DimLevelType::Dense), AffineMap(),
AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value> dimSizes =
- getDimSizes(rewriter, loc, encSrc, srcTensorTp, src);
+ getDimSizes(rewriter, loc, encSrc, srcType, src);
Value iter = NewCallParams(rewriter, loc)
- .genBuffers(encDst, dimSizes, dstTensorTp)
+ .genBuffers(encDst, dimSizes, resType)
.genNewCall(Action::kToIterator, src);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
Block *insertionBlock = rewriter.getInsertionBlock();
// TODO: Dense buffers should be allocated/deallocated via the callback
// in BufferizationOptions.
- Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, dimSizes);
+ Value dst = allocDenseTensor(rewriter, loc, resType, dimSizes);
SmallVector<Value> noArgs;
SmallVector<Type> noTypes;
auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
// index order. All values are passed by reference through stack
// allocated memrefs.
Location loc = op->getLoc();
- auto tp = op.getTensor().getType().cast<RankedTensorType>();
+ auto tp = getRankedTensorType(op.getTensor());
auto elemTp = tp.getElementType();
unsigned rank = tp.getRank();
auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
- RankedTensorType srcType =
- op.getTensor().getType().cast<RankedTensorType>();
+ auto srcType = getRankedTensorType(op.getTensor());
Type eltType = srcType.getElementType();
Type boolType = rewriter.getIntegerType(1);
Type idxType = rewriter.getIndexType();
Value added = adaptor.getAdded();
Value count = adaptor.getCount();
Value tensor = adaptor.getTensor();
- auto tp = op.getTensor().getType().cast<RankedTensorType>();
+ auto tp = getRankedTensorType(op.getTensor());
Type elemTp = tp.getElementType();
unsigned rank = tp.getRank();
auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
// a[ adjustForOffset(elem.indices) ] = elem.value
// return a
Location loc = op.getLoc();
- auto dstTp = op.getType().cast<RankedTensorType>();
+ auto dstTp = getRankedTensorType(op);
auto encDst = getSparseTensorEncoding(dstTp);
Type elemTp = dstTp.getElementType();
uint64_t concatDim = op.getDimension().getZExtValue();
for (auto it : llvm::zip(op.getInputs(), adaptor.getInputs())) {
Value orignalOp = std::get<0>(it); // Input (with encoding) from Op
Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor
- RankedTensorType srcTp = orignalOp.getType().cast<RankedTensorType>();
+ auto srcTp = getRankedTensorType(orignalOp);
auto encSrc = getSparseTensorEncoding(srcTp);
if (encSrc) {
genSparseCOOIterationLoop(
/// Constructs vector type from pointer.
static VectorType vectorType(VL vl, Value ptr) {
- return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType());
+ return vectorType(vl, getMemRefType(ptr).getElementType());
}
/// Constructs vector iteration mask.