/// structs, but does not in uniquing of identified structs.
class LLVMStructType
: public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
- DataLayoutTypeInterface::Trait> {
+ DataLayoutTypeInterface::Trait,
+ TypeTrait::IsMutable> {
public:
/// Inherit base constructors.
using Base::Base;
/// In the above, expressing recursive struct types is accomplished by giving a
/// recursive struct a unique identified and using that identifier in the struct
/// definition for recursive references.
-class StructType : public Type::TypeBase<StructType, CompositeType,
- detail::StructTypeStorage> {
+class StructType
+ : public Type::TypeBase<StructType, CompositeType,
+ detail::StructTypeStorage, TypeTrait::IsMutable> {
public:
using Base::Base;
friend InterfaceBase;
};
+//===----------------------------------------------------------------------===//
+// Core AttributeTrait
+//===----------------------------------------------------------------------===//
+
+/// This trait is used to determine if an attribute is mutable or not. It is
+/// attached on an attribute if the corresponding ImplType defines a `mutate`
+/// function with proper signature.
+namespace AttributeTrait {
+template <typename ConcreteType>
+using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+} // namespace AttributeTrait
+
} // namespace mlir.
namespace llvm {
}
};
+namespace StorageUserTrait {
+/// This trait is used to determine if a storage user, like Type, is mutable
+/// or not. A storage user is mutable if ImplType of the derived class defines
+/// a `mutate` function with a proper signature. Note that this trait is not
+/// supposed to be used publicly. Users should use alias names like
+/// `TypeTrait::IsMutable` instead.
+template <typename ConcreteType>
+struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {};
+} // namespace StorageUserTrait
+
//===----------------------------------------------------------------------===//
// StorageUserBase
//===----------------------------------------------------------------------===//
/// Mutate the current storage instance. This will not change the unique key.
/// The arguments are forwarded to 'ConcreteT::mutate'.
template <typename... Args> LogicalResult mutate(Args &&...args) {
+ static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
+ ConcreteT>::value,
+ "The `mutate` function expects mutable trait "
+ "(e.g. TypeTrait::IsMutable) to be attached on parent.");
return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
std::forward<Args>(args)...);
}
};
//===----------------------------------------------------------------------===//
+// Core TypeTrait
+//===----------------------------------------------------------------------===//
+
+/// This trait is used to determine if a type is mutable or not. It is attached
+/// on a type if the corresponding ImplType defines a `mutate` function with
+/// a proper signature.
+namespace TypeTrait {
+template <typename ConcreteType>
+using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+} // namespace TypeTrait
+
+//===----------------------------------------------------------------------===//
// Type Utils
//===----------------------------------------------------------------------===//
#include "mlir/IR/SubElementInterfaces.h"
+#include "llvm/ADT/DenseSet.h"
+
using namespace mlir;
template <typename InterfaceT>
static void walkSubElementsImpl(InterfaceT interface,
function_ref<void(Attribute)> walkAttrsFn,
- function_ref<void(Type)> walkTypesFn) {
+ function_ref<void(Type)> walkTypesFn,
+ DenseSet<Attribute> &visitedAttrs,
+ DenseSet<Type> &visitedTypes) {
interface.walkImmediateSubElements(
[&](Attribute attr) {
// Guard against potentially null inputs. This removes the need for the
if (!attr)
return;
+ // Avoid infinite recursion when visiting sub attributes later, if this
+ // is a mutable attribute.
+ if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
+ if (!visitedAttrs.insert(attr).second)
+ return;
+ }
+
// Walk any sub elements first.
if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
- walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
+ walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
+ visitedTypes);
// Walk this attribute.
walkAttrsFn(attr);
if (!type)
return;
+ // Avoid infinite recursion when visiting sub types later, if this
+ // is a mutable type.
+ if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
+ if (!visitedTypes.insert(type).second)
+ return;
+ }
+
// Walk any sub elements first.
if (auto interface = type.dyn_cast<SubElementTypeInterface>())
- walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
+ walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
+ visitedTypes);
// Walk this type.
walkTypesFn(type);
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
- walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
+ DenseSet<Attribute> visitedAttrs;
+ DenseSet<Type> visitedTypes;
+ walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
+ visitedTypes);
}
void SubElementTypeInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
- walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
+ DenseSet<Attribute> visitedAttrs;
+ DenseSet<Type> visitedTypes;
+ walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
+ visitedTypes);
}
//===----------------------------------------------------------------------===//
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
+// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
+
// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
// CHECK: !test.test_rec<a, test_rec<b, test_type>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<a, test_rec<b, test_type>>
// CHECK: !test.test_rec<c, test_rec<c>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<c, test_rec<c>>
+ // Make sure walkSubElementType, which is used to generate aliases, doesn't go
+ // into inifinite recursion.
+ // CHECK: !testrec
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
return
}
return AliasResult::FinalAlias;
}
}
+ if (auto recType = type.dyn_cast<TestRecursiveType>()) {
+ if (recType.getName() == "type_to_alias") {
+ // We only make alias for a specific recursive type.
+ os << "testrec";
+ return AliasResult::FinalAlias;
+ }
+ }
return AliasResult::NoAlias;
}
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
/// from type creation.
class TestRecursiveType
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
- TestRecursiveTypeStorage> {
+ TestRecursiveTypeStorage,
+ ::mlir::SubElementTypeInterface::Trait,
+ ::mlir::TypeTrait::IsMutable> {
public:
using Base::Base;
/// Body getter and setter.
::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
- ::mlir::Type getBody() { return getImpl()->body; }
+ ::mlir::Type getBody() const { return getImpl()->body; }
/// Name/key getter.
::llvm::StringRef getName() { return getImpl()->name; }
+
+ void walkImmediateSubElements(
+ ::llvm::function_ref<void(::mlir::Attribute)> walkAttrsFn,
+ ::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
+ walkTypesFn(getBody());
+ }
};
} // namespace test
MLIRDialect)
add_subdirectory(Affine)
+add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
-
add_subdirectory(Quant)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
--- /dev/null
+add_mlir_unittest(MLIRLLVMIRTests
+ LLVMTypeTest.cpp
+)
+target_link_libraries(MLIRLLVMIRTests
+ PRIVATE
+ MLIRLLVMDialect
+ )
--- /dev/null
+//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Test fixure for LLVM dialect tests.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
+#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "gtest/gtest.h"
+
+class LLVMIRTest : public ::testing::Test {
+protected:
+ LLVMIRTest() { context.getOrLoadDialect<mlir::LLVM::LLVMDialect>(); }
+
+ mlir::MLIRContext context;
+};
+
+#endif
--- /dev/null
+//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LLVMTestBase.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/SubElementInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+TEST_F(LLVMIRTest, IsStructTypeMutable) {
+ auto structTy = LLVMStructType::getIdentified(&context, "foo");
+ ASSERT_TRUE(bool(structTy));
+ ASSERT_TRUE(structTy.hasTrait<TypeTrait::IsMutable>());
+}