[mlir] add complex type to getZeroAttr
authorAart Bik <ajcbik@google.com>
Thu, 7 Jul 2022 20:16:49 +0000 (13:16 -0700)
committerAart Bik <ajcbik@google.com>
Thu, 7 Jul 2022 23:58:59 +0000 (16:58 -0700)
Fixes issue encountered with <sparse> complex constant
https://github.com/llvm/llvm-project/issues/56428

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D129325

mlir/lib/IR/BuiltinAttributes.cpp
mlir/test/Integration/Dialect/LLVMIR/CPU/test-complex-sparse-constant.mlir [new file with mode: 0644]

index 35926522e080570644a133a4270b119af9efd59d..80b91ca5bb05040fd409ce2c8b252220929feb58 100644 (file)
@@ -1589,6 +1589,18 @@ Attribute SparseElementsAttr::getZeroAttr() const {
   if (eltType.isa<FloatType>())
     return FloatAttr::get(eltType, 0);
 
+  // Handle complex elements.
+  if (auto complexTy = eltType.dyn_cast<ComplexType>()) {
+    auto eltType = complexTy.getElementType();
+    Attribute zero;
+    if (eltType.isa<FloatType>())
+      zero = FloatAttr::get(eltType, 0);
+    else // must be integer
+      zero = IntegerAttr::get(eltType, 0);
+    return ArrayAttr::get(complexTy.getContext(),
+                          ArrayRef<Attribute>{zero, zero});
+  }
+
   // Handle string type.
   if (getValues().isa<DenseStringElementsAttr>())
     return StringAttr::get("", eltType);
diff --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/test-complex-sparse-constant.mlir b/mlir/test/Integration/Dialect/LLVMIR/CPU/test-complex-sparse-constant.mlir
new file mode 100644 (file)
index 0000000..a42c64d
--- /dev/null
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s --convert-memref-to-llvm | \
+// RUN:   mlir-cpu-runner -e entry -entry-point-result=void
+
+//
+// Code should not crash on the complex32 sparse constant.
+//
+module attributes {llvm.data_layout = ""} {
+  memref.global "private" constant @"__constant_32xcomplex<f32>_0" : memref<32xcomplex<f32>> =
+     sparse<[[1], [28], [31]],
+            [(1.000000e+00,0.000000e+00), (2.000000e+00,0.000000e+00), (3.000000e+00,0.000000e+00)]
+           > {alignment = 128 : i64}
+  llvm.func @entry() {
+     %0 = memref.get_global @"__constant_32xcomplex<f32>_0" : memref<32xcomplex<f32>>
+     llvm.return
+  }
+}