-//===- LLVMTypes.cpp - MLIR LLVM Dialect types ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/TypeSupport.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/TypeSize.h"
//===----------------------------------------------------------------------===//
bool LLVMPointerType::isValidElementType(Type type) {
- return isCompatibleType(type) ? !type.isa<LLVMVoidType, LLVMTokenType,
- LLVMMetadataType, LLVMLabelType>()
- : type.isa<PointerElementTypeInterface>();
+ return isCompatibleOuterType(type)
+ ? !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
+ LLVMLabelType>()
+ : type.isa<PointerElementTypeInterface>();
}
LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
// Utility functions.
//===----------------------------------------------------------------------===//
-bool mlir::LLVM::isCompatibleType(Type type) {
- // Only signless integers are compatible.
- if (auto intType = type.dyn_cast<IntegerType>())
- return intType.isSignless();
-
- // 1D vector types are compatible if their element types are.
- if (auto vecType = type.dyn_cast<VectorType>())
- return vecType.getRank() == 1 && isCompatibleType(vecType.getElementType());
-
+bool mlir::LLVM::isCompatibleOuterType(Type type) {
// clang-format off
- return type.isa<
+ if (type.isa<
BFloat16Type,
Float16Type,
Float32Type,
LLVMScalableVectorType,
LLVMVoidType,
LLVMX86MMXType
- >();
- // clang-format on
+ >()) {
+ // clang-format on
+ return true;
+ }
+
+ // Only signless integers are compatible.
+ if (auto intType = type.dyn_cast<IntegerType>())
+ return intType.isSignless();
+
+ // 1D vector types are compatible.
+ if (auto vecType = type.dyn_cast<VectorType>())
+ return vecType.getRank() == 1;
+
+ return false;
+}
+
+static bool isCompatibleImpl(Type type, SetVector<Type> &callstack) {
+ if (callstack.contains(type))
+ return true;
+
+ callstack.insert(type);
+ auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); });
+
+ auto isCompatible = [&](Type type) {
+ return isCompatibleImpl(type, callstack);
+ };
+
+ return llvm::TypeSwitch<Type, bool>(type)
+ .Case<LLVMStructType>([&](auto structType) {
+ return llvm::all_of(structType.getBody(), isCompatible);
+ })
+ .Case<LLVMFunctionType>([&](auto funcType) {
+ return isCompatible(funcType.getReturnType()) &&
+ llvm::all_of(funcType.getParams(), isCompatible);
+ })
+ .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
+ .Case<VectorType>([&](auto vecType) {
+ return vecType.getRank() == 1 && isCompatible(vecType.getElementType());
+ })
+ // clang-format off
+ .Case<
+ LLVMPointerType,
+ LLVMFixedVectorType,
+ LLVMScalableVectorType,
+ LLVMArrayType
+ >([&](auto containerType) {
+ return isCompatible(containerType.getElementType());
+ })
+ .Case<
+ BFloat16Type,
+ Float16Type,
+ Float32Type,
+ Float64Type,
+ Float80Type,
+ Float128Type,
+ LLVMLabelType,
+ LLVMMetadataType,
+ LLVMPPCFP128Type,
+ LLVMTokenType,
+ LLVMVoidType,
+ LLVMX86MMXType
+ >([](Type) { return true; })
+ // clang-format on
+ .Default([](Type) { return false; });
+}
+
+bool mlir::LLVM::isCompatibleType(Type type) {
+ SetVector<Type> callstack;
+ return isCompatibleImpl(type, callstack);
}
bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {