splatData = splatData.drop_back();
}
- // Check that the data buffer corresponds to a splat.
- return llvm::is_splat(splatData) ? generateSplatKey()
- : KeyTy(ty, data, llvm::hash_value(data));
+ // Check that the data buffer corresponds to a splat of the proper mask.
+ char mask = splatValue ? ~0 : 0;
+ return llvm::all_of(splatData, [mask](char c) { return c == mask; })
+ ? generateSplatKey()
+ : KeyTy(ty, data, llvm::hash_value(data));
}
/// Hash the key for the storage.
EXPECT_EQ(detectedSplat, falseSplat);
}
+TEST(DenseSplatTest, BoolNonSplat) {
+ MLIRContext context;
+ IntegerType boolTy = IntegerType::get(1, &context);
+ VectorType shape = VectorType::get({6}, boolTy);
+
+ // Check that we properly handle non-splat values.
+ DenseElementsAttr nonSplat =
+ DenseElementsAttr::get(shape, {false, false, true, false, false, true});
+ EXPECT_FALSE(nonSplat.isSplat());
+}
+
TEST(DenseSplatTest, OddIntSplat) {
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
MLIRContext context;