[mlir][sparse] support dynamic sparse tensor slices.
authorPeiming Liu <peiming@google.com>
Tue, 10 Jan 2023 22:35:49 +0000 (22:35 +0000)
committerPeiming Liu <peiming@google.com>
Fri, 10 Mar 2023 23:12:41 +0000 (23:12 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D141532

14 files changed:
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir [new file with mode: 0644]
mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 6411222..3bf1118 100644 (file)
@@ -570,7 +570,7 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
 /// We normalized sparse tensor encoding attribute by always using
 /// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well
 /// as other variants) lead to the same storage specifier type, and stripping
-/// irrelevant fields that does not alter the sparse tensor memory layout.
+/// irrelevant fields that do not alter the sparse tensor memory layout.
 static SparseTensorEncodingAttr
 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
   SmallVector<DimLevelType> dlts;
@@ -582,13 +582,10 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
       AffineMap(), // dimOrdering (irrelavant to storage speicifer)
       AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
       // Always use `index` for memSize and lvlSize instead of reusing
-      // `getPosWidth`/`getCrdWidth`.
-      // It allows us to reuse the same SSA value for different bitwidth,
-      // It also avoids casting between index/integer (returned by DimOp)
-      0, 0,
-      // FIXME: we should keep the slice information, for now it is okay as only
-      // constant can be used for slice
-      ArrayRef<SparseTensorDimSliceAttr>{} /*enc.getDimSlices()*/);
+      // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
+      // value for different bitwidth, it also avoids casting between index and
+      // integer (returned by DimOp)
+      0, 0, enc.getDimSlices());
 }
 
 StorageSpecifierType
@@ -620,11 +617,10 @@ static LogicalResult verifySparsifierGetterSetter(
   const auto enc = md.getType().getEncoding();
   const Level lvlRank = enc.getLvlRank();
 
-  // TODO:
-  //  if (mdKind == StorageSpecifierKind::DimOffset ||
-  //      mdKind == StorageSpecifierKind::DimStride)
-  //    if (!enc.isSlice())
-  //      return op->emitError("requested slice data on non-slice tensor");
+  if (mdKind == StorageSpecifierKind::DimOffset ||
+      mdKind == StorageSpecifierKind::DimStride)
+    if (!enc.isSlice())
+      return op->emitError("requested slice data on non-slice tensor");
 
   if (mdKind != StorageSpecifierKind::ValMemSize) {
     if (!lvl)
index 40836f4..bdd6020 100644 (file)
@@ -694,3 +694,23 @@ Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
                                    Value tensor) {
   return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
 }
+
+Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
+                                               Value tensor, Dimension dim) {
+  auto enc = getSparseTensorEncoding(tensor.getType());
+  assert(enc && enc.isSlice());
+  std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
+  if (offset.has_value())
+    return constantIndex(builder, loc, *offset);
+  return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim));
+}
+
+Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
+                                               Value tensor, Dimension dim) {
+  auto enc = getSparseTensorEncoding(tensor.getType());
+  assert(enc && enc.isSlice());
+  std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
+  if (stride.has_value())
+    return constantIndex(builder, loc, *stride);
+  return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
+}
index 6fa3093..6d6351c 100644 (file)
@@ -364,6 +364,15 @@ Value genToValues(OpBuilder &builder, Location loc, Value tensor);
 /// Generates code to retrieve the values size for the sparse tensor.
 Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
 
+/// Generates code to retrieve the slice offset for the sparse tensor slice,
+/// return a constant if the offset is statically known.
+Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
+                                Dimension dim);
+
+/// Generates code to retrieve the slice slice for the sparse tensor slice,
+/// return a constant if the offset is statically known.
+Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor,
+                                Dimension dim);
 } // namespace sparse_tensor
 } // namespace mlir
 
index a8474a1..f48520b 100644 (file)
@@ -43,29 +43,25 @@ static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem,
   return load;
 }
 
-// TODO: Support dynamic sized slice.
-static Value getSliceOffset(OpBuilder &builder, Location loc,
-                            SparseTensorEncodingAttr enc, unsigned lvl) {
-  return constantIndex(builder, loc, *enc.getStaticLvlSliceOffset(lvl));
+static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
+                            unsigned lvl) {
+  auto enc = getSparseTensorEncoding(tensor.getType());
+  // FIXME: `toOrigDim` is deprecated
+  return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
 }
 
