[mlir] support complex type in DenseElementsAttr::get.
authorXiang Li <python3kgae@outlook.com>
Sun, 12 Feb 2023 17:19:17 +0000 (12:19 -0500)
committerXiang Li <python3kgae@outlook.com>
Mon, 13 Feb 2023 14:35:19 +0000 (09:35 -0500)
Fixes #60662 https://github.com/llvm/llvm-project/issues/60662

Allow ComplexType when create DenseElementsAttr.
Also allow build ConstantOp for integer complex.

Differential Revision: https://reviews.llvm.org/D143848

mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir

index c71e2e0..28e121b 100644 (file)
@@ -32,10 +32,16 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
     if (!complexTy || arrAttr.size() != 2)
       return false;
     auto complexEltTy = complexTy.getElementType();
-    auto re = arrAttr[0].dyn_cast<FloatAttr>();
-    auto im = arrAttr[1].dyn_cast<FloatAttr>();
-    return re && im && re.getType() == complexEltTy &&
-           im.getType() == complexEltTy;
+    if (auto fre = arrAttr[0].dyn_cast<FloatAttr>()) {
+      auto im = arrAttr[1].dyn_cast<FloatAttr>();
+      return im && fre.getType() == complexEltTy &&
+             im.getType() == complexEltTy;
+    }
+    if (auto ire = arrAttr[0].dyn_cast<IntegerAttr>()) {
+      auto im = arrAttr[1].dyn_cast<IntegerAttr>();
+      return im && ire.getType() == complexEltTy &&
+             im.getType() == complexEltTy;
+    }
   }
   return false;
 }
index ff4aa65..b99ec22 100644 (file)
@@ -891,9 +891,44 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<Attribute> values) {
   assert(hasSameElementsOrSplat(type, values));
 
+  Type eltType = type.getElementType();
+
+  // Take care complex type case first.
+  if (auto complexType = eltType.dyn_cast<ComplexType>()) {
+    if (complexType.getElementType().isIntOrIndex()) {
+      SmallVector<std::complex<APInt>> complexValues;
+      complexValues.reserve(values.size());
+      for (Attribute attr : values) {
+        assert(attr.isa<ArrayAttr>() &&
+               "expected ArrayAttr for complex");
+        auto arrayAttr = attr.cast<ArrayAttr>();
+        assert(arrayAttr.size() == 2 && "expected 2 element for complex");
+        auto attr0 = arrayAttr[0];
+        auto attr1 = arrayAttr[1];
+        complexValues.push_back(
+            std::complex<APInt>(attr0.cast<IntegerAttr>().getValue(),
+                                attr1.cast<IntegerAttr>().getValue()));
+      }
+      return DenseElementsAttr::get(type, complexValues);
+    }
+    // Must be float.
+    SmallVector<std::complex<APFloat>> complexValues;
+    complexValues.reserve(values.size());
+    for (Attribute attr : values) {
+      assert(attr.isa<ArrayAttr>() && "expected ArrayAttr for complex");
+      auto arrayAttr = attr.cast<ArrayAttr>();
+      assert(arrayAttr.size() == 2 && "expected 2 element for complex");
+      auto attr0 = arrayAttr[0];
+      auto attr1 = arrayAttr[1];
+      complexValues.push_back(
+          std::complex<APFloat>(attr0.cast<FloatAttr>().getValue(),
+                                attr1.cast<FloatAttr>().getValue()));
+    }
+    return DenseElementsAttr::get(type, complexValues);
+  }
+
   // If the element type is not based on int/float/index, assume it is a string
   // type.
-  Type eltType = type.getElementType();
   if (!eltType.isIntOrIndexOrFloat()) {
     SmallVector<StringRef, 8> stringValues;
     stringValues.reserve(values.size());
index 6b72804..2c0b871 100644 (file)
@@ -207,6 +207,34 @@ func.func @extract_from_tensor.from_elements_3d()
 
 // -----
 
+// CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
+// CHECK-NEXT:  %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
+// CHECK-NEXT:  return %cst : tensor<3xcomplex<i32>>
+func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
+  %c1 = arith.constant dense<(1, 2)> : tensor<complex<i32>>
+  %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
+  %c2 = arith.constant dense<(3, 2)> : tensor<complex<i32>>
+  %complex2 = tensor.extract %c2[] : tensor<complex<i32>>
+  %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<i32>>
+  return %tensor : tensor<3xcomplex<i32>>
+}
+
+// -----
+
+// CHECK-LABEL:  func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
+// CHECK-NEXT:   %cst = arith.constant dense<[(1.200000e+00,2.300000e+00), (3.200000e+00,2.100000e+00), (1.200000e+00,2.300000e+00)]> : tensor<3xcomplex<f32>>
+// CHECK-NEXT:   return %cst : tensor<3xcomplex<f32>>
+func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
+  %c1 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
+  %complex1 = tensor.extract %c1[] : tensor<complex<f32>>
+  %c2 = arith.constant dense<(3.2, 2.1)> : tensor<complex<f32>>
+  %complex2 = tensor.extract %c2[] : tensor<complex<f32>>
+  %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<f32>>
+  return %tensor : tensor<3xcomplex<f32>>
+}
+
+// -----
+
 // Ensure the optimization doesn't segfault from bad constants
 // CHECK-LABEL: func @extract_negative_from_tensor.from_elements
 func.func @extract_negative_from_tensor.from_elements(%element : index) -> index {