From d926b3307e2f90d8b7f770e6f40fbbe5dcc1f887 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 7 Jul 2022 13:16:49 -0700 Subject: [PATCH] [mlir] add complex type to getZeroAttr Fixes issue encountered with complex constant https://github.com/llvm/llvm-project/issues/56428 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D129325 --- mlir/lib/IR/BuiltinAttributes.cpp | 12 ++++++++++++ .../Dialect/LLVMIR/CPU/test-complex-sparse-constant.mlir | 16 ++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 mlir/test/Integration/Dialect/LLVMIR/CPU/test-complex-sparse-constant.mlir diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 3592652..80b91ca 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1589,6 +1589,18 @@ Attribute SparseElementsAttr::getZeroAttr() const { if (eltType.isa()) return FloatAttr::get(eltType, 0); + // Handle complex elements. + if (auto complexTy = eltType.dyn_cast()) { + auto eltType = complexTy.getElementType(); + Attribute zero; + if (eltType.isa()) + zero = FloatAttr::get(eltType, 0); + else // must be integer + zero = IntegerAttr::get(eltType, 0); + return ArrayAttr::get(complexTy.getContext(), + ArrayRef{zero, zero}); + } + // Handle string type. if (getValues().isa()) return StringAttr::get("", eltType); diff --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/test-complex-sparse-constant.mlir b/mlir/test/Integration/Dialect/LLVMIR/CPU/test-complex-sparse-constant.mlir new file mode 100644 index 0000000..a42c64d --- /dev/null +++ b/mlir/test/Integration/Dialect/LLVMIR/CPU/test-complex-sparse-constant.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s --convert-memref-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void + +// +// Code should not crash on the complex32 sparse constant. +// +module attributes {llvm.data_layout = ""} { + memref.global "private" constant @"__constant_32xcomplex_0" : memref<32xcomplex> = + sparse<[[1], [28], [31]], + [(1.000000e+00,0.000000e+00), (2.000000e+00,0.000000e+00), (3.000000e+00,0.000000e+00)] + > {alignment = 128 : i64} + llvm.func @entry() { + %0 = memref.get_global @"__constant_32xcomplex_0" : memref<32xcomplex> + llvm.return + } +} -- 2.7.4