[mlir] Fix for MSVC bool splat issue encountered.
authorKevin Gleason <gleasonk@google.com>
Tue, 25 Jul 2023 03:39:40 +0000 (20:39 -0700)
committerJacques Pienaar <jpienaar@google.com>
Tue, 25 Jul 2023 03:45:43 +0000 (20:45 -0700)
When building MLIR using bazel on windows with MSVC2019, bool splats
were being created incorrectly:

```
dense<[true,true,true,true]> : tensor<4xi1>
-(parse with mlir-opt)-> dense<[true, false, false, false]> : tensor<4xi1>
```

Appears that a Windows bazel build produces a corrupt DenseIntOrFPElementsAttr.
Unable to repro using MSVC and cmake.

Issue first discovered here:
https://github.com/google/jax/issues/16394

Added test point for reproduction:

```
$ bazel test @llvm-project//mlir/unittests:ir_tests --test_arg=--gtest_filter=DenseSplatTest.BoolSplatSmall
```

Differential Revision: https://reviews.llvm.org/D155745

mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/BuiltinAttributes.cpp
mlir/unittests/IR/AttributeTest.cpp

index b76da1b..dcd24af 100644 (file)
@@ -193,8 +193,11 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
   ArrayRef<char> data;
 
   /// The values used to denote a boolean splat value.
-  static constexpr char kSplatTrue = ~0;
-  static constexpr char kSplatFalse = 0;
+  // This is not using constexpr declaration due to compilation failure
+  // encountered with MSVC where it would inline these values, which makes it
+  // unsafe to refer by reference in KeyTy.
+  static const char kSplatTrue;
+  static const char kSplatFalse;
 };
 
 /// An attribute representing a reference to a dense vector or tensor object
index c670359..7328fcc 100644 (file)
@@ -438,6 +438,9 @@ LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 // DenseElementsAttr Utilities
 //===----------------------------------------------------------------------===//
 
+const char DenseIntOrFPElementsAttrStorage::kSplatTrue = ~0;
+const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0;
+
 /// Get the bitwidth of a dense element type within the buffer.
 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
 static size_t getDenseElementStorageWidth(size_t origWidth) {
index f01cc02..a707dfa 100644 (file)
@@ -74,6 +74,20 @@ TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
   EXPECT_EQ(trueSplat, trueSplatFromRaw);
 }
 
+TEST(DenseSplatTest, BoolSplatSmall) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  // Check that splats that don't fill entire byte are handled properly.
+  auto tensorType = RankedTensorType::get({4}, builder.getI1Type());
+  std::vector<char> data{0b00001111};
+  auto trueSplatFromRaw =
+      DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data);
+  EXPECT_TRUE(trueSplatFromRaw.isSplat());
+  DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true);
+  EXPECT_EQ(trueSplat, trueSplatFromRaw);
+}
+
 TEST(DenseSplatTest, LargeBoolSplat) {
   constexpr int64_t boolCount = 56;