Fix the detection of boolean splat values in DenseElementsAttr for arrays with <...
authorRiver Riddle <riverriddle@google.com>
Tue, 18 Jun 2019 02:46:31 +0000 (19:46 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 20 Jun 2019 06:05:19 +0000 (23:05 -0700)
PiperOrigin-RevId: 253711372

mlir/lib/IR/AttributeDetail.h
mlir/unittests/IR/AttributeTest.cpp

index 56d1762..f164f50 100644 (file)
@@ -478,9 +478,11 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
       splatData = splatData.drop_back();
     }
 
-    // Check that the data buffer corresponds to a splat.
-    return llvm::is_splat(splatData) ? generateSplatKey()
-                                     : KeyTy(ty, data, llvm::hash_value(data));
+    // Check that the data buffer corresponds to a splat of the proper mask.
+    char mask = splatValue ? ~0 : 0;
+    return llvm::all_of(splatData, [mask](char c) { return c == mask; })
+               ? generateSplatKey()
+               : KeyTy(ty, data, llvm::hash_value(data));
   }
 
   /// Hash the key for the storage.
index 275c142..3db87b2 100644 (file)
@@ -84,6 +84,17 @@ TEST(DenseSplatTest, LargeBoolSplat) {
   EXPECT_EQ(detectedSplat, falseSplat);
 }
 
+TEST(DenseSplatTest, BoolNonSplat) {
+  MLIRContext context;
+  IntegerType boolTy = IntegerType::get(1, &context);
+  VectorType shape = VectorType::get({6}, boolTy);
+
+  // Check that we properly handle non-splat values.
+  DenseElementsAttr nonSplat =
+      DenseElementsAttr::get(shape, {false, false, true, false, false, true});
+  EXPECT_FALSE(nonSplat.isSplat());
+}
+
 TEST(DenseSplatTest, OddIntSplat) {
   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
   MLIRContext context;