-static Value getSliceSize(OpBuilder &builder, Location loc,
-                          SparseTensorEncodingAttr enc, unsigned lvl) {
-  return constantIndex(builder, loc, *enc.getStaticLvlSliceSize(lvl));
-}
-
-static Value getSliceStride(OpBuilder &builder, Location loc,
-                            SparseTensorEncodingAttr enc, unsigned lvl) {
-  return constantIndex(builder, loc, *enc.getStaticLvlSliceStride(lvl));
+static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
+                            unsigned lvl) {
+  auto enc = getSparseTensorEncoding(tensor.getType());
+  // FIXME: `toOrigDim` is deprecated
+  return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
 }
 
 // Converts a coordinate relative to the slice to the coordinate relative
 // to the underlying tensor.
 static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
-                          SparseTensorEncodingAttr enc, unsigned lvl) {
-
-  Value stride = getSliceStride(builder, loc, enc, lvl);
-  Value offset = getSliceOffset(builder, loc, enc, lvl);
+                          Value offset, Value stride, Value tensor,
+                          unsigned lvl) {
   // iv = iv * stride + offset
   v = builder.create<arith::MulIOp>(loc, v, stride);
   v = builder.create<arith::AddIOp>(loc, v, offset);
@@ -75,40 +71,58 @@ static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
 // Converts a coordinate relative to the underlying tensor to the coordinate
 // relative to the slice, returns a extra reminder value
 static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
-                                            Value v,
-                                            SparseTensorEncodingAttr enc,
+                                            Value iv, Value offset,
+                                            Value stride, Value tensor,
                                             unsigned lvl) {
-  Value stride = getSliceStride(builder, loc, enc, lvl);
-  Value offset = getSliceOffset(builder, loc, enc, lvl);
   // iv = (iv - offset) / stride
-  v = builder.create<arith::SubIOp>(loc, v, offset);
-  Value rem = builder.create<arith::RemUIOp>(loc, v, stride);
-  v = builder.create<arith::DivUIOp>(loc, v, stride);
-  return std::make_pair(v, rem);
+  iv = builder.create<arith::SubIOp>(loc, iv, offset);
+  Value rem = builder.create<arith::RemUIOp>(loc, iv, stride);
+  iv = builder.create<arith::DivUIOp>(loc, iv, stride);
+  return std::make_pair(iv, rem);
 }
 
-static std::pair<Value, Value>
-genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
-                       SparseTensorEncodingAttr enc, unsigned lvl) {
-  std::pair<Value, Value> trans = fromSliceCrd(builder, loc, crd, enc, lvl);
-  // First, crd >= offset (TODO: seems unsigned >= 0 won't be folded, skip
-  // the check if the offset is zero).
-  auto geOffset =
-      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, crd,
-                                    getSliceOffset(builder, loc, enc, lvl));
+std::pair<Value, Value>
+LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
+                                    unsigned tid, unsigned lvl) {
+  assert(isSparseSlices[tid]);
+  Value slice = tensors[tid];
+  Value offset = sliceOffsets[tid][lvl];
+  Value stride = sliceStrides[tid][lvl];
+  auto enc = getSparseTensorEncoding(slice.getType());
+
+  std::pair<Value, Value> transformedCrd =
+      fromSliceCrd(builder, loc, crd, offset, stride, slice, lvl);
+
+  SmallVector<Value, 3> conds; // at most 3 conditions
+
+  // First, coord >= offset (skip the check if offset is known to be 0).
+  if (auto staticOffset = enc.getStaticLvlSliceOffset(lvl);
+      !(staticOffset.has_value() && *staticOffset == 0)) {
+    auto geOffset = builder.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::uge, crd, offset);
+    conds.push_back(geOffset);
+  }
+
   // Second, coord_in_slice < length
