EXPECT_TRUE(attr.getValue({0}) == value);
}
+TEST(SparseElementsAttrTest, GetZero) {
+ MLIRContext context;
+ context.allowUnregisteredDialects();
+
+ IntegerType intTy = IntegerType::get(&context, 32);
+ FloatType floatTy = FloatType::getF32(&context);
+ Type stringTy = OpaqueType::get(Identifier::get("test", &context), "string");
+
+ ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
+ ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
+ ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
+
+ auto indicesType =
+ RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
+ auto indices =
+ DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
+
+ RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
+ auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
+
+ RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
+ auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
+
+ RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
+ auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
+
+ auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
+ auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
+ auto sparseString =
+ SparseElementsAttr::get(tensorString, indices, stringValue);
+
+ // Only index (0, 0) contains an element, others are supposed to return
+ // the zero/empty value.
+ auto zeroIntValue = sparseInt.getValue({1, 1});
+ EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0);
+ EXPECT_TRUE(zeroIntValue.getType() == intTy);
+
+ auto zeroFloatValue = sparseFloat.getValue({1, 1});
+ EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f);
+ EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
+
+ auto zeroStringValue = sparseString.getValue({1, 1});
+ EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty());
+ EXPECT_TRUE(zeroStringValue.getType() == stringTy);
+}
+
} // end namespace