[mlir][sparse] Fix problems in creating complex zero for initialization.
authorbixia1 <bixia@google.com>
Thu, 8 Dec 2022 00:14:22 +0000 (16:14 -0800)
committerbixia1 <bixia@google.com>
Thu, 8 Dec 2022 15:49:27 +0000 (07:49 -0800)
Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D139591

mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir

index 0592009..ded1e65 100644 (file)
@@ -823,8 +823,7 @@ public:
           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
       if (enableBufferInitialization) {
         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
-        Value fillValue = rewriter.create<arith::ConstantOp>(
-            loc, value.getType(), rewriter.getZeroAttr(value.getType()));
+        Value fillValue = constantZero(rewriter, loc, value.getType());
         Value subBuffer = rewriter.create<memref::SubViewOp>(
             loc, newBuffer, /*offset=*/ValueRange{newSize},
             /*size=*/ValueRange{fillSize},
index e059bd3..4c190bc 100644 (file)
@@ -235,8 +235,7 @@ static Value createAllocation(OpBuilder &builder, Location loc,
   Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
   Type elemType = memRefType.getElementType();
   if (enableInit) {
-    Value fillValue = builder.create<arith::ConstantOp>(
-        loc, elemType, builder.getZeroAttr(elemType));
+    Value fillValue = constantZero(builder, loc, elemType);
     builder.create<linalg::FillOp>(loc, fillValue, buffer);
   }
   return buffer;
index 06f57eb..6845aa0 100644 (file)
@@ -216,9 +216,9 @@ struct SparseTensorCodegenPass
     // The following operations and dialects may be introduced by the
     // codegen rules, and are therefore marked as legal.
     target.addLegalOp<linalg::FillOp>();
-    target.addLegalDialect<arith::ArithDialect,
-                           bufferization::BufferizationDialect,
-                           memref::MemRefDialect, scf::SCFDialect>();
+    target.addLegalDialect<
+        arith::ArithDialect, bufferization::BufferizationDialect,
+        complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
     target.addLegalOp<UnrealizedConversionCastOp>();
     // Populate with rules and apply rewriting rules.
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
index 80947a2..6ee26b9 100644 (file)
@@ -8,7 +8,7 @@
 // RUN: %{command}
 //
 // Do the same run, but now with direct IR generation.
-// REDEFINE: %{option} = enable-runtime-library=false
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
 // RUN: %{command}
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>