-  auto ltLength =
-      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, trans.first,
-                                    getSliceSize(builder, loc, enc, lvl));
-
-  // Third, rem == 0; confirmed that (a % 1) will be folded to 0
-  auto fitStride =
-      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, trans.second,
-                                    constantIndex(builder, loc, 0));
-
-  auto pred = builder.create<arith::AndIOp>(loc, geOffset, ltLength);
-  pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
-  return {trans.first, pred};
+  auto ltLength = builder.create<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::ult, transformedCrd.first, lvlSizes[tid][lvl]);
+  conds.push_back(ltLength);
+
+  // Third, rem == 0 (skip the check if stride is known to be 1).
+  if (auto staticStride = enc.getStaticLvlSliceStride(lvl);
+      !(staticStride.has_value() && *staticStride == 1)) {
+    auto fitStride = builder.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::eq, transformedCrd.second,
+        constantIndex(builder, loc, 0));
+    conds.push_back(fitStride);
+  }
+
+  // Must meet all condition to be a valid coordinate in slice.
+  auto pred = conds.front();
+  for (auto cond : ValueRange(conds).drop_front())
+    pred = builder.create<arith::AndIOp>(loc, pred, cond);
+
+  return {transformedCrd.first, pred};
 }
 
 //===----------------------------------------------------------------------===//
@@ -119,10 +133,9 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
                               size_t dim, Value iv) {
   Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
   Value mul = builder.create<arith::MulIOp>(loc, highs[tid][dim], p);
-  if (isSparseSlices[tid]) {
-    auto enc = getSparseTensorEncoding(tensors[tid].getType());
-    iv = toSliceCoord(builder, loc, iv, enc, dim);
-  }
+  if (isSparseSlices[tid])
+    iv = toSliceCoord(builder, loc, iv, sliceOffsets[tid][dim],
+                      sliceStrides[tid][dim], tensors[tid], dim);
   Value add = builder.create<arith::AddIOp>(loc, mul, iv);
   return add;
 }
@@ -204,6 +217,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->isSparseOut = isSparseOut;
   this->tensors.assign(ts.begin(), ts.end());
   this->isSparseSlices.assign(tensors.size(), false);
+  this->sliceOffsets.assign(tensors.size(), std::vector<Value>());
+  this->sliceStrides.assign(tensors.size(), std::vector<Value>());
   this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
   this->pidxs.assign(tensors.size(), std::vector<Value>());
   this->segHi.assign(tensors.size(), std::vector<Value>());
@@ -246,6 +261,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
       dimTypes[tid].assign(rank, DimLevelType::Dense);
 
     // Initialize using empty value.
+    sliceOffsets[tid].assign(rank, Value());
+    sliceStrides[tid].assign(rank, Value());
     pidxs[tid].assign(rank, Value());
     segHi[tid].assign(rank, Value());
     coord[tid].assign(rank, Value());
@@ -300,11 +317,17 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
         assert(isDenseDLT(dlt));
       }
 
-      // Find upper bound in current dimension.
       // FIXME: `toOrigDim` is deprecated
-      const Dimension d = toOrigDim(enc, l);
-      lvlSizes[t][l] = highs[t][l] =
-          mlir::linalg::createOrFoldDimOp(builder, loc, tensor, d);
+      // Since we do not have HigherOrdering now, we can always rely on the 1:1
+      // mapping from level to dimension to retrieve the level size.
+      Value lvlSz = mlir::linalg::createOrFoldDimOp(builder, loc, tensor,
+                                                    toOrigDim(enc, l));
+      // Find upper bound in current dimension.
+      highs[t][l] = lvlSizes[t][l] = lvlSz;
+      if (isSparseSlices[t]) {
+        sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l);
+        sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l);
+      }
     }
 
     // Perform the required bufferization. Dense inputs materialize
@@ -405,7 +428,6 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
     isSparseInput = isSparseInput || isSparse;
   }
 
-  auto enc = getSparseTensorEncoding(tensors[tid].getType());
   const auto reassoc = getCollapseReassociation(tid, dim);
   // TODO: support dynamic slices.
   // Uses the first dimension here to build the loop bound (which is also the
@@ -468,7 +490,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
     for (Value red : reduc)
       types.push_back(red.getType());
 
-    auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, enc, dim);
+    auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, dim);
     bool hasReduc = !types.empty();
     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
                                                /*else*/ hasReduc);
