Support 0-d tensor type attributes
authorFeng Liu <fengliuai@google.com>
Mon, 1 Apr 2019 17:01:47 +0000 (10:01 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 1 Apr 2019 17:59:59 +0000 (10:59 -0700)
    This CL fixes the parser and printer to support the 0-d tensor type attributes.

--

PiperOrigin-RevId: 241345329

mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir

index a9d5dfe..14946f8 100644 (file)
@@ -652,6 +652,12 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
   SmallVector<Attribute, 16> elements;
   attr.getValues(elements);
 
+  // Special case for 0-d tensors;
+  if (rank == 0) {
+    printAttribute(elements[0]);
+    return;
+  }
+
   // Special case for degenerate tensors.
   if (elements.empty()) {
     for (int i = 0; i < rank; ++i)
index 6af2195..d840439 100644 (file)
@@ -763,9 +763,10 @@ public:
   TensorLiteralParser(Parser &p, Type eltTy) : p(p), eltTy(eltTy) {}
 
   ParseResult parse() {
-    if (p.getToken().isNot(Token::l_square))
-      return p.emitError("expected '[' in tensor literal list");
-    return parseList(shape);
+    if (p.getToken().is(Token::l_square)) {
+      return parseList(shape);
+    }
+    return parseElement();
   }
 
   ArrayRef<Attribute> getValues() const { return storage; }
@@ -773,9 +774,11 @@ public:
   ArrayRef<int64_t> getShape() const { return shape; }
 
 private:
-  /// Parse either a single element or a list of elements. Return the dimensions
-  /// of the parsed sub-tensor in dims.
-  ParseResult parseElementOrList(llvm::SmallVectorImpl<int64_t> &dims);
+  /// Parse a single element, returning failure if it isn't a valid element
+  /// literal. For example:
+  /// parseElement(1) -> Success, 1
+  /// parseElement([1]) -> Failure
+  ParseResult parseElement();
 
   /// Parse a list of either lists or elements, returning the dimensions of the
   /// parsed sub-tensors in dims. For example:
@@ -792,13 +795,8 @@ private:
 };
 } // namespace
 
