From 64ce74a6c8f23481f8062830e3ca7f38e171d74c Mon Sep 17 00:00:00 2001 From: Chia-hung Duan Date: Thu, 4 Nov 2021 18:11:50 +0000 Subject: [PATCH] [mlir] Handle StringAttr in SparseElementsAttr::getZeroAttr. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D111203 --- mlir/lib/IR/BuiltinAttributes.cpp | 5 +++- mlir/unittests/IR/AttributeTest.cpp | 46 +++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 72891d9..2acc386 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1336,8 +1336,11 @@ Attribute SparseElementsAttr::getZeroAttr() const { if (eltType.isa()) return FloatAttr::get(eltType, 0); + // Handle string type. + if (getValues().isa()) + return StringAttr::get("", eltType); + // Otherwise, this is an integer. - // TODO: Handle StringAttr here. return IntegerAttr::get(eltType, 0); } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index 891abd1..aaff61e 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -205,4 +205,50 @@ TEST(DenseScalarTest, ExtractZeroRankElement) { EXPECT_TRUE(attr.getValue({0}) == value); } +TEST(SparseElementsAttrTest, GetZero) { + MLIRContext context; + context.allowUnregisteredDialects(); + + IntegerType intTy = IntegerType::get(&context, 32); + FloatType floatTy = FloatType::getF32(&context); + Type stringTy = OpaqueType::get(Identifier::get("test", &context), "string"); + + ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); + ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); + ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); + + auto indicesType = + RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); + auto indices = + DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); + + RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); + auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); + + RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); + auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); + + RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); + auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); + + auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); + auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); + auto sparseString = + SparseElementsAttr::get(tensorString, indices, stringValue); + + // Only index (0, 0) contains an element, others are supposed to return + // the zero/empty value. + auto zeroIntValue = sparseInt.getValue({1, 1}); + EXPECT_EQ(zeroIntValue.cast().getInt(), 0); + EXPECT_TRUE(zeroIntValue.getType() == intTy); + + auto zeroFloatValue = sparseFloat.getValue({1, 1}); + EXPECT_EQ(zeroFloatValue.cast().getValueAsDouble(), 0.0f); + EXPECT_TRUE(zeroFloatValue.getType() == floatTy); + + auto zeroStringValue = sparseString.getValue({1, 1}); + EXPECT_TRUE(zeroStringValue.cast().getValue().empty()); + EXPECT_TRUE(zeroStringValue.getType() == stringTy); +} + } // end namespace -- 2.7.4