From 998003eaf7b91b5620eeb0b77c81654b9f36efb7 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 24 Jul 2023 20:39:40 -0700 Subject: [PATCH] [mlir] Fix for MSVC bool splat issue encountered. When building MLIR using bazel on windows with MSVC2019, bool splats were being created incorrectly: ``` dense<[true,true,true,true]> : tensor<4xi1> -(parse with mlir-opt)-> dense<[true, false, false, false]> : tensor<4xi1> ``` Appears that a Windows bazel build produces a corrupt DenseIntOrFPElementsAttr. Unable to repro using MSVC and cmake. Issue first discovered here: https://github.com/google/jax/issues/16394 Added test point for reproduction: ``` $ bazel test @llvm-project//mlir/unittests:ir_tests --test_arg=--gtest_filter=DenseSplatTest.BoolSplatSmall ``` Differential Revision: https://reviews.llvm.org/D155745 --- mlir/lib/IR/AttributeDetail.h | 7 +++++-- mlir/lib/IR/BuiltinAttributes.cpp | 3 +++ mlir/unittests/IR/AttributeTest.cpp | 14 ++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index b76da1b..dcd24af 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -193,8 +193,11 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage { ArrayRef data; /// The values used to denote a boolean splat value. - static constexpr char kSplatTrue = ~0; - static constexpr char kSplatFalse = 0; + // This is not using constexpr declaration due to compilation failure + // encountered with MSVC where it would inline these values, which makes it + // unsafe to refer by reference in KeyTy. + static const char kSplatTrue; + static const char kSplatFalse; }; /// An attribute representing a reference to a dense vector or tensor object diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index c670359..7328fcc 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -438,6 +438,9 @@ LogicalResult OpaqueAttr::verify(function_ref emitError, // DenseElementsAttr Utilities //===----------------------------------------------------------------------===// +const char DenseIntOrFPElementsAttrStorage::kSplatTrue = ~0; +const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0; + /// Get the bitwidth of a dense element type within the buffer. /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. static size_t getDenseElementStorageWidth(size_t origWidth) { diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index f01cc02..a707dfa 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -74,6 +74,20 @@ TEST(DenseSplatTest, BoolSplatRawRoundtrip) { EXPECT_EQ(trueSplat, trueSplatFromRaw); } +TEST(DenseSplatTest, BoolSplatSmall) { + MLIRContext context; + Builder builder(&context); + + // Check that splats that don't fill entire byte are handled properly. + auto tensorType = RankedTensorType::get({4}, builder.getI1Type()); + std::vector data{0b00001111}; + auto trueSplatFromRaw = + DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data); + EXPECT_TRUE(trueSplatFromRaw.isSplat()); + DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true); + EXPECT_EQ(trueSplat, trueSplatFromRaw); +} + TEST(DenseSplatTest, LargeBoolSplat) { constexpr int64_t boolCount = 56; -- 2.7.4