[mlir][sparse] Add option enable-buffer-initialization to the sparse-tensor-codegen...
authorbixia1 <bixia@google.com>
Thu, 10 Nov 2022 00:16:03 +0000 (16:16 -0800)
committerbixia1 <bixia@google.com>
Thu, 10 Nov 2022 15:23:33 +0000 (07:23 -0800)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137733

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir [new file with mode: 0644]

index 8d704dc..19ff2eb 100644 (file)
@@ -130,9 +130,11 @@ public:
 
 /// Sets up sparse tensor conversion rules.
 void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
-                                         RewritePatternSet &patterns);
+                                         RewritePatternSet &patterns,
+                                         bool enableBufferInitialization);
 
-std::unique_ptr<Pass> createSparseTensorCodegenPass();
+std::unique_ptr<Pass>
+createSparseTensorCodegenPass(bool enableBufferInitialization = false);
 
 //===----------------------------------------------------------------------===//
 // The SparseTensorRewriting pass.
index b7c4baa..74784af 100644 (file)
@@ -181,6 +181,10 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
     "scf::SCFDialect",
     "sparse_tensor::SparseTensorDialect",
   ];
+  let options = [
+    Option<"enableBufferInitialization", "enable-buffer-initialization", "bool",
+           "false", "Enable zero-initialization of the memory buffers">,
+  ];
 }
 
 def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
index 82d060e..478bac5 100644 (file)
@@ -64,7 +64,8 @@ void mlir::sparse_tensor::buildSparseCompiler(
     pm.addPass(createSparseTensorConversionPass(
         options.sparseTensorConversionOptions()));
   else
-    pm.addPass(createSparseTensorCodegenPass());
+    pm.addPass(
+        createSparseTensorCodegenPass(options.enableBufferInitialization));
   pm.addPass(createSparseBufferRewritePass(options.enableBufferInitialization));
   pm.addPass(createDenseBufferizationPass(
       getBufferizationOptions(/*analysisOnly=*/false)));
index c161fa5..499dc4c 100644 (file)
@@ -287,9 +287,15 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
 
 /// Creates allocation operation.
 static Value createAllocation(OpBuilder &builder, Location loc, Type type,
-                              Value sz) {
+                              Value sz, bool enableInit) {
   auto memType = MemRefType::get({ShapedType::kDynamicSize}, type);
-  return builder.create<memref::AllocOp>(loc, memType, sz);
+  Value buffer = builder.create<memref::AllocOp>(loc, memType, sz);
+  if (enableInit) {
+    Value fillValue =
+        builder.create<arith::ConstantOp>(loc, type, builder.getZeroAttr(type));
+    builder.create<linalg::FillOp>(loc, fillValue, buffer);
+  }
+  return buffer;
 }
 
 /// Creates allocation for each field in sparse tensor type. Note that
@@ -300,7 +306,7 @@ static Value createAllocation(OpBuilder &builder, Location loc, Type type,
 ///       on the required capacities (see heuristic variable).
 ///
 static void createAllocFields(OpBuilder &builder, Location loc, Type type,
-                              ValueRange dynSizes,
+                              ValueRange dynSizes, bool enableInit,
                               SmallVectorImpl<Value> &fields) {
   auto enc = getSparseTensorEncoding(type);
   assert(enc);
@@ -334,16 +340,20 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
   // Per-dimension storage.
   for (unsigned r = 0; r < rank; r++) {
     if (isCompressedDim(rtp, r)) {
-      fields.push_back(createAllocation(builder, loc, ptrType, heuristic));
-      fields.push_back(createAllocation(builder, loc, idxType, heuristic));
+      fields.push_back(
+          createAllocation(builder, loc, ptrType, heuristic, enableInit));
+      fields.push_back(
+          createAllocation(builder, loc, idxType, heuristic, enableInit));
     } else if (isSingletonDim(rtp, r)) {
-      fields.push_back(createAllocation(builder, loc, idxType, heuristic));
+      fields.push_back(
+          createAllocation(builder, loc, idxType, heuristic, enableInit));
     } else {
       assert(isDenseDim(rtp, r)); // no fields
     }
   }
   // The values array.
-  fields.push_back(createAllocation(builder, loc, eltType, heuristic));
+  fields.push_back(
+      createAllocation(builder, loc, eltType, heuristic, enableInit));
   assert(fields.size() == lastField);
   // Initialize the storage scheme to an empty tensor. Initialized memSizes
   // to all zeros, sets the dimSizes to known values and gives all pointer
@@ -685,6 +695,10 @@ class SparseTensorAllocConverter
     : public OpConversionPattern<bufferization::AllocTensorOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
+  SparseTensorAllocConverter(TypeConverter &typeConverter, MLIRContext *context,
+                             bool enableInit)
+      : OpConversionPattern(typeConverter, context),
+        enableBufferInitialization(enableInit) {}
   LogicalResult
   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -698,11 +712,15 @@ public:
     // Construct allocation for each field.
     Location loc = op.getLoc();
     SmallVector<Value, 8> fields;
-    createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields);
+    createAllocFields(rewriter, loc, resType, adaptor.getOperands(),
+                      enableBufferInitialization, fields);
     // Replace operation with resulting memrefs.
     rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
     return success();
   }
+
+private:
+  bool enableBufferInitialization;
 };
 
 /// Sparse codegen rule for the dealloc operator.
@@ -1014,8 +1032,9 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 
 /// Populates the given patterns list with conversion rules required for
 /// the sparsification of linear algebra operations.
-void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
-                                               RewritePatternSet &patterns) {
+void mlir::populateSparseTensorCodegenPatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    bool enableBufferInitialization) {
   patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
                SparseCastConverter, SparseTensorAllocConverter,
                SparseTensorDeallocConverter, SparseTensorLoadConverter,
@@ -1024,4 +1043,6 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                SparseToIndicesConverter, SparseToValuesConverter,
                SparseConvertConverter, SparseNumberOfEntriesConverter>(
       typeConverter, patterns.getContext());
+  patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
+                                           enableBufferInitialization);
 }
index 8bc132f..f74eb5f 100644 (file)
@@ -161,6 +161,9 @@ struct SparseTensorCodegenPass
 
   SparseTensorCodegenPass() = default;
   SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
+  SparseTensorCodegenPass(bool enableInit) {
+    enableBufferInitialization = enableInit;
+  }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
@@ -203,7 +206,8 @@ struct SparseTensorCodegenPass
                                                                    converter);
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                          target);
-    populateSparseTensorCodegenPatterns(converter, patterns);
+    populateSparseTensorCodegenPatterns(converter, patterns,
+                                        enableBufferInitialization);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
@@ -278,8 +282,9 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
   return std::make_unique<SparseTensorConversionPass>(options);
 }
 
-std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
-  return std::make_unique<SparseTensorCodegenPass>();
+std::unique_ptr<Pass>
+mlir::createSparseTensorCodegenPass(bool enableBufferInitialization) {
+  return std::make_unique<SparseTensorCodegenPass>(enableBufferInitialization);
 }
 
 std::unique_ptr<Pass>
diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
new file mode 100644 (file)
index 0000000..043cdbb
--- /dev/null
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s --sparse-tensor-codegen=enable-buffer-initialization=true  --canonicalize --cse | FileCheck %s
+
+#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
+// CHECK-LABEL: func @sparse_alloc_sparse_vector(
+//  CHECK-SAME: %[[A:.*]]: index) ->
+//  CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[F0:.*]] = arith.constant 0.{{0*}}e+00 : f64
+//       CHECK: %[[T0:.*]] = memref.alloc() : memref<1xindex>
+//       CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex>
+//       CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex>
+//       CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref<?xindex>
+//       CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T2]] : memref<16xindex>)
+//       CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex>
+//       CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref<?xindex>
+//       CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T4]] : memref<16xindex>)
+//       CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64>
+//       CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref<?xf64>
+//       CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[T6]] : memref<16xf64>)
+//       CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>)
+//       CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<1xindex>
+//       CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]]
+//       CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]]
+//       CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] :
+func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor<?xf64, #SV> {
+  %0 = bufferization.alloc_tensor(%arg0) : tensor<?xf64, #SV>
+  %1 = sparse_tensor.load %0 : tensor<?xf64, #SV>
+  return %1 : tensor<?xf64, #SV>
+}