[mlir][sparse] make foreach operation support sparse tensor slices.
authorPeiming Liu <peiming@google.com>
Wed, 28 Dec 2022 01:42:53 +0000 (01:42 +0000)
committerPeiming Liu <peiming@google.com>
Wed, 8 Feb 2023 18:58:35 +0000 (18:58 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140713

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir [new file with mode: 0644]

index f61e1f5..fda7a5c 100644 (file)
@@ -534,7 +534,10 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
       enc.getContext(), dlts,
       AffineMap(), // dimOrdering (irrelavant to storage speicifer)
       AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
-      enc.getPointerBitWidth(), enc.getIndexBitWidth());
+      enc.getPointerBitWidth(), enc.getIndexBitWidth(),
+      // FIXME: we should keep the slice information, for now it is okay as only
+      // constant can be used for slice
+      ArrayRef<SparseTensorDimSliceAttr>{} /*enc.getDimSlices()*/);
 }
 
 StorageSpecifierType
index 88981fc..df19b61 100644 (file)
@@ -42,6 +42,50 @@ static Value genIndexLoad(OpBuilder &builder, Location loc, Value ptr,
   return load;
 }
 
+// TODO: Support dynamic sized slice.
+static Value getSliceOffset(OpBuilder &builder, Location loc,
+                            SparseTensorEncodingAttr enc, unsigned lvl) {
+  return constantIndex(builder, loc, *enc.getStaticLvlSliceOffset(lvl));
+}
+
+static Value getSliceSize(OpBuilder &builder, Location loc,
+                          SparseTensorEncodingAttr enc, unsigned lvl) {
+  return constantIndex(builder, loc, *enc.getStaticLvlSliceSize(lvl));
+}
+
+static Value getSliceStride(OpBuilder &builder, Location loc,
+                            SparseTensorEncodingAttr enc, unsigned lvl) {
+  return constantIndex(builder, loc, *enc.getStaticLvlSliceStride(lvl));
+}
+
+// Converts a coordinate relative to the slice to the coordinate relative
+// to the underlying tensor.
+static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
+                          SparseTensorEncodingAttr enc, unsigned lvl) {
+
+  Value stride = getSliceStride(builder, loc, enc, lvl);
+  Value offset = getSliceOffset(builder, loc, enc, lvl);
+  // iv = iv * stride + offset
+  v = builder.create<arith::MulIOp>(loc, v, stride);
+  v = builder.create<arith::AddIOp>(loc, v, offset);
+  return v;
+}
+
+// Converts a coordinate relative to the underlying tensor to the coordinate
+// relative to the slice, returns a extra reminder value
+static std::pair<Value, Value> fromSliceCoord(OpBuilder &builder, Location loc,
+                                              Value v,
+                                              SparseTensorEncodingAttr enc,
+                                              unsigned lvl) {
+  Value stride = getSliceStride(builder, loc, enc, lvl);
+  Value offset = getSliceOffset(builder, loc, enc, lvl);
+  // iv = (iv - offset) / stride
+  v = builder.create<arith::SubIOp>(loc, v, offset);
+  Value rem = builder.create<arith::RemUIOp>(loc, v, stride);
+  v = builder.create<arith::DivUIOp>(loc, v, stride);
+  return std::make_pair(v, rem);
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse tensor loop emitter class implementations
 //===----------------------------------------------------------------------===//
@@ -50,6 +94,10 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
                               size_t dim, Value iv) {
   Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
   Value mul = builder.create<arith::MulIOp>(loc, highs[tid][dim], p);
+  if (isSparseSlices[tid]) {
+    auto enc = getSparseTensorEncoding(tensors[tid].getType());
+    iv = toSliceCoord(builder, loc, iv, enc, dim);
+  }
   Value add = builder.create<arith::AddIOp>(loc, mul, iv);
   return add;
 }
@@ -67,6 +115,7 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
   this->hasOutput = hasOutput;
   this->isSparseOut = isSparseOut;
   this->tensors.assign(tensors.begin(), tensors.end());
+  this->isSparseSlices.assign(tensors.size(), false);
   this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
   this->pidxs.assign(tensors.size(), std::vector<Value>());
   this->coord.assign(tensors.size(), std::vector<Value>());
@@ -87,10 +136,11 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
     auto enc = getSparseTensorEncoding(rtp);
     // We always treat sparse output tensor as dense so that we always iterate
     // it based on dim size.
