#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
}
};
+/// Sparse codegen rule for the expand op.
+class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
+ Type eltType = srcType.getElementType();
+ Type boolType = rewriter.getIntegerType(1);
+ Type idxType = rewriter.getIndexType();
+ // All initialization should be done on entry of the loop nest.
+ rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
+ // Determine the size for access expansion (always the innermost stored
+ // dimension size, translated back to original dimension). Note that we
+ // recursively rewrite the new DimOp on the **original** tensor.
+ auto enc = getSparseTensorEncoding(srcType);
+ unsigned innerDim = srcType.getRank() - 1;
+ if (AffineMap p = enc.getDimOrdering())
+ innerDim = p.getDimPosition(innerDim);
+ Value sz = rewriter.create<tensor::DimOp>(loc, op.getTensor(), innerDim);
+ // Generate a memref for `sz` elements of type `t`.
+ auto genAlloc = [&](Type t) {
+ auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t);
+ return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
+ };
+ // Allocate temporary buffers for values, filled-switch, and indices.
+ // We do not use stack buffers for this, since the expanded size may
+ // be rather large (as it envelops a single expanded dense dimension).
+ Value values = genAlloc(eltType);
+ Value filled = genAlloc(boolType);
+ Value indices = genAlloc(idxType);
+ Value zero = constantZero(rewriter, loc, idxType);
+ // Reset the values/filled-switch to all-zero/false. Note that this
+ // introduces an O(N) operation into the computation, but this reset
+ // operation is amortized over the innermost loops for the access
+ // pattern expansion. As noted in the operation doc, we would like
+ // to amortize this setup cost even between kernels.
+ rewriter.create<linalg::FillOp>(
+ loc, ValueRange{constantZero(rewriter, loc, eltType)},
+ ValueRange{values});
+ rewriter.create<linalg::FillOp>(
+ loc, ValueRange{constantZero(rewriter, loc, boolType)},
+ ValueRange{filled});
+ // Replace expansion op with these buffers and initial index.
+ assert(op.getNumResults() == 4);
+ rewriter.replaceOp(op, {values, filled, indices, zero});
+ return success();
+ }
+};
+
/// Sparse codegen rule for pointer accesses.
class SparseToPointersConverter
: public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
- SparseCastConverter, SparseTensorAllocConverter,
- SparseTensorDeallocConverter, SparseToPointersConverter,
- SparseToIndicesConverter, SparseToValuesConverter,
- SparseTensorLoadConverter>(typeConverter, patterns.getContext());
+ SparseCastConverter, SparseExpandConverter,
+ SparseTensorAllocConverter, SparseTensorDeallocConverter,
+ SparseToPointersConverter, SparseToIndicesConverter,
+ SparseToValuesConverter, SparseTensorLoadConverter>(
+ typeConverter, patterns.getContext());
}
%1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
return %1 : tensor<10x20x30xf64, #Dense3D>
}
+
+// CHECK-LABEL: func.func @sparse_expansion1()
+// CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64>
+// CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1>
+// CHECK: %[[C:.*]] = memref.alloc() : memref<8xindex>
+// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<8xindex> to memref<?xindex>
+// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<8xf64>)
+// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<8xi1>)
+// CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion1() -> memref<?xindex> {
+ %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR>
+ %values, %filled, %added, %count = sparse_tensor.expand %0
+ : tensor<4x8xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ return %added : memref<?xindex>
+}
+
+// CHECK-LABEL: func.func @sparse_expansion2()
+// CHECK: %[[A:.*]] = memref.alloc() : memref<4xf64>
+// CHECK: %[[B:.*]] = memref.alloc() : memref<4xi1>
+// CHECK: %[[C:.*]] = memref.alloc() : memref<4xindex>
+// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<4xindex> to memref<?xindex>
+// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<4xf64>)
+// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<4xi1>)
+// CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion2() -> memref<?xindex> {
+ %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC>
+ %values, %filled, %added, %count = sparse_tensor.expand %0
+ : tensor<4x8xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ return %added : memref<?xindex>
+}
+
+// CHECK-LABEL: func.func @sparse_expansion3(
+// CHECK-SAME: %[[D0:.*]]: index,
+// CHECK-SAME: %{{.*}}: index) -> memref<?xindex> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[S0:.*]] = memref.alloc() : memref<2xindex>
+// CHECK: memref.store %[[D0]], %[[S0]]{{\[}}%[[C1]]] : memref<2xindex>
+// CHECK: %[[D1:.*]] = memref.load %[[S0]]{{\[}}%[[C1]]] : memref<2xindex>
+// CHECK: %[[V:.*]] = memref.alloc(%[[D1]]) : memref<?xf64>
+// CHECK: %[[B:.*]] = memref.alloc(%[[D1]]) : memref<?xi1>
+// CHECK: %[[D:.*]] = memref.alloc(%[[D1]]) : memref<?xindex>
+// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref<?xf64>)
+// CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
+// CHECK: return %[[D]] : memref<?xindex>
+func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
+ %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSC>
+ %values, %filled, %added, %count = sparse_tensor.expand %0
+ : tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ return %added : memref<?xindex>
+}