From 73440ca9f878f9c4150b339bdd56b234d9167ee9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Markus=20B=C3=B6ck?= Date: Wed, 6 Jul 2022 12:27:44 +0200 Subject: [PATCH] [mlir] Define proper DenseMapInfo for Interfaces Prior to this patch, using any kind of interface (op interface, attr interface, type interface) as the key of a llvm::DenseSet, llvm::DenseMap or other related containers, leads to invalid pointer dereferences, despite compiling. The gist of the problem is that a llvm::DenseMapInfo specialization for the base type (aka one of Operation*, Type or Attribute) are selected when using an interface as a key, which uses getFromOpaquePointer with invalid pointer addresses to construct instances for the empty key and tombstone key values. The interface is then constructed with this invalid base type and an attempt is made to lookup the implementation in the interface map, which then dereferences the invalid pointer address. (For more details and the exact call chain involved see the GitHub issue below) The current workaround is to use the more generic base type (eg. instead of DenseSet, DenseSet), but this is strictly worse from a code perspective (doesn't enforce the invariant, code is less self documenting, having to insert casts etc). This patch fixes that issue by defining a DenseMapInfo specialization of Interface subclasses which uses a new constructor to construct an instance without querying a concept. That allows getEmptyKey and getTombstoneKey to construct an interface with invalid pointer values. Fixes https://github.com/llvm/llvm-project/issues/54908 Differential Revision: https://reviews.llvm.org/D129038 --- mlir/include/mlir/IR/Attributes.h | 3 +- mlir/include/mlir/IR/OpDefinition.h | 5 +- mlir/include/mlir/IR/Types.h | 3 +- mlir/include/mlir/Support/InterfaceSupport.h | 37 +++++++++++++ mlir/unittests/IR/CMakeLists.txt | 1 + mlir/unittests/IR/InterfaceTest.cpp | 77 ++++++++++++++++++++++++++++ 6 files changed, 122 insertions(+), 4 deletions(-) create mode 100644 mlir/unittests/IR/InterfaceTest.cpp diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 2195057..6c420f1 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -266,7 +266,8 @@ template <> struct DenseMapInfo { }; template struct DenseMapInfo< - T, std::enable_if_t::value>> + T, std::enable_if_t::value && + !mlir::detail::IsInterface::value>> : public DenseMapInfo { static T getEmptyKey() { const void *pointer = llvm::DenseMapInfo::getEmptyKey(); diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 42eccde..c98993a 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1963,8 +1963,9 @@ LogicalResult verifyCastInterfaceOp( namespace llvm { template -struct DenseMapInfo< - T, std::enable_if_t::value>> { +struct DenseMapInfo::value && + !mlir::detail::IsInterface::value>> { static inline T getEmptyKey() { auto *pointer = llvm::DenseMapInfo::getEmptyKey(); return T::getFromOpaquePointer(pointer); diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index 5af7389..4efc838 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -282,7 +282,8 @@ template <> struct DenseMapInfo { static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; } }; template -struct DenseMapInfo::value>> +struct DenseMapInfo::value && + !mlir::detail::IsInterface::value>> : public DenseMapInfo { static T getEmptyKey() { const void *pointer = llvm::DenseMapInfo::getEmptyKey(); diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h index 7cbd18f..02eaa1b 100644 --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -78,6 +78,7 @@ public: Interface; template using ExternalModel = typename Traits::template ExternalModel; + using ValueType = ValueT; /// This is a special trait that registers a given interface with an object. template @@ -104,6 +105,9 @@ public: assert((!t || impl) && "expected value to provide interface instance"); } + /// Constructor for DenseMapInfo's empty key and tombstone key. + Interface(ValueT t, std::nullptr_t) : BaseType(t), impl(nullptr) {} + /// Support 'classof' by checking if the given object defines the concrete /// interface. static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); } @@ -264,7 +268,40 @@ private: SmallVector> interfaces; }; +template class> class BaseTrait> +void isInterfaceImpl( + Interface &); + +template +using is_interface_t = decltype(isInterfaceImpl(std::declval())); + +template +using IsInterface = llvm::is_detected; + } // namespace detail } // namespace mlir +namespace llvm { + +template +struct DenseMapInfo::value>> { + using ValueTypeInfo = llvm::DenseMapInfo; + + static T getEmptyKey() { return T(ValueTypeInfo::getEmptyKey(), nullptr); } + + static T getTombstoneKey() { + return T(ValueTypeInfo::getTombstoneKey(), nullptr); + } + + static unsigned getHashValue(T val) { + return ValueTypeInfo::getHashValue(val); + } + + static bool isEqual(T lhs, T rhs) { return ValueTypeInfo::isEqual(lhs, rhs); } +}; + +} // namespace llvm + #endif diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt index 326fc0a..188df7e 100644 --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRIRTests AttributeTest.cpp DialectTest.cpp + InterfaceTest.cpp InterfaceAttachmentTest.cpp OperationSupportTest.cpp PatternMatchTest.cpp diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp new file mode 100644 index 0000000..e77e879 --- /dev/null +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -0,0 +1,77 @@ +//===- InterfaceTest.cpp - Test interfaces --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OwningOpRef.h" +#include "gtest/gtest.h" + +#include "../../test/lib/Dialect/Test/TestAttributes.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestTypes.h" + +using namespace mlir; +using namespace test; + +TEST(InterfaceTest, OpInterfaceDenseMapKey) { + MLIRContext context; + context.loadDialect(); + + OwningOpRef module = ModuleOp::create(UnknownLoc::get(&context)); + OpBuilder builder(module->getBody(), module->getBody()->begin()); + auto op1 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + auto op2 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + auto op3 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + DenseSet opSet; + opSet.insert(op1); + opSet.insert(op2); + opSet.erase(op1); + EXPECT_FALSE(opSet.contains(op1)); + EXPECT_TRUE(opSet.contains(op2)); + EXPECT_FALSE(opSet.contains(op3)); +} + +TEST(InterfaceTest, AttrInterfaceDenseMapKey) { + MLIRContext context; + context.loadDialect(); + + OpBuilder builder(&context); + + DenseSet attrSet; + auto attr1 = builder.getArrayAttr({}); + auto attr2 = builder.getI32ArrayAttr({0}); + auto attr3 = builder.getI32ArrayAttr({1}); + attrSet.insert(attr1); + attrSet.insert(attr2); + attrSet.erase(attr1); + EXPECT_FALSE(attrSet.contains(attr1)); + EXPECT_TRUE(attrSet.contains(attr2)); + EXPECT_FALSE(attrSet.contains(attr3)); +} + +TEST(InterfaceTest, TypeInterfaceDenseMapKey) { + MLIRContext context; + context.loadDialect(); + + OpBuilder builder(&context); + DenseSet typeSet; + auto type1 = builder.getType(1); + auto type2 = builder.getType(2); + auto type3 = builder.getType(3); + typeSet.insert(type1); + typeSet.insert(type2); + typeSet.erase(type1); + EXPECT_FALSE(typeSet.contains(type1)); + EXPECT_TRUE(typeSet.contains(type2)); + EXPECT_FALSE(typeSet.contains(type3)); +} -- 2.7.4