-    if (enc && !(isOutputTensor(tid) && isSparseOut))
+    if (enc && !(isOutputTensor(tid) && isSparseOut)) {
+      isSparseSlices[tid] = enc.isSlice();
       for (auto dimTp : enc.getDimLevelType())
         dimTypes[tid].push_back(dimTp);
-    else
+    else
       dimTypes[tid].assign(rank, DimLevelType::Dense);
 
     // Initialize using empty value.
@@ -218,7 +268,6 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
     ArrayRef<size_t> dims, MutableArrayRef<Value> reduc, bool isParallel) {
   // TODO: support multiple return on parallel for?
   assert(!isParallel || reduc.size() <= 1);
-
   bool isSparseInput = false;
   size_t tid = tids.front(), dim = dims.front();
   for (auto [t, d] : llvm::zip(tids, dims)) {
@@ -239,10 +288,13 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
     isSparseInput = isSparseInput || isSparse;
   }
 
+  auto enc = getSparseTensorEncoding(tensors[tid].getType());
+  // TODO: support dynamic slices.
   Value step = constantIndex(builder, loc, 1);
   Value lo = isSparseInput ? pidxs[tid][dim]      // current offset
-                           : loopSeqStack.back(); // univeral tid
+                           : loopSeqStack.back(); // universal index
   Value hi = highs[tid][dim];
+
   Operation *loop = nullptr;
   Value iv;
   if (isParallel) {
@@ -275,15 +327,64 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
   }
   assert(loop && iv);
 
+  Value c;
   if (isSparseInput) {
     pidxs[tid][dim] = iv;
     // Generating a load on the indices array yields the coordinate.
     Value ptr = idxBuffer[tid][dim];
-    coord[tid][dim] = genIndexLoad(builder, loc, ptr, iv);
+    c = genIndexLoad(builder, loc, ptr, iv);
   } else {
     // Dense tensor, the coordinates is the inducation variable.
-    coord[tid][dim] = iv;
+    c = iv;
   }
+
+  if (isSparseSlices[tid] && isSparseInput) {
+    // For sparse level slices, we need to filter out invalid coordinates that
+    // are not included in the slice.
+    std::pair<Value, Value> trans = fromSliceCoord(builder, loc, c, enc, dim);
+    SmallVector<Type> types;
+    for (Value red : reduc)
+      types.push_back(red.getType());
+
+    // First, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip
+    // the check if the offset is zero).
+    auto geOff =
+        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, c,
+                                      getSliceOffset(builder, loc, enc, dim));
+    // Second, coords < length
+    auto ltLen = builder.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, trans.first,
+        getSliceSize(builder, loc, enc, dim));
+
+    // Third, rem == 0; confirmed that (a % 1) will be folded to 0
+    auto fitStride = builder.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::eq, trans.second,
+        constantIndex(builder, loc, 0));
+
+    auto pred = builder.create<arith::AndIOp>(loc, geOff, ltLen);
+    pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
+    bool hasReduc = !types.empty();
+    scf::IfOp ifOp =
+        builder.create<scf::IfOp>(loc, types, pred, /*else*/ hasReduc);
+    if (hasReduc) {
+      // scf.for (a) -> v
+      //  %s = scf.if (a) -> v
+      //    user-generated code.
+      //  else
+      //    yield a
+      //  yield %s
+      builder.create<scf::YieldOp>(loc, ifOp.getResults());
+      builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+      // On mismatch.
+      builder.create<scf::YieldOp>(loc, reduc);
+    }
+    // Set the insertion point to matched branch.
+    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    c = trans.first;
+  }
+
+  assert(c);
+  coord[tid][dim] = c;
   // NOTE: we can also prepare for next dim here in advance
   // Push the loop into stack
   loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
index a1db60c..832281a 100644 (file)
@@ -259,22 +259,25 @@ private:
   std::vector<std::vector<Value>> idxBuffer; // to_indices
   std::vector<Value> valBuffer;              // to_value
 
-  // Loop Stack, stores the information of all the nested loops that are
-  // alive.
+  /// Whether the sparse input is a slice.
+  std::vector<bool> isSparseSlices;
+
+  /// Loop Stack, stores the information of all the nested loops that are
+  /// alive.
   std::vector<LoopLevelInfo> loopStack;
 
-  // Loop Sequence Stack, stores the unversial index for the current loop
-  // sequence.
+  /// Loop Sequence Stack, stores the unversial index for the current loop
+  /// sequence.
   std::vector<Value> loopSeqStack;
 
-  // Maps AffineDimExpr to the index of the loop in loopStack.
-  // TODO: We should probably use a callback function here to make it more
-  // general.
+  /// Maps AffineDimExpr to the index of the loop in loopStack.
+  /// TODO: We should probably use a callback function here to make it more
+  /// general.
   std::vector<unsigned> sparsiferLoopLvlMap;
 
-  // TODO: not yet used, it should track the current level for each tensor
-  // to help eliminate `dim` paramters from above APIs.
-  // std::vector<size_t> curLv;
+  /// TODO: not yet used, it should track the current level for each tensor
+  /// to help eliminate `dim` paramters from above APIs.
+  /// std::vector<size_t> curLv;
 };
 
 } // namespace sparse_tensor
index 11348e0..b2541c3 100644 (file)
@@ -1010,6 +1010,40 @@ public:
   }
 };
 
