// Check for the splat case.
if (attr.isSplat()) {
- processElt(*attr.begin(), /*index=*/0);
+ if (bitWidth == 1) {
+ // Handle the special encoding of splat of bool.
+ data[0] = mapping(*attr.begin()).isZero() ? 0 : -1;
+ } else {
+ processElt(*attr.begin(), /*index=*/0);
+ }
return newArrayType;
}
auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
}
+
+TEST(DenseSplatMapValuesTest, I32ToTrue) {
+ MLIRContext context;
+ const int elementValue = 12;
+ IntegerType boolTy = IntegerType::get(&context, 1);
+ IntegerType intTy = IntegerType::get(&context, 32);
+ RankedTensorType shape = RankedTensorType::get({4}, intTy);
+
+ auto attr =
+ DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
+ .mapValues(boolTy, [](const APInt &x) {
+ return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
+ });
+ EXPECT_EQ(attr.getNumElements(), 4);
+ EXPECT_TRUE(attr.isSplat());
+ EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
+}
+
+TEST(DenseSplatMapValuesTest, I32ToFalse) {
+ MLIRContext context;
+ const int elementValue = 0;
+ IntegerType boolTy = IntegerType::get(&context, 1);
+ IntegerType intTy = IntegerType::get(&context, 32);
+ RankedTensorType shape = RankedTensorType::get({4}, intTy);
+
+ auto attr =
+ DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
+ .mapValues(boolTy, [](const APInt &x) {
+ return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
+ });
+ EXPECT_EQ(attr.getNumElements(), 4);
+ EXPECT_TRUE(attr.isSplat());
+ EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
+}
} // namespace
//===----------------------------------------------------------------------===//