From fb674e3329d8fa694d8e5ce179081890fb918556 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Sat, 25 Apr 2020 01:17:09 -0700 Subject: [PATCH] [mlir] Add support for sparse DenseStringElements. Summary: Added support for sparse strings elements. This is a follow up from the original DenseStringElements. Differential Revision: https://reviews.llvm.org/D78844 --- mlir/include/mlir/IR/Attributes.h | 8 ++++++++ mlir/lib/IR/AsmPrinter.cpp | 23 ++++++++++++++++++++--- mlir/test/IR/parser.mlir | 5 +++++ 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 656c28b..b4f1d02 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -1274,6 +1274,14 @@ private: getZeroValue() const { return getZeroAPFloat(); } + + /// Get a zero for a StringRef. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return StringRef(); + } + /// Get a zero for an C++ integer or float type. template typename std::enable_if::is_integer || diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index bdaf15c..f17d8fd 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -976,6 +976,11 @@ protected: /// Print a dense string elements attribute. void printDenseStringElementsAttr(DenseStringElementsAttr attr); + /// Print a dense elements attribute. If 'allowHex' is true, a hex string is + /// used instead of individual elements when the elements attr is large. + void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, + bool allowHex); + void printDialectAttribute(Attribute attr); void printDialectType(Type type); @@ -1396,13 +1401,13 @@ void ModulePrinter::printAttribute(Attribute attr, break; } case StandardAttributes::DenseIntOrFPElements: { - auto eltsAttr = attr.cast(); + auto eltsAttr = attr.cast(); if (printerFlags.shouldElideElementsAttr(eltsAttr)) { printElidedElementsAttr(os); break; } os << "dense<"; - printDenseElementsAttr(eltsAttr, /*allowHex=*/true); + printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true); os << '>'; break; } @@ -1425,7 +1430,8 @@ void ModulePrinter::printAttribute(Attribute attr, break; } os << "sparse<"; - printDenseElementsAttr(elementsAttr.getIndices(), /*allowHex=*/false); + printDenseIntOrFPElementsAttr(elementsAttr.getIndices(), + /*allowHex=*/false); os << ", "; printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true); os << '>'; @@ -1477,6 +1483,17 @@ static void printDenseStringElement(DenseStringElementsAttr attr, void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, bool allowHex) { + if (auto stringAttr = attr.dyn_cast()) { + printDenseStringElementsAttr(stringAttr); + return; + } + + printDenseIntOrFPElementsAttr(attr.cast(), + allowHex); +} + +void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, + bool allowHex) { auto type = attr.getType(); auto shape = type.getShape(); auto rank = type.getRank(); diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 7a02d77..2170927d 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -764,6 +764,11 @@ func @sparsetensorattr() -> () { "foof320"(){bar = sparse<[], []> : tensor<0xf32>} : () -> () // CHECK: "foof321"() {bar = sparse<{{\[}}], {{\[}}]> : tensor} : () -> () "foof321"(){bar = sparse<[], []> : tensor} : () -> () + +// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> () + "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> () +// CHECK: "foostr"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}"a", "b", "c"]> : tensor<2x2x2x!unknown<"">>} : () -> () + "foostr"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], ["a", "b", "c"]> : tensor<2x2x2x!unknown<"">>} : () -> () return } -- 2.7.4