[mlir] Fix DenseElementsAttr::mapValues(i1, splat).
authorChenguang Wang <w3cing@gmail.com>
Tue, 6 Sep 2022 19:28:23 +0000 (21:28 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Tue, 6 Sep 2022 19:28:25 +0000 (21:28 +0200)
Splat of bool is encoded as a byte with all-ones in it [1]. Without this
change, this piece of code:

    auto xs = builder.getI32TensorAttr({42, 42, 42, 42});
    auto xs2 = xs.mapValues(builder.getI1Type(), [](const llvm::APInt &x) {
      return x.isZero() ? llvm::APInt::getZero(1) : llvm::APInt::getAllOnes(1);
    });
    xs2.dump();

Prints:

    dense<[true, false, false, false]> : tensor<4xi1>

Because only the first bit is set. This applies to both
DenseIntElementsAttr::mapValues() and DenseFPElementsAttr::mapValues().

[1]: https://github.com/llvm/llvm-project/blob/e877b42e2c70813352c1963ea33e992f481d5cba/mlir/lib/IR/BuiltinAttributes.cpp#L984

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D132767

mlir/lib/IR/BuiltinAttributes.cpp
mlir/unittests/IR/AttributeTest.cpp

index 810672e..22eff2d 100644 (file)
@@ -1526,7 +1526,12 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
 
   // Check for the splat case.
   if (attr.isSplat()) {
-    processElt(*attr.begin(), /*index=*/0);
+    if (bitWidth == 1) {
+      // Handle the special encoding of splat of bool.
+      data[0] = mapping(*attr.begin()).isZero() ? 0 : -1;
+    } else {
+      processElt(*attr.begin(), /*index=*/0);
+    }
     return newArrayType;
   }
 
index e393b83..cffac41 100644 (file)
@@ -209,6 +209,40 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
   auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
   EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
 }
+
+TEST(DenseSplatMapValuesTest, I32ToTrue) {
+  MLIRContext context;
+  const int elementValue = 12;
+  IntegerType boolTy = IntegerType::get(&context, 1);
+  IntegerType intTy = IntegerType::get(&context, 32);
+  RankedTensorType shape = RankedTensorType::get({4}, intTy);
+
+  auto attr =
+      DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
+          .mapValues(boolTy, [](const APInt &x) {
+            return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
+          });
+  EXPECT_EQ(attr.getNumElements(), 4);
+  EXPECT_TRUE(attr.isSplat());
+  EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
+}
+
+TEST(DenseSplatMapValuesTest, I32ToFalse) {
+  MLIRContext context;
+  const int elementValue = 0;
+  IntegerType boolTy = IntegerType::get(&context, 1);
+  IntegerType intTy = IntegerType::get(&context, 32);
+  RankedTensorType shape = RankedTensorType::get({4}, intTy);
+
+  auto attr =
+      DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
+          .mapValues(boolTy, [](const APInt &x) {
+            return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
+          });
+  EXPECT_EQ(attr.getNumElements(), 4);
+  EXPECT_TRUE(attr.isSplat());
+  EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
+}
 } // namespace
 
 //===----------------------------------------------------------------------===//