[mlir][spirv] Support convert complex types
authorLei Zhang <antiagainst@google.com>
Thu, 30 Mar 2023 15:48:24 +0000 (08:48 -0700)
committerLei Zhang <antiagainst@google.com>
Thu, 30 Mar 2023 15:48:31 +0000 (08:48 -0700)
Complex types are converted to a two-element vector type to contain
the real and imaginary numbers.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D147188

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

index af7436e..bf74ad5 100644 (file)
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 
 #include <functional>
+#include <optional>
 
 #define DEBUG_TYPE "mlir-spirv-conversion"
 
@@ -149,6 +152,13 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
     return bitWidth / 8;
   }
 
+  if (auto complexType = type.dyn_cast<ComplexType>()) {
+    auto elementSize = getTypeNumBytes(options, complexType.getElementType());
+    if (!elementSize)
+      return std::nullopt;
+    return 2 * *elementSize;
+  }
+
   if (auto vecType = type.dyn_cast<VectorType>()) {
     auto elementSize = getTypeNumBytes(options, vecType.getElementType());
     if (!elementSize)
@@ -299,6 +309,30 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   return nullptr;
 }
 
+static Type
+convertComplexType(const spirv::TargetEnv &targetEnv,
+                   const SPIRVConversionOptions &options, ComplexType type,
+                   std::optional<spirv::StorageClass> storageClass = {}) {
+  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
+  if (!scalarType) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot convert non-scalar element type\n");
+    return nullptr;
+  }
+
+  auto elementType =
+      convertScalarType(targetEnv, options, scalarType, storageClass);
+  if (!elementType)
+    return nullptr;
+  if (elementType != type.getElementType()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: complex type emulation unsupported\n");
+    return nullptr;
+  }
+
+  return VectorType::get(2, elementType);
+}
+
 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
 ///
 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
@@ -372,7 +406,6 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-
   if (!type.hasStaticShape()) {
     // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
     // to the element.
@@ -419,6 +452,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   if (auto vecType = elementType.dyn_cast<VectorType>()) {
     arrayElemType =
         convertVectorType(targetEnv, options, vecType, storageClass);
+  } else if (auto complexType = elementType.dyn_cast<ComplexType>()) {
+    arrayElemType =
+        convertComplexType(targetEnv, options, complexType, storageClass);
   } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
     arrayElemType =
         convertScalarType(targetEnv, options, scalarType, storageClass);
@@ -443,7 +479,6 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-
   if (!type.hasStaticShape()) {
     // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
     // to the element.
@@ -500,6 +535,10 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
     return Type();
   });
 
+  addConversion([this](ComplexType complexType) {
+    return convertComplexType(this->targetEnv, this->options, complexType);
+  });
+
   addConversion([this](VectorType vectorType) {
     return convertVectorType(this->targetEnv, this->options, vectorType);
   });
index 7880547..fe6edc1 100644 (file)
@@ -197,6 +197,61 @@ func.func @bf16_type(%arg0: bf16) { return }
 // -----
 
 //===----------------------------------------------------------------------===//
+// Complex types
+//===----------------------------------------------------------------------===//
+
+// Check that capabilities for scalar types affects complex types too: having
+// special capabilities means keep vector types untouched.
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [Float64, StorageUniform16, StorageBuffer16BitAccess],
+    [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func @complex_types
+// CHECK-SAME: vector<2xf32>
+// CHECK-SAME: vector<2xf64>
+func.func @complex_types(
+    %arg0: complex<f32>,
+    %arg2: complex<f64>
+) { return }
+
+// CHECK-LABEL: func @memref_complex_types_with_cap
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x vector<2xf16>, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<16 x vector<2xf16>, stride=4> [0])>, Uniform>
+func.func @memref_complex_types_with_cap(
+    %arg0: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>,
+    %arg1: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that capabilities for scalar types affects complex types too: no special
+// capabilities available means widening element types to 32-bit.
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Emulation is unimplemented right now.
+// CHECK-LABEL: func @memref_complex_types_no_cap
+// CHECK-SAME: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+// NOEMU-LABEL: func @memref_complex_types_no_cap
+// NOEMU-SAME: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>
+// NOEMU-SAME: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+func.func @memref_complex_types_no_cap(
+    %arg0: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>,
+    %arg1: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+) { return }
+
+} // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // Vector types
 //===----------------------------------------------------------------------===//