+class SparseExtractSliceCoverter
+    : public OpConversionPattern<tensor::ExtractSliceOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcEnc = getSparseTensorEncoding(op.getSourceType());
+    auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
+    if (!srcEnc && !dstEnc)
+      return failure();
+
+    // TODO: We should check these in ExtractSliceOp::verify.
+    assert(srcEnc && dstEnc && dstEnc.isSlice());
+    assert(srcEnc.getDimLevelType() == dstEnc.getDimLevelType());
+    assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering());
+    assert(srcEnc.getHigherOrdering() == dstEnc.getHigherOrdering());
+    assert(srcEnc.getPointerBitWidth() == dstEnc.getPointerBitWidth());
+    assert(srcEnc.getIndexBitWidth() == dstEnc.getIndexBitWidth());
+
+    // TODO: support dynamic slices.
+    for (int i = 0, e = op.getSourceType().getRank(); i < e; i++) {
+      assert(op.getStaticStrides()[i] == dstEnc.getStaticDimSliceStride(i));
+      assert(op.getStaticOffsets()[i] == dstEnc.getStaticDimSliceOffset(i));
+      assert(op.getStaticSizes()[i] == dstEnc.getStaticDimSliceSize(i));
+    }
+
+    // TODO: create a new specifer for slices (need to encode slice metadata).
+    // It does not matter now because only constant offset/stride are allowed.
+    rewriter.replaceOp(op, adaptor.getSource());
+    return success();
+  }
+};
+
 /// Sparse codegen rule for number of entries operator.
 class SparseNumberOfEntriesConverter
     : public OpConversionPattern<NumberOfEntriesOp> {
@@ -1133,13 +1167,13 @@ void mlir::populateSparseTensorCodegenPatterns(
     bool enableBufferInitialization) {
   patterns.add<SparsePackOpConverter, SparseReturnConverter,
                SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
-               SparseTensorDeallocConverter, SparseTensorLoadConverter,
-               SparseExpandConverter, SparseCompressConverter,
-               SparseInsertConverter, SparseToPointersConverter,
-               SparseToIndicesConverter, SparseToIndicesBufferConverter,
-               SparseToValuesConverter, SparseConvertConverter,
-               SparseNumberOfEntriesConverter>(typeConverter,
-                                               patterns.getContext());
+               SparseTensorDeallocConverter, SparseExtractSliceCoverter,
+               SparseTensorLoadConverter, SparseExpandConverter,
+               SparseCompressConverter, SparseInsertConverter,
+               SparseToPointersConverter, SparseToIndicesConverter,
+               SparseToIndicesBufferConverter, SparseToValuesConverter,
+               SparseConvertConverter, SparseNumberOfEntriesConverter>(
+      typeConverter, patterns.getContext());
   patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
                                            enableBufferInitialization);
 }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
new file mode 100644 (file)
index 0000000..560c64f
--- /dev/null
@@ -0,0 +1,94 @@
+// DEFINE: %{option} = enable-runtime-library=false
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+//
+
+// TODO: support slices on lib path
+#CSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ]
+}>
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  slice = [ (1, 4, 1), (1, 4, 2) ]
+}>
+
+module {
+  func.func @foreach_print_non_slice(%A: tensor<4x4xf64, #CSR>) {
+    sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR> do {
+    ^bb0(%1: index, %2: index, %v: f64) :
+      vector.print %1: index
+      vector.print %2: index
+      vector.print %v: f64
+    }
+    return
+  }
+
+  func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) {
+    sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR_SLICE> do {
+    ^bb0(%1: index, %2: index, %v: f64) :
+      vector.print %1: index
+      vector.print %2: index
+      vector.print %v: f64
+    }
+    return
+  }
+
+  func.func @entry() {
+    %c0 = arith.constant 0 : index
+    %sa = arith.constant dense<[
+        [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ],
+        [ 0.0, 0.0, 0.1, 0.0, 0.0, 2.1, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0, 3.1, 0.0, 0.0, 0.0 ],
+        [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 3.3, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ]
+    ]> : tensor<8x8xf64>
+
+    %tmp = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
+    %a = tensor.extract_slice %tmp[1, 1][4, 4][1, 2] : tensor<8x8xf64, #CSR> to
+                                                       tensor<4x4xf64, #CSR_SLICE>
+    // Foreach on sparse tensor slices directly
+    //
+    // CHECK: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 2.3
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 3
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 3
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 2.1
+    //
+    call @foreach_print_slice(%a) : (tensor<4x4xf64, #CSR_SLICE>) -> ()
+
+    // FIXME: investigate why a tensor copy is inserted for this slice
+//    %dense = tensor.extract_slice %sa[1, 1][4, 4][1, 2] : tensor<8x8xf64> to
+//                                                          tensor<4x4xf64>
+//    %b = sparse_tensor.convert %dense : tensor<4x4xf64> to tensor<4x4xf64, #CSR>
+//    // Foreach on sparse tensor instead of slice they should yield the same result.
+//    //
+//    // C_HECK-NEXT: 1
+//    // C_HECK-NEXT: 0
+//    // C_HECK-NEXT: 2.3
+//    // C_HECK-NEXT: 2
+//    // C_HECK-NEXT: 3
+//    // C_HECK-NEXT: 1
+//    // C_HECK-NEXT: 3
+//    // C_HECK-NEXT: 2
+//    // C_HECK-NEXT: 2.1
+//    //
+//    call @foreach_print_non_slice(%b) : (tensor<4x4xf64, #CSR>) -> ()
+//    bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR>
+
+    bufferization.dealloc_tensor %tmp : tensor<8x8xf64, #CSR>
+    return
+  }
+}