[mlir] Add support for sparse DenseStringElements.
authorRob Suderman <rob.suderman@gmail.com>
Sat, 25 Apr 2020 08:17:09 +0000 (01:17 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Sat, 25 Apr 2020 08:21:40 +0000 (01:21 -0700)
Summary: Added support for sparse strings elements. This is a follow up from the original DenseStringElements.

Differential Revision: https://reviews.llvm.org/D78844

mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/parser.mlir

index 656c28b..b4f1d02 100644 (file)
@@ -1274,6 +1274,14 @@ private:
   getZeroValue() const {
     return getZeroAPFloat();
   }
+
+  /// Get a zero for a StringRef.
+  template <typename T>
+  typename std::enable_if<std::is_same<StringRef, T>::value, T>::type
+  getZeroValue() const {
+    return StringRef();
+  }
+
   /// Get a zero for an C++ integer or float type.
   template <typename T>
   typename std::enable_if<std::numeric_limits<T>::is_integer ||
index bdaf15c..f17d8fd 100644 (file)
@@ -976,6 +976,11 @@ protected:
   /// Print a dense string elements attribute.
   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
 
+  /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
+  /// used instead of individual elements when the elements attr is large.
+  void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
+                                     bool allowHex);
+
   void printDialectAttribute(Attribute attr);
   void printDialectType(Type type);
 
@@ -1396,13 +1401,13 @@ void ModulePrinter::printAttribute(Attribute attr,
     break;
   }
   case StandardAttributes::DenseIntOrFPElements: {
-    auto eltsAttr = attr.cast<DenseElementsAttr>();
+    auto eltsAttr = attr.cast<DenseIntOrFPElementsAttr>();
     if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
       printElidedElementsAttr(os);
       break;
     }
     os << "dense<";
-    printDenseElementsAttr(eltsAttr, /*allowHex=*/true);
+    printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true);
     os << '>';
     break;
   }
@@ -1425,7 +1430,8 @@ void ModulePrinter::printAttribute(Attribute attr,
       break;
     }
     os << "sparse<";
-    printDenseElementsAttr(elementsAttr.getIndices(), /*allowHex=*/false);
+    printDenseIntOrFPElementsAttr(elementsAttr.getIndices(),
+                                  /*allowHex=*/false);
     os << ", ";
     printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
     os << '>';
@@ -1477,6 +1483,17 @@ static void printDenseStringElement(DenseStringElementsAttr attr,
 
 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
                                            bool allowHex) {
+  if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
+    printDenseStringElementsAttr(stringAttr);
+    return;
+  }
+
+  printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
+                                allowHex);
+}
+
+void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
+                                                  bool allowHex) {
   auto type = attr.getType();
   auto shape = type.getShape();
   auto rank = type.getRank();
index 7a02d77..2170927 100644 (file)
@@ -764,6 +764,11 @@ func @sparsetensorattr() -> () {
   "foof320"(){bar = sparse<[], []> : tensor<0xf32>} : () -> ()
 // CHECK: "foof321"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<f32>} : () -> ()
   "foof321"(){bar = sparse<[], []> : tensor<f32>} : () -> ()
+
+// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
+  "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
+// CHECK: "foostr"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}"a", "b", "c"]> : tensor<2x2x2x!unknown<"">>} : () -> ()
+  "foostr"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], ["a", "b", "c"]> : tensor<2x2x2x!unknown<"">>} : () -> ()
   return
 }