From d41028610b5372669adcb9b7091fae5250f0a4a8 Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Fri, 20 May 2022 21:52:49 -0700 Subject: [PATCH] [mlir] Prevent SubElementInterface from going into infinite recursion Since only mutable types and attributes can go into infinite recursion inside SubElementInterface::walkSubElement, and there are only a few of them (mutable types and attributes), we introduce new traits for Type and Attribute: TypeTrait::IsMutable and AttributeTrait::IsMutable, respectively. They indicate whether a type or attribute is mutable. Such traits are required if the ImplType defines a `mutate` function. Then, inside SubElementInterface, we use a set to record visited mutable types and attributes that have been visited before. Differential Revision: https://reviews.llvm.org/D127537 --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 3 ++- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 5 ++-- mlir/include/mlir/IR/Attributes.h | 12 +++++++++ mlir/include/mlir/IR/StorageUniquerSupport.h | 14 ++++++++++ mlir/include/mlir/IR/Types.h | 12 +++++++++ mlir/lib/IR/SubElementInterfaces.cpp | 36 +++++++++++++++++++++---- mlir/test/IR/recursive-type.mlir | 6 +++++ mlir/test/lib/Dialect/Test/TestDialect.cpp | 7 +++++ mlir/test/lib/Dialect/Test/TestTypes.h | 13 +++++++-- mlir/unittests/Dialect/CMakeLists.txt | 2 +- mlir/unittests/Dialect/LLVMIR/CMakeLists.txt | 7 +++++ mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h | 27 +++++++++++++++++++ mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp | 20 ++++++++++++++ 13 files changed, 153 insertions(+), 11 deletions(-) create mode 100644 mlir/unittests/Dialect/LLVMIR/CMakeLists.txt create mode 100644 mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h create mode 100644 mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index c3a244f..50537f6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -264,7 +264,8 @@ public: /// structs, but does not in uniquing of identified structs. class LLVMStructType : public Type::TypeBase { + DataLayoutTypeInterface::Trait, + TypeTrait::IsMutable> { public: /// Inherit base constructors. using Base::Base; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 0d4471e..40a4acf 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -275,8 +275,9 @@ public: /// In the above, expressing recursive struct types is accomplished by giving a /// recursive struct a unique identified and using that identifier in the struct /// definition for recursive references. -class StructType : public Type::TypeBase { +class StructType + : public Type::TypeBase { public: using Base::Base; diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index da5b78e..2195057 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -231,6 +231,18 @@ private: friend InterfaceBase; }; +//===----------------------------------------------------------------------===// +// Core AttributeTrait +//===----------------------------------------------------------------------===// + +/// This trait is used to determine if an attribute is mutable or not. It is +/// attached on an attribute if the corresponding ImplType defines a `mutate` +/// function with proper signature. +namespace AttributeTrait { +template +using IsMutable = detail::StorageUserTrait::IsMutable; +} // namespace AttributeTrait + } // namespace mlir. namespace llvm { diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h index 6d854f6..7aaa688 100644 --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -53,6 +53,16 @@ protected: } }; +namespace StorageUserTrait { +/// This trait is used to determine if a storage user, like Type, is mutable +/// or not. A storage user is mutable if ImplType of the derived class defines +/// a `mutate` function with a proper signature. Note that this trait is not +/// supposed to be used publicly. Users should use alias names like +/// `TypeTrait::IsMutable` instead. +template +struct IsMutable : public StorageUserTraitBase {}; +} // namespace StorageUserTrait + //===----------------------------------------------------------------------===// // StorageUserBase //===----------------------------------------------------------------------===// @@ -173,6 +183,10 @@ protected: /// Mutate the current storage instance. This will not change the unique key. /// The arguments are forwarded to 'ConcreteT::mutate'. template LogicalResult mutate(Args &&...args) { + static_assert(std::is_base_of, + ConcreteT>::value, + "The `mutate` function expects mutable trait " + "(e.g. TypeTrait::IsMutable) to be attached on parent."); return UniquerT::template mutate(this->getContext(), getImpl(), std::forward(args)...); } diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index f2a6966..5af7389 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -223,6 +223,18 @@ private: }; //===----------------------------------------------------------------------===// +// Core TypeTrait +//===----------------------------------------------------------------------===// + +/// This trait is used to determine if a type is mutable or not. It is attached +/// on a type if the corresponding ImplType defines a `mutate` function with +/// a proper signature. +namespace TypeTrait { +template +using IsMutable = detail::StorageUserTrait::IsMutable; +} // namespace TypeTrait + +//===----------------------------------------------------------------------===// // Type Utils //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp index 0a4875c..4059b99 100644 --- a/mlir/lib/IR/SubElementInterfaces.cpp +++ b/mlir/lib/IR/SubElementInterfaces.cpp @@ -8,12 +8,16 @@ #include "mlir/IR/SubElementInterfaces.h" +#include "llvm/ADT/DenseSet.h" + using namespace mlir; template static void walkSubElementsImpl(InterfaceT interface, function_ref walkAttrsFn, - function_ref walkTypesFn) { + function_ref walkTypesFn, + DenseSet &visitedAttrs, + DenseSet &visitedTypes) { interface.walkImmediateSubElements( [&](Attribute attr) { // Guard against potentially null inputs. This removes the need for the @@ -21,9 +25,17 @@ static void walkSubElementsImpl(InterfaceT interface, if (!attr) return; + // Avoid infinite recursion when visiting sub attributes later, if this + // is a mutable attribute. + if (LLVM_UNLIKELY(attr.hasTrait())) { + if (!visitedAttrs.insert(attr).second) + return; + } + // Walk any sub elements first. if (auto interface = attr.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this attribute. walkAttrsFn(attr); @@ -34,9 +46,17 @@ static void walkSubElementsImpl(InterfaceT interface, if (!type) return; + // Avoid infinite recursion when visiting sub types later, if this + // is a mutable type. + if (LLVM_UNLIKELY(type.hasTrait())) { + if (!visitedTypes.insert(type).second) + return; + } + // Walk any sub elements first. if (auto interface = type.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this type. walkTypesFn(type); @@ -47,14 +67,20 @@ void SubElementAttrInterface::walkSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + DenseSet visitedAttrs; + DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } void SubElementTypeInterface::walkSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + DenseSet visitedAttrs; + DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir index e66d7fd..bc9b2cd 100644 --- a/mlir/test/IR/recursive-type.mlir +++ b/mlir/test/IR/recursive-type.mlir @@ -1,11 +1,17 @@ // RUN: mlir-opt %s -test-recursive-types | FileCheck %s +// CHECK: !testrec = !test.test_rec> + // CHECK-LABEL: @roundtrip func.func @roundtrip() { // CHECK: !test.test_rec> "test.dummy_op_for_roundtrip"() : () -> !test.test_rec> // CHECK: !test.test_rec> "test.dummy_op_for_roundtrip"() : () -> !test.test_rec> + // Make sure walkSubElementType, which is used to generate aliases, doesn't go + // into inifinite recursion. + // CHECK: !testrec + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec> return } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 3585d43..653f0e1 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -160,6 +160,13 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { return AliasResult::FinalAlias; } } + if (auto recType = type.dyn_cast()) { + if (recType.getName() == "type_to_alias") { + // We only make alias for a specific recursive type. + os << "testrec"; + return AliasResult::FinalAlias; + } + } return AliasResult::NoAlias; } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h index bd6421d..8772efd 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -21,6 +21,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/SubElementInterfaces.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" @@ -130,7 +131,9 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage { /// from type creation. class TestRecursiveType : public ::mlir::Type::TypeBase { + TestRecursiveTypeStorage, + ::mlir::SubElementTypeInterface::Trait, + ::mlir::TypeTrait::IsMutable> { public: using Base::Base; @@ -141,10 +144,16 @@ public: /// Body getter and setter. ::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); } - ::mlir::Type getBody() { return getImpl()->body; } + ::mlir::Type getBody() const { return getImpl()->body; } /// Name/key getter. ::llvm::StringRef getName() { return getImpl()->name; } + + void walkImmediateSubElements( + ::llvm::function_ref walkAttrsFn, + ::llvm::function_ref walkTypesFn) const { + walkTypesFn(getBody()); + } }; } // namespace test diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index ec89b14..f8e5e46 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -7,8 +7,8 @@ target_link_libraries(MLIRDialectTests MLIRDialect) add_subdirectory(Affine) +add_subdirectory(LLVMIR) add_subdirectory(MemRef) - add_subdirectory(Quant) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt new file mode 100644 index 0000000..92af185 --- /dev/null +++ b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRLLVMIRTests + LLVMTypeTest.cpp +) +target_link_libraries(MLIRLLVMIRTests + PRIVATE + MLIRLLVMDialect + ) diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h b/mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h new file mode 100644 index 0000000..1badc44 --- /dev/null +++ b/mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h @@ -0,0 +1,27 @@ +//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test fixure for LLVM dialect tests. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H +#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "gtest/gtest.h" + +class LLVMIRTest : public ::testing::Test { +protected: + LLVMIRTest() { context.getOrLoadDialect(); } + + mlir::MLIRContext context; +}; + +#endif diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp new file mode 100644 index 0000000..9c0ea4f --- /dev/null +++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp @@ -0,0 +1,20 @@ +//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMTestBase.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/SubElementInterfaces.h" + +using namespace mlir; +using namespace mlir::LLVM; + +TEST_F(LLVMIRTest, IsStructTypeMutable) { + auto structTy = LLVMStructType::getIdentified(&context, "foo"); + ASSERT_TRUE(bool(structTy)); + ASSERT_TRUE(structTy.hasTrait()); +} -- 2.7.4