#ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
#define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
+#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
/// type.
class LLVMArrayType
: public Type::TypeBase<LLVMArrayType, Type, detail::LLVMTypeAndSizeStorage,
- DataLayoutTypeInterface::Trait> {
+ DataLayoutTypeInterface::Trait,
+ SubElementTypeInterface::Trait> {
public:
/// Inherit base constructors.
using Base::Base;
unsigned getPreferredAlignment(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const;
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
/// LLVM dialect function type. It consists of a single return type (unlike MLIR
/// which can have multiple), a list of parameter types and can optionally be
/// variadic.
-class LLVMFunctionType
- : public Type::TypeBase<LLVMFunctionType, Type,
- detail::LLVMFunctionTypeStorage> {
+class LLVMFunctionType : public Type::TypeBase<LLVMFunctionType, Type,
+ detail::LLVMFunctionTypeStorage,
+ SubElementTypeInterface::Trait> {
public:
/// Inherit base constructors.
using Base::Base;
LLVMFunctionType clone(TypeRange inputs, TypeRange results) const;
/// Returns the result type of the function.
- Type getReturnType();
+ Type getReturnType() const;
/// Returns the result type of the function as an ArrayRef, enabling better
/// integration with generic MLIR utilities.
- ArrayRef<Type> getReturnTypes();
+ ArrayRef<Type> getReturnTypes() const;
/// Returns the number of arguments to the function.
unsigned getNumParams();
Type getParamType(unsigned i);
/// Returns a list of argument types of the function.
- ArrayRef<Type> getParams();
+ ArrayRef<Type> getParams() const;
ArrayRef<Type> params() { return getParams(); }
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type result, ArrayRef<Type> arguments, bool);
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
/// object in memory. Pointers may be opaque or parameterized by the element
/// type. Both opaque and non-opaque pointers are additionally parameterized by
/// the address space.
-class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
- detail::LLVMPointerTypeStorage,
- DataLayoutTypeInterface::Trait> {
+class LLVMPointerType
+ : public Type::TypeBase<
+ LLVMPointerType, Type, detail::LLVMPointerTypeStorage,
+ DataLayoutTypeInterface::Trait, SubElementTypeInterface::Trait> {
public:
/// Inherit base constructors.
using Base::Base;
DataLayoutEntryListRef newLayout) const;
LogicalResult verifyEntries(DataLayoutEntryListRef entries,
Location loc) const;
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
class LLVMStructType
: public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
DataLayoutTypeInterface::Trait,
+ SubElementTypeInterface::Trait,
TypeTrait::IsMutable> {
public:
/// Inherit base constructors.
LogicalResult verifyEntries(DataLayoutEntryListRef entries,
Location loc) const;
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
/// length that can be processed as one.
class LLVMFixedVectorType
: public Type::TypeBase<LLVMFixedVectorType, Type,
- detail::LLVMTypeAndSizeStorage> {
+ detail::LLVMTypeAndSizeStorage,
+ SubElementTypeInterface::Trait> {
public:
/// Inherit base constructor.
using Base::Base;
static bool isValidElementType(Type type);
/// Returns the element type of the vector.
- Type getElementType();
+ Type getElementType() const;
/// Returns the number of elements in the fixed vector.
unsigned getNumElements();
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements);
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
/// elements can be processed as one in SIMD context.
class LLVMScalableVectorType
: public Type::TypeBase<LLVMScalableVectorType, Type,
- detail::LLVMTypeAndSizeStorage> {
+ detail::LLVMTypeAndSizeStorage,
+ SubElementTypeInterface::Trait> {
public:
/// Inherit base constructor.
using Base::Base;
static bool isValidElementType(Type type);
/// Returns the element type of the vector.
- Type getElementType();
+ Type getElementType() const;
/// Returns the scaling factor of the number of elements in the vector. The
/// vector contains at least the resulting number of elements, or any non-zero
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned minNumElements);
+
+ void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const;
};
//===----------------------------------------------------------------------===//
return dataLayout.getTypePreferredAlignment(getElementType());
}
+void LLVMArrayType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkTypesFn(getElementType());
+}
+
//===----------------------------------------------------------------------===//
// Function type.
//===----------------------------------------------------------------------===//
return get(results[0], llvm::to_vector(inputs), isVarArg());
}
-Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); }
-ArrayRef<Type> LLVMFunctionType::getReturnTypes() {
+Type LLVMFunctionType::getReturnType() const {
+ return getImpl()->getReturnType();
+}
+ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
return getImpl()->getReturnType();
}
bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); }
-ArrayRef<Type> LLVMFunctionType::getParams() {
+ArrayRef<Type> LLVMFunctionType::getParams() const {
return getImpl()->getArgumentTypes();
}
return success();
}
+void LLVMFunctionType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ for (Type type : llvm::concat<const Type>(getReturnTypes(), getParams()))
+ walkTypesFn(type);
+}
+
//===----------------------------------------------------------------------===//
// Pointer type.
//===----------------------------------------------------------------------===//
return success();
}
+void LLVMPointerType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkTypesFn(getElementType());
+}
+
//===----------------------------------------------------------------------===//
// Struct type.
//===----------------------------------------------------------------------===//
return mlir::success();
}
+void LLVMStructType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ for (Type type : getBody())
+ walkTypesFn(type);
+}
+
//===----------------------------------------------------------------------===//
// Vector types.
//===----------------------------------------------------------------------===//
numElements);
}
-Type LLVMFixedVectorType::getElementType() {
+Type LLVMFixedVectorType::getElementType() const {
return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
}
emitError, elementType, numElements);
}
+void LLVMFixedVectorType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkTypesFn(getElementType());
+}
+
//===----------------------------------------------------------------------===//
// LLVMScalableVectorType.
//===----------------------------------------------------------------------===//
minNumElements);
}
-Type LLVMScalableVectorType::getElementType() {
+Type LLVMScalableVectorType::getElementType() const {
return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
}
emitError, elementType, numElements);
}
+void LLVMScalableVectorType::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkTypesFn(getElementType());
+}
+
//===----------------------------------------------------------------------===//
// Utility functions.
//===----------------------------------------------------------------------===//
ASSERT_TRUE(bool(structTy));
ASSERT_TRUE(structTy.hasTrait<TypeTrait::IsMutable>());
}
+
+TEST_F(LLVMIRTest, MutualReferencedSubElementTypes) {
+ auto fooStructTy = LLVMStructType::getIdentified(&context, "foo");
+ ASSERT_TRUE(bool(fooStructTy));
+ auto barStructTy = LLVMStructType::getIdentified(&context, "bar");
+ ASSERT_TRUE(bool(barStructTy));
+
+ // Created two structs that are referencing each other.
+ Type fooBody[] = {LLVMPointerType::get(barStructTy)};
+ ASSERT_TRUE(succeeded(fooStructTy.setBody(fooBody, /*packed=*/false)));
+ Type barBody[] = {LLVMPointerType::get(fooStructTy)};
+ ASSERT_TRUE(succeeded(barStructTy.setBody(barBody, /*packed=*/false)));
+
+ auto subElementInterface = fooStructTy.dyn_cast<SubElementTypeInterface>();
+ ASSERT_TRUE(bool(subElementInterface));
+ // Test if walkSubElements goes into infinite loops.
+ SmallVector<Type, 4> subElementTypes;
+ subElementInterface.walkSubElements(
+ [](Attribute attr) {},
+ [&](Type type) { subElementTypes.push_back(type); });
+ // We don't record LLVMPointerType (because it's immutable), thus
+ // !llvm.ptr<struct<"bar",...>> will be visited twice.
+ ASSERT_EQ(subElementTypes.size(), 5U);
+
+ // !llvm.ptr<struct<"bar",...>>
+ ASSERT_TRUE(subElementTypes[0].isa<LLVMPointerType>());
+
+ // !llvm.struct<"foo",...>
+ auto structType = subElementTypes[1].dyn_cast<LLVMStructType>();
+ ASSERT_TRUE(bool(structType));
+ ASSERT_TRUE(structType.getName().equals("foo"));
+
+ // !llvm.ptr<struct<"foo",...>>
+ ASSERT_TRUE(subElementTypes[2].isa<LLVMPointerType>());
+
+ // !llvm.struct<"bar",...>
+ structType = subElementTypes[3].dyn_cast<LLVMStructType>();
+ ASSERT_TRUE(bool(structType));
+ ASSERT_TRUE(structType.getName().equals("bar"));
+
+ // !llvm.ptr<struct<"bar",...>>
+ ASSERT_TRUE(subElementTypes[4].isa<LLVMPointerType>());
+}