@@ -660,11 +682,8 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
         isSingletonDLT(dimTypes[tid][dim])) {
       coord[tid][dim] = genSparseCrd(builder, loc, tid, dim);
       if (isSparseSlices[tid]) {
-        Value load =
-            genIndexLoad(builder, loc, crdBuffer[tid][dim], pidxs[tid][dim]);
-        auto enc = getSparseTensorEncoding(tensors[tid].getType());
         auto [trans, pred] =
-            genSliceLegitPredicate(builder, loc, load, enc, dim);
+            genSliceLegitPredicate(builder, loc, coord[tid][dim], tid, dim);
         slicesPreds.emplace_back(pred, i);
         // Updates to the relative coordinate to the slice.
         coord[tid][dim] = trans;
@@ -679,7 +698,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
     // Generates a list of if statments
     //  pidx = in_slice ? pidx : pidx + 1
     // TODO: instead of always picking pidx + 1, we should set pidx = high to
-    // break to loop the coordinates is larger than the slice size.
+    // break to loop if the coordinates is larger than the slice size.
     for (auto [pred, idx] : slicesPreds) {
       Value nextPidx = builder.create<arith::AddIOp>(
           loc, yields[idx], constantIndex(builder, loc, 1));
index 1f6ee15..8bc5da0 100644 (file)
@@ -202,6 +202,13 @@ private:
   Value genSparseCrd(OpBuilder &builder, Location loc, size_t tid,
                      size_t dstLvl);
 
+  /// Generates a predicate to determine whether the tranformed coordinates are
+  /// in the given slice.
+  /// Returns std::pair<Transformed coordinates, Predicate>
+  std::pair<Value, Value> genSliceLegitPredicate(OpBuilder &builder,
+                                                 Location loc, Value crd,
+                                                 unsigned tid, unsigned lvl);
+
   bool isOutputTensor(size_t tid) {
     return hasOutput && tid == tensors.size() - 1;
   }
@@ -278,6 +285,9 @@ private:
 
   /// Whether the sparse input is a slice.
   std::vector<bool> isSparseSlices;
+  /// Values related to slices.
+  std::vector<std::vector<Value>> sliceOffsets;
+  std::vector<std::vector<Value>> sliceStrides;
 
   /// Loop Stack, stores the information of all the nested loops that are
   /// alive.
index f3a6adb..0c68c4d 100644 (file)
@@ -130,17 +130,18 @@ Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
 /// Builds IR extracting the pos-th offset from the descriptor.
 Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
                                         Dimension dim) const {
-  return builder.create<LLVM::ExtractValueOp>(
-      loc, value,
-      ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
+  return extractField(
+      builder, loc,
+      ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)});
 }
 
 /// Builds IR inserting the pos-th offset into the descriptor.
 void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
                                           Dimension dim, Value size) {
-  value = builder.create<LLVM::InsertValueOp>(
-      loc, value, size,
-      ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
+  insertField(
+      builder, loc,
+      ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)},
+      size);
 }
 
 /// Builds IR extracting the `lvl`-th level-size from the descriptor.
index 80f2996..71c78d9 100644 (file)
@@ -18,6 +18,9 @@
 #include "CodegenUtils.h"
 #include "SparseTensorStorageLayout.h"
 
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -28,7 +31,6 @@
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/FormatVariadic.h"
 
 #include <optional>
 
@@ -697,6 +699,23 @@ public:
   }
 };
 
+template <typename Op, StorageSpecifierKind kind>
+class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
+public:
+  using OpConversionPattern<Op>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Simply lowers to specifer.get <field> operation.
+    auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
+    auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
+                                    op.getDim().getZExtValue());
+
+    rewriter.replaceOp(op, v);
+    return success();
+  }
+};
+
 /// Sparse codegen rule for trivial tensor casts.
 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
 public:
@@ -1099,13 +1118,15 @@ public:
   }
 };
 
-class SparseExtractSliceCoverter
+class SparseExtractSliceConverter
     : public OpConversionPattern<tensor::ExtractSliceOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *ctx = op.getContext();
     auto srcEnc = getSparseTensorEncoding(op.getSourceType());
     auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
     if (!srcEnc && !dstEnc)
@@ -1119,16 +1140,43 @@ public:
     assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
     assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth());
 