-/// Parse either a single element or a list of elements. Return the dimensions
-/// of the parsed sub-tensor in dims.
-ParseResult
-TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int64_t> &dims) {
+ParseResult TensorLiteralParser::parseElement() {
   switch (p.getToken().getKind()) {
-  case Token::l_square:
-    return parseList(dims);
   case Token::floatliteral:
   case Token::integer:
   case Token::minus: {
@@ -842,7 +840,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int64_t> &dims) {
     break;
   }
   default:
-    return p.emitError("expected '[' or scalar constant inside tensor literal");
+    return p.emitError("expected element literal of primitive type");
   }
   return ParseSuccess;
 }
@@ -870,8 +868,12 @@ TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) {
   unsigned size = 0;
   auto parseCommaSeparatedList = [&]() {
     llvm::SmallVector<int64_t, 4> thisDims;
-    if (parseElementOrList(thisDims))
+    if (p.getToken().getKind() == Token::l_square) {
+      if (parseList(thisDims))
+        return ParseFailure;
+    } else if (parseElement()) {
       return ParseFailure;
+    }
     ++size;
     if (!first)
       return checkDims(newDims, thisDims);
@@ -1162,18 +1164,14 @@ Attribute Parser::parseAttribute(Type type) {
     if (!type)
       return nullptr;
 
-    switch (getToken().getKind()) {
-    case Token::l_square: {
-      auto attr = parseDenseElementsAttr(type);
-      if (!attr)
-        return nullptr;
-      if (parseToken(Token::greater, "expected '>'"))
-        return nullptr;
-      return attr;
-    }
-    default:
-      return (emitError("expected '[' to start dense tensor literal"), nullptr);
-    }
+    auto attr = parseDenseElementsAttr(type);
+    if (!attr)
+      return nullptr;
+
+    if (parseToken(Token::greater, "expected '>'"))
+      return nullptr;
+
+    return attr;
   }
   case Token::kw_sparse: {
     consumeToken(Token::kw_sparse);
@@ -1202,8 +1200,11 @@ Attribute Parser::parseAttribute(Type type) {
         return nullptr;
 
       /// Sanity check.
-      auto indicesType = indices.getType();
       auto valuesType = values.getType();
+      if (valuesType.getRank() != 1) {
+        return (emitError("expected 1-d tensor for values"), nullptr);
+      }
+      auto indicesType = indices.getType();
       auto sameShape = (indicesType.getRank() == 1) ||
                        (type.getRank() == indicesType.getDimSize(1));
       auto sameElementNum =
@@ -1277,6 +1278,7 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
   TensorLiteralParser literalParser(*this, eltTy);
   if (literalParser.parse())
     return nullptr;
+
   if (literalParser.getShape() != type.getShape()) {
     std::string str;
     llvm::raw_string_ostream s(str);
@@ -1287,6 +1289,7 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
     s << "])";
     return (emitError(s.str()), nullptr);
   }
+
   return builder.getDenseElementsAttr(type, literalParser.getValues())
       .cast<DenseElementsAttr>();
 }
index 69d6a0b..ca2a79b 100644 (file)
@@ -647,7 +647,7 @@ func @elementsattr_invalid() -> () {
 
 func @elementsattr_badtoken() -> () {
 ^bb0:
-  "foo"(){bar: dense<tensor<1xi32>, [tf_opaque]>} : () -> () // expected-error {{expected '[' or scalar constant inside tensor literal}}
+  "foo"(){bar: dense<tensor<1xi32>, [tf_opaque]>} : () -> () // expected-error {{expected element literal of primitive type}}
 }
 
 // -----
@@ -813,19 +813,19 @@ func @complex_loops() {
 // -----
 
 func @mi() {
-  // expected-error @+1 {{expected '[' or scalar constant inside tensor literal}}
+  // expected-error @+1 {{expected element literal of primitive type}}
   "fooi64"(){bar: sparse<vector<1xi64>,[,[,1]
 
 // -----
 
 func @invalid_tensor_literal() {
-  // expected-error @+1 {{expected '[' in tensor literal list}}
+  // expected-error @+1 {{expected 1-d tensor for values}}
   "foof16"(){bar: sparse<vector<1x1x1xf16>, [[0, 0, 0]],  -2.0]>} : () -> ()
 
 // -----
 
 func @invalid_tensor_literal() {
-  // expected-error @+1 {{expected '[' or scalar constant inside tensor literal}}
+  // expected-error @+1 {{expected element literal of primitive type}}
   "fooi16"(){bar: sparse<tensor<2x2x2xi16>, [[1, 1, 0], [0, 1, 0], [0,, [[0, 0, 0]], [-2.0]>} : () -> ()
 
 // -----
index 1aa4858..e67832e 100644 (file)
@@ -592,6 +592,11 @@ func @splattensorattr() -> () {
 
   // CHECK: "splatFloatVector"() {bar: splat<vector<2x1x4xf16>, -5.000000e+00>} : () -> ()
   "splatFloatVector"(){bar: splat<vector<2x1x4xf16>, -5.0>} : () -> ()
+
+  // CHECK: "splatIntScalar"() {bar: splat<tensor<i9>, 5>} : () -> ()
+  "splatIntScalar"() {bar: splat<tensor<i9>, 5>} : () -> ()
+  // CHECK: "splatFloatScalar"() {bar: splat<tensor<f16>, -5.000000e+00>} : () -> ()
+  "splatFloatScalar"() {bar: splat<tensor<f16>, -5.0>} : () -> ()
   return
 }
 
@@ -649,6 +654,11 @@ func @densetensorattr() -> () {
   "float32"(){bar: dense<tensor<2x1x4xf32>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
 // CHECK: "float64"() {bar: dense<tensor<2x1x4xf64>, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> ()
   "float64"(){bar: dense<tensor<2x1x4xf64>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
+
+// CHECK: "intscalar"() {bar: dense<tensor<i32>, 1>} : () -> ()
+  "intscalar"(){bar: dense<tensor<i32>, 1>} : () -> ()
+// CHECK: "floatscalar"() {bar: dense<tensor<f32>, 5.000000e+00>} : () -> ()
+  "floatscalar"(){bar: dense<tensor<f32>, 5.0>} : () -> ()
   return
 }
 
@@ -696,7 +706,9 @@ func @sparsetensorattr() -> () {
   "fooi64"(){bar: sparse<tensor<1xi64>, [[0]], [-1]>} : () -> ()
 // CHECK: "foo2"() {bar: sparse<tensor<0xi32>, {{\[}}], {{\[}}]>} : () -> ()
   "foo2"(){bar: sparse<tensor<0xi32>, [], []>} : () -> ()
-  
+// CHECK: "foo3"() {bar: sparse<tensor<i32>, {{\[}}], {{\[}}]>} : () -> ()
+  "foo3"(){bar: sparse<tensor<i32>, [], []>} : () -> ()
+
 // CHECK: "foof16"() {bar: sparse<tensor<1x1x1xf16>, {{\[\[}}0, 0, 0]], {{\[}}-2.000000e+00]>} : () -> ()
   "foof16"(){bar: sparse<tensor<1x1x1xf16>, [[0, 0, 0]], [-2.0]>} : () -> ()
 // CHECK: "foobf16"() {bar: sparse<tensor<2x2x2xbf16>, {{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2.000000e+00, -1.000000e+00, 5.000000e+00]>} : () -> ()
@@ -707,6 +719,8 @@ func @sparsetensorattr() -> () {
   "foof64"(){bar: sparse<tensor<1xf64>, [[0]], [-1.0]>} : () -> ()
 // CHECK: "foof320"() {bar: sparse<tensor<0xf32>, {{\[}}], {{\[}}]>} : () -> ()
   "foof320"(){bar: sparse<tensor<0xf32>, [], []>} : () -> ()
+// CHECK: "foof321"() {bar: sparse<tensor<f32>, {{\[}}], {{\[}}]>} : () -> ()
+  "foof321"(){bar: sparse<tensor<f32>, [], []>} : () -> ()
   return
 }