[mlir] Handle StringAttr in SparseElementsAttr::getZeroAttr.
authorChia-hung Duan <chiahungduan@google.com>
Thu, 4 Nov 2021 18:11:50 +0000 (18:11 +0000)
committerChia-hung Duan <chiahungduan@google.com>
Thu, 4 Nov 2021 18:18:20 +0000 (18:18 +0000)
Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D111203

mlir/lib/IR/BuiltinAttributes.cpp
mlir/unittests/IR/AttributeTest.cpp

index 72891d9..2acc386 100644 (file)
@@ -1336,8 +1336,11 @@ Attribute SparseElementsAttr::getZeroAttr() const {
   if (eltType.isa<FloatType>())
     return FloatAttr::get(eltType, 0);
 
+  // Handle string type.
+  if (getValues().isa<DenseStringElementsAttr>())
+    return StringAttr::get("", eltType);
+
   // Otherwise, this is an integer.
-  // TODO: Handle StringAttr here.
   return IntegerAttr::get(eltType, 0);
 }
 
index 891abd1..aaff61e 100644 (file)
@@ -205,4 +205,50 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
   EXPECT_TRUE(attr.getValue({0}) == value);
 }
 
+TEST(SparseElementsAttrTest, GetZero) {
+  MLIRContext context;
+  context.allowUnregisteredDialects();
+
+  IntegerType intTy = IntegerType::get(&context, 32);
+  FloatType floatTy = FloatType::getF32(&context);
+  Type stringTy = OpaqueType::get(Identifier::get("test", &context), "string");
+
+  ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
+  ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
+  ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
+
+  auto indicesType =
+      RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
+  auto indices =
+      DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
+
+  RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
+  auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
+
+  RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
+  auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
+
+  RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
+  auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
+
+  auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
+  auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
+  auto sparseString =
+      SparseElementsAttr::get(tensorString, indices, stringValue);
+
+  // Only index (0, 0) contains an element, others are supposed to return
+  // the zero/empty value.
+  auto zeroIntValue = sparseInt.getValue({1, 1});
+  EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0);
+  EXPECT_TRUE(zeroIntValue.getType() == intTy);
+
+  auto zeroFloatValue = sparseFloat.getValue({1, 1});
+  EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f);
+  EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
+
+  auto zeroStringValue = sparseString.getValue({1, 1});
+  EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty());
+  EXPECT_TRUE(zeroStringValue.getType() == stringTy);
+}
+
 } // end namespace