-    // TODO: support dynamic slices.
-    for (int i = 0, e = op.getSourceType().getRank(); i < e; i++) {
-      assert(op.getStaticStrides()[i] == dstEnc.getStaticDimSliceStride(i));
-      assert(op.getStaticOffsets()[i] == dstEnc.getStaticDimSliceOffset(i));
-      assert(op.getStaticSizes()[i] == dstEnc.getStaticDimSliceSize(i));
+    SmallVector<Value> fields;
+    auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
+
+    auto newSpec = rewriter.create<StorageSpecifierInitOp>(
+        loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
+    desc.setSpecifier(newSpec);
+
+    // Fills in slice information.
+    for (const auto &it : llvm::enumerate(llvm::zip(
+             op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()))) {
+      Dimension dim = it.index();
+      auto [offset, size, stride] = it.value();
+
+      Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
+      Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
+      Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
+      // TODO: We could probably only set dynamic value here. But it would
+      // requires us to fill the hole when casting a static slice to dynamic
+      // slice.
+      desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
+                             dim, offsetV);
+
+      // FIXME: we need to distinguish level sizes and dimension size for slices
+      // here. Maybe we should store slice level sizes in a different array
+      // instead of reusing it.
+      assert(srcEnc.hasIdDimOrdering());
+      desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
+                             sizeV);
+      desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
+                             dim, strideV);
     }
 
-    // TODO: create a new specifer for slices (need to encode slice metadata).
-    // It does not matter now because only constant offset/stride are allowed.
-    rewriter.replaceOp(op, adaptor.getSource());
+    // NOTE: we can not generate tuples directly from descriptor here, as the
+    // descriptor is holding the original type, yet we want the slice type
+    // here (they shared every memref but with an updated specifier).
+    rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
+                                    desc.getFields()));
     return success();
   }
 };
@@ -1449,13 +1497,18 @@ void mlir::populateSparseTensorCodegenPatterns(
   patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
                SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
                SparseCastConverter, SparseTensorDeallocConverter,
-               SparseExtractSliceCoverter, SparseTensorLoadConverter,
+               SparseExtractSliceConverter, SparseTensorLoadConverter,
                SparseExpandConverter, SparseCompressConverter,
-               SparseInsertConverter, SparseToPositionsConverter,
-               SparseToCoordinatesConverter, SparseToCoordinatesBufferConverter,
-               SparseToValuesConverter, SparseConvertConverter,
-               SparseNewOpConverter, SparseNumberOfEntriesConverter>(
-      typeConverter, patterns.getContext());
+               SparseInsertConverter,
+               SparseSliceGetterOpConverter<ToSliceOffsetOp,
+                                            StorageSpecifierKind::DimOffset>,
+               SparseSliceGetterOpConverter<ToSliceStrideOp,
+                                            StorageSpecifierKind::DimStride>,
+               SparseToPositionsConverter, SparseToCoordinatesConverter,
+               SparseToCoordinatesBufferConverter, SparseToValuesConverter,
+               SparseConvertConverter, SparseNewOpConverter,
+               SparseNumberOfEntriesConverter>(typeConverter,
+                                               patterns.getContext());
   patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
                                            enableBufferInitialization);
 }
index 69cc3af..788ad28 100644 (file)
@@ -403,6 +403,8 @@ public:
     fields[fidx] = v;
   }
 
+  void setSpecifier(Value newSpec) { fields.back() = newSpec; }
+
   void setSpecifierField(OpBuilder &builder, Location loc,
                          StorageSpecifierKind kind, std::optional<Level> lvl,
                          Value v) {
index caf994c..6c0f13a 100644 (file)
@@ -259,16 +259,16 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>)
   return %0 : index
 }
 
-//// -----
-//
-//#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
-//
-//func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
-//  // _e_xpected-error@+1 {{requested slice data on non-slice tensor}}
-//  %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0
-//       : !sparse_tensor.storage_specifier<#SparseVector> to i64
-//  return %0 : i64
-//}
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+  // expected-error@+1 {{requested slice data on non-slice tensor}}
+  %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0
+       : !sparse_tensor.storage_specifier<#SparseVector>
+  return %0 : index
+}
 
 // -----
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir b/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir
new file mode 100644 (file)
index 0000000..745b0a8
--- /dev/null
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --sparse-tensor-codegen --cse |  FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ]
+}>
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  slice = [ (0, 4, 1), (0, 8, 1) ]
+}>
+
+// CHECK-LABEL:   func.func @sparse_slice(
+// CHECK-SAME:                            %[[VAL_0:.*0]]: memref<?xindex>,
+// CHECK-SAME:                            %[[VAL_1:.*1]]: memref<?xindex>,
+// CHECK-SAME:                            %[[VAL_2:.*2]]: memref<?xf64>,
+// CHECK-SAME:                            %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>)
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.storage_specifier.init with %[[VAL_3]]
+// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.storage_specifier.set %[[VAL_4]]  dim_offset at 0 with %[[VAL_5]]
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.storage_specifier.set %[[VAL_8]]  lvl_sz at 0 with %[[VAL_6]]
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.storage_specifier.set %[[VAL_9]]  dim_stride at 0 with %[[VAL_7]]
+// CHECK:           %[[VAL_11:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]]  dim_offset at 1 with %[[VAL_5]]
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]]  lvl_sz at 1 with %[[VAL_11]]
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]]  dim_stride at 1 with %[[VAL_7]]
+// CHECK:           return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_14]]
+func.func @sparse_slice(%t1 : tensor<8x8xf64, #CSR>) -> tensor<4x8xf64, #CSR_SLICE> {
+  %a1 = tensor.extract_slice %t1[0, 0][4, 8][1, 1] : tensor<8x8xf64, #CSR> to
+                                                     tensor<4x8xf64, #CSR_SLICE>
+  return %a1 : tensor<4x8xf64, #CSR_SLICE>
+}
index ad203ab..8d72e94 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" | FileCheck %s
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" --canonicalize | FileCheck %s
 
 // CHECK-LABEL: func.func @sparse_foreach_constant
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
@@ -27,3 +27,115 @@ func.func @sparse_foreach_constant() -> () {
   }
   return
 }
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ],
+  slice = [ (0, 4, 1), (2, 4, 1) ]
+}>
+
+#CSR_SLICE_DYN = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ],
+  slice = [ (?, ?, ?), (?, ?, ?) ]
+}>
+
+
+// CHECK-LABEL:   func.func @foreach_print_slice_dyn(
+// CHECK-SAME:                                       %[[VAL_0:.*]]: tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 0 : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 0 : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_10:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 1 : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 1 : tensor<?x?xf64,
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64,
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_2]] {
+// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK:             %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_19:.*]] = arith.remui %[[VAL_18]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_20:.*]] = arith.divui %[[VAL_18]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_17]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_23:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_1]] : index
+// CHECK:             %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1
+// CHECK:             %[[VAL_25:.*]] = arith.andi %[[VAL_24]], %[[VAL_23]] : i1
+// CHECK:             scf.if %[[VAL_25]] {
+// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK:               %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_2]] : index
+// CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref<?xindex>
+// CHECK:               scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_2]] {
+// CHECK:                 %[[VAL_30:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_29]]] : memref<?xindex>
+// CHECK:                 %[[VAL_31:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : index
+// CHECK:                 %[[VAL_32:.*]] = arith.remui %[[VAL_31]], %[[VAL_12]] : index
+// CHECK:                 %[[VAL_33:.*]] = arith.divui %[[VAL_31]], %[[VAL_12]] : index
+// CHECK:                 %[[VAL_34:.*]] = arith.cmpi uge, %[[VAL_30]], %[[VAL_11]] : index
+// CHECK:                 %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_10]] : index
+// CHECK:                 %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_1]] : index
+// CHECK:                 %[[VAL_37:.*]] = arith.andi %[[VAL_34]], %[[VAL_35]] : i1
+// CHECK:                 %[[VAL_38:.*]] = arith.andi %[[VAL_37]], %[[VAL_36]] : i1
+// CHECK:                 scf.if %[[VAL_38]] {
+// CHECK:                   %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_29]]] : memref<?xf64>
+// CHECK:                   "test.use"(%[[VAL_39]]) : (f64) -> ()
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return
+//
+func.func @foreach_print_slice_dyn(%A: tensor<?x?xf64, #CSR_SLICE_DYN>) {
+  sparse_tensor.foreach in %A : tensor<?x?xf64, #CSR_SLICE_DYN> do {
+  ^bb0(%1: index, %2: index, %v: f64) :
+    "test.use" (%v) : (f64) -> ()
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @foreach_print_slice(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<4x4xf64,
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64,
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64,
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64,
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64,
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf64,
+// CHECK-DAG:       %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK:             %[[VAL_14:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_1]] : index
+// CHECK:             scf.if %[[VAL_14]] {
+// CHECK:               %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK:               %[[VAL_16:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index
+// CHECK:               %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK:               scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_4]] {
+// CHECK:                 %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:                 %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_2]] : index
+// CHECK:                 %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_19]], %[[VAL_2]] : index
+// CHECK:                 %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_1]] : index
+// CHECK:                 %[[VAL_23:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1
+// CHECK:                 scf.if %[[VAL_23]] {
+// CHECK:                   %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK:                   "test.use"(%[[VAL_24]]) : (f64) -> ()
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return
+//
+func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) {
+  sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR_SLICE> do {
+  ^bb0(%1: index, %2: index, %v: f64) :
+    "test.use" (%v) : (f64) -> ()
+  }
+  return
+}
\ No newline at end of file
index 548818c..a3e9426 100644 (file)
@@ -2,7 +2,7 @@
 // DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
 // DEFINE: mlir-cpu-runner \
 // DEFINE:  -e entry -entry-point-result=void  \
-// DEFINE:  -shared-libs=%mlir_c_runner_utils | \
+// DEFINE:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
 // DEFINE: FileCheck %s
 //
 // RUN: %{command}
   slice = [ (1, 4, 1), (1, 4, 2) ]
 }>
 
+#CSR_SLICE_DYN = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  slice = [ (?, ?, ?), (?, ?, ?) ]
+}>
+
+
 module {
   func.func @foreach_print_non_slice(%A: tensor<4x4xf64, #CSR>) {
     sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR> do {
@@ -39,8 +45,22 @@ module {
     return
   }
 
+  func.func @foreach_print_slice_dyn(%A: tensor<?x?xf64, #CSR_SLICE_DYN>) {
+    sparse_tensor.foreach in %A : tensor<?x?xf64, #CSR_SLICE_DYN> do {
+    ^bb0(%1: index, %2: index, %v: f64) :
+      vector.print %1: index
+      vector.print %2: index
+      vector.print %v: f64
+    }
+    return
+  }
+
   func.func @entry() {
     %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c4 = arith.constant 4 : index
+
     %sa = arith.constant dense<[
         [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
         [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
@@ -52,6 +72,7 @@ module {
         [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ]
     ]> : tensor<8x8xf64>
 
+
     %tmp = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
     %a = tensor.extract_slice %tmp[1, 1][4, 4][1, 2] : tensor<8x8xf64, #CSR> to
                                                        tensor<4x4xf64, #CSR_SLICE>
@@ -72,7 +93,7 @@ module {
     %dense = tensor.extract_slice %sa[1, 1][4, 4][1, 2] : tensor<8x8xf64> to
                                                           tensor<4x4xf64>
     %b = sparse_tensor.convert %dense : tensor<4x4xf64> to tensor<4x4xf64, #CSR>
-    // Foreach on sparse tensor instead of slice should yield the same result.
+    // Foreach on sparse tensor instead of slice they should yield the same result.
     //
     // CHECK-NEXT: 1
     // CHECK-NEXT: 0
@@ -86,8 +107,28 @@ module {
     //
     call @foreach_print_non_slice(%b) : (tensor<4x4xf64, #CSR>) -> ()
 
-    bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR>
+    // The same slice, but with dynamic encoding.
+    // TODO: Investigates why reusing the same %tmp above would cause bufferization
+    // errors.
+    %tmp1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
+    %a_dyn = tensor.extract_slice %tmp1[%c1, %c1][%c4, %c4][%c1, %c2] :
+          tensor<8x8xf64, #CSR> to tensor<?x?xf64, #CSR_SLICE_DYN>
+    //
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 2.3
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 3
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 3
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 2.1
+    //
+    call @foreach_print_slice_dyn(%a_dyn) : (tensor<?x?xf64, #CSR_SLICE_DYN>) -> ()
+
     bufferization.dealloc_tensor %tmp : tensor<8x8xf64, #CSR>
+    bufferization.dealloc_tensor %tmp1 : tensor<8x8xf64, #CSR>
+    bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR>
     return
   }
 }
index 8f77b3d..ffa6ad8 100644 (file)
   slice = [ (0, 4, 2), (1, 4, 1) ]
 }>
 
+#CSR_SLICE_dyn = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  slice = [ (?, 4, ?), (?, 4, ?) ]
+}>
+
+#DCSR_SLICE_dyn = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ],
+  slice = [ (?, 4, ?), (?, 4, ?) ]
+}>
+
+
 module {
   func.func private @printMemrefF64(%ptr : tensor<*xf64>)
   func.func private @printMemref1dF64(%ptr : memref<?xf64>) attributes { llvm.emit_c_interface }
 
+  //
+  // Computes C = A x B with all matrices dynamic sparse slice (SpMSpM) in CSR and DCSR
+  //
+  func.func @matmul_dyn(%A: tensor<4x4xf64, #CSR_SLICE_dyn>,
+                        %B: tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR> {
+    %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
+    %D = linalg.matmul
+      ins(%A, %B: tensor<4x4xf64, #CSR_SLICE_dyn>, tensor<4x4xf64, #DCSR_SLICE_dyn>)
+         outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+    return %D: tensor<4x4xf64, #CSR>
+  }
 
   //
   // Computes C = A x B with one matrix CSR sparse slices and the other DSCR sparse slice.
@@ -83,7 +105,9 @@ module {
   // Main driver.
   //
   func.func @entry() {
-    %c0 = arith.constant 0 : index
+    %c_0 = arith.constant 0 : index
+    %c_1 = arith.constant 1 : index
+    %c_2 = arith.constant 2 : index
     %f0 = arith.constant 0.0 : f64
 
     %sa = arith.constant dense<[
@@ -158,11 +182,27 @@ module {
     %4 = call @matmul1(%s2, %s1)
        : (tensor<4x4xf64, #CSR_SLICE_1>,
           tensor<4x4xf64, #DCSR_SLICE_1>) -> tensor<4x4xf64, #CSR>
-
     %c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
     %c4u = tensor.cast %c4 : tensor<4x4xf64> to tensor<*xf64>
     call @printMemrefF64(%c4u) : (tensor<*xf64>) -> ()
 
+    // slice x slice (same as above, but with dynamic stride information)
+    //
+    // CHECK:      [2.3,   0,   0,   0],
+    // CHECK-NEXT: [6.9,   0,   0,   0],
+    // CHECK-NEXT: [0,   0,   0,   0],
+    // CHECK-NEXT: [12.6,   0,   0,   0]]
+    //
+    %s1_dyn = tensor.extract_slice %tmp[%c_0, %c_1][4, 4][%c_2, %c_1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_dyn>
+    %s2_dyn = tensor.extract_slice %b1[%c_0, %c_0][4, 4][%c_2, %c_1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_dyn>
+    %dyn_4 = call @matmul_dyn(%s2_dyn, %s1_dyn)
+       : (tensor<4x4xf64, #CSR_SLICE_dyn>,
+          tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR>
+
+    %c4_dyn = sparse_tensor.convert %dyn_4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+    %c4u_dyn = tensor.cast %c4_dyn : tensor<4x4xf64> to tensor<*xf64>
+    call @printMemrefF64(%c4u_dyn) : (tensor<*xf64>) -> ()
+
     // sparse slices should generate the same result as dense slices
     //
     // CHECK:      [2.3,   0,   0,   0],
@@ -179,7 +219,7 @@ module {
     %du = tensor.cast %r : tensor<4x4xf64> to tensor<*xf64>
     call @printMemrefF64(%du) : (tensor<*xf64>) -> ()
 
-    // Releases resources.
+    // Releases resources (we do not need to deallocate slices).
     bufferization.dealloc_tensor %b1 : tensor<8x4xf64, #CSR>
     bufferization.dealloc_tensor %t1 : tensor<8x8xf64, #CSR>
     bufferization.dealloc_tensor %b  : tensor<8x4xf64, #DCSR>
@@ -187,6 +227,7 @@ module {
     bufferization.dealloc_tensor %4  : tensor<4x4xf64, #CSR>
     bufferization.dealloc_tensor %3  : tensor<4x4xf64, #CSR>
     bufferization.dealloc_tensor %2  : tensor<4x4xf64, #DCSR>
+    bufferization.dealloc_tensor %dyn_4 : tensor<4x4xf64, #CSR>
 
     return
   }
index 3f98d82..a3f58b0 100644 (file)
@@ -2223,6 +2223,7 @@ cc_library(
     deps = [
         ":AffineDialect",
         ":ArithDialect",
+        ":ArithUtils",
         ":BufferizationDialect",
         ":BufferizationTransforms",
         ":ComplexDialect",