#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
// TODO: Maybe pick the bitwidth based on input/output tensors (probably the
// largest one among them) in the original operation instead of using the
// default value.
+ unsigned pointerBitWidth = encSrc ? encSrc.getPointerBitWidth() : 0;
+ unsigned indexBitWidth = encSrc ? encSrc.getIndexBitWidth() : 0;
auto enc = SparseTensorEncodingAttr::get(
ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), AffineMap(),
- encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
+ pointerBitWidth, indexBitWidth);
return RankedTensorType::get(src.getShape(), src.getElementType(), enc);
}
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto rtp = op.getType().cast<RankedTensorType>();
- // TODO: Build the output shape if needed.
- assert(rtp.hasStaticShape());
- auto rank = rtp.getRank();
size_t conDim = op.getDimension().getZExtValue();
+ SmallVector<Value> dynSizes;
+ if (!rtp.hasStaticShape()) {
+ ArrayRef<int64_t> rShape = rtp.getShape();
+ for (const auto &d : llvm::enumerate(rShape)) {
+ if (d.value() == ShapedType::kDynamicSize) {
+ Value v =
+ createOrFoldDimOp(rewriter, loc, op.getOperand(0), d.index());
+ rewriter.create<tensor::DimOp>(loc, op.getOperand(0), d.index());
+ for (const auto &opnd : op.getOperands().drop_front()) {
+ Value t = createOrFoldDimOp(rewriter, loc, opnd, d.index());
+ v = rewriter.create<arith::AddIOp>(loc, v, t);
+ }
+ dynSizes.push_back(v);
+ }
+ }
+ }
+
// %t = concatenate %s1, %s2, %s3 {dim = 1}
// ==>
// %tmp = bufferization.alloc_tensor : unordered COO
// %t = sparse_tensor.cast %tmp
auto cooTp = getUnorderedCOOFromType(rtp);
auto cooBuffer =
- rewriter.create<AllocTensorOp>(loc, cooTp, ValueRange()).getResult();
-
+ rewriter.create<AllocTensorOp>(loc, cooTp, dynSizes).getResult();
+ auto rank = rtp.getRank();
Value offset = constantIndex(rewriter, loc, 0);
ForeachOp foreachOp;
for (Value input : op.getInputs()) {
- // Builds the indexing map.
-
// Build a for op for each input tensor to append new values into the
// output tensor.
foreachOp = rewriter.create<ForeachOp>(
idx = builder.create<arith::AddIOp>(loc, idx, offset);
indices.push_back(idx);
}
- auto t = builder.create<InsertOp>(loc, v, reduc.front(), indices);
- builder.create<sparse_tensor::YieldOp>(loc, t);
+ Value cond = genIsNonzero(rewriter, loc, v);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(
+ loc, TypeRange(reduc.front().getType()), cond, /*else*/ true);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value t = builder.create<InsertOp>(loc, v, reduc.front(), indices);
+ rewriter.create<scf::YieldOp>(loc, t);
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ rewriter.create<scf::YieldOp>(loc, reduc.front());
+ rewriter.setInsertionPointAfter(ifOp);
+ rewriter.create<sparse_tensor::YieldOp>(loc, ifOp.getResult(0));
});
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
for (uint64_t i = 0; i < rank; i++) {
uint64_t orgDim = toOrigDim(encSrc, i);
xs[toStoredDim(encDst, orgDim)] = rewriter.create<ToIndicesOp>(
- loc, indTp, src, rewriter.getIndexAttr(orgDim));
+ loc, indTp, src, rewriter.getIndexAttr(i));
}
// Retrieve NNZ.