From 09c053bfd0a31143455aeab5285bba47328b0b5f Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sun, 14 Apr 2019 13:41:09 -0700 Subject: [PATCH] Expand the pretty dialect type system to support arbitrary punctuation and other characters within the <>'s now that we can. This will allow quantized types to use the pretty syntax (among others) after a few changes. -- PiperOrigin-RevId: 243521268 --- mlir/examples/Linalg/Linalg1/Example.cpp | 12 ++--- mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp | 2 +- mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp | 2 +- mlir/examples/Linalg/Linalg2/Example.cpp | 14 ++--- mlir/examples/Linalg/Linalg3/Example.cpp | 40 +++++++------- mlir/examples/Linalg/Linalg4/Example.cpp | 22 ++++---- mlir/g3doc/LangRef.md | 5 +- mlir/lib/IR/AsmPrinter.cpp | 56 +++++++++++-------- mlir/lib/Parser/Parser.cpp | 80 +++++++++++++++++++--------- mlir/test/IR/invalid.mlir | 4 +- mlir/test/IR/parser.mlir | 11 ++++ 11 files changed, 152 insertions(+), 96 deletions(-) diff --git a/mlir/examples/Linalg/Linalg1/Example.cpp b/mlir/examples/Linalg/Linalg1/Example.cpp index 9189d23..3ab54d0 100644 --- a/mlir/examples/Linalg/Linalg1/Example.cpp +++ b/mlir/examples/Linalg/Linalg1/Example.cpp @@ -60,8 +60,8 @@ TEST_FUNC(view_op) { // CHECK-LABEL: func @view_op // CHECK: %[[R:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[] : !linalg.view - // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]]] : !linalg<"view"> - // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]], %[[R]]] : !linalg<"view"> + // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]]] : !linalg.view + // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]], %[[R]]] : !linalg.view // clang-format on cleanupAndPrintFunction(f); @@ -95,12 +95,12 @@ TEST_FUNC(slice_op) { // CHECK-NEXT: %[[N:.*]] = dim %0, 1 : memref // CHECK-NEXT: %[[R1:.*]] = linalg.range {{.*}}:%[[M]]:{{.*}} : !linalg.range // CHECK-NEXT: %[[R2:.*]] = linalg.range {{.*}}:%[[N]]:{{.*}} : !linalg.range - // CHECK-NEXT: %[[V:.*]] = linalg.view %0[%[[R1]], %[[R2]]] : !linalg<"view"> + // CHECK-NEXT: %[[V:.*]] = linalg.view %0[%[[R1]], %[[R2]]] : !linalg.view // CHECK-NEXT: for %i0 = 0 to (d0) -> (d0)(%[[M]]) { // CHECK-NEXT: for %i1 = 0 to (d0) -> (d0)(%[[N]]) { - // CHECK-NEXT: %[[S1:.*]] = linalg.slice %[[V]][*, %i0] : !linalg<"view"> - // CHECK-NEXT: "some_consumer"(%[[S1]]) : (!linalg<"view">) -> () - // CHECK-NEXT: %[[S2:.*]] = linalg.slice %[[V]][%i1, *] : !linalg<"view"> + // CHECK-NEXT: %[[S1:.*]] = linalg.slice %[[V]][*, %i0] : !linalg.view + // CHECK-NEXT: "some_consumer"(%[[S1]]) : (!linalg.view) -> () + // CHECK-NEXT: %[[S2:.*]] = linalg.slice %[[V]][%i1, *] : !linalg.view // CHECK-NEXT: %[[S3:.*]] = linalg.slice %[[S2]][%i0] : !linalg.view // CHECK-NEXT: "some_consumer"(%[[S3]]) : (!linalg.view) -> () // clang-format on diff --git a/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp b/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp index 4f30fb1..818a770 100644 --- a/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp @@ -80,7 +80,7 @@ bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) { // A SliceOp prints as: // // ```{.mlir} -// linalg.slice %0[*, %i0] : !linalg<"view"> +// linalg.slice %0[*, %i0] : !linalg.view // ``` // // Where %0 is an ssa-value holding a `view`, %i0 is an ssa-value diff --git a/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp b/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp index 6564391..42be75a 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp @@ -97,7 +97,7 @@ bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) { // A ViewOp prints as: // // ```{.mlir} -// linalg.view %0[%1, %2] : !linalg<"view"> +// linalg.view %0[%1, %2] : !linalg.view // ``` // // Where %0 is an ssa-value holding a MemRef, %1 and %2 are ssa-value each diff --git a/mlir/examples/Linalg/Linalg2/Example.cpp b/mlir/examples/Linalg/Linalg2/Example.cpp index f7ee853..901fbe7 100644 --- a/mlir/examples/Linalg/Linalg2/Example.cpp +++ b/mlir/examples/Linalg/Linalg2/Example.cpp @@ -58,12 +58,12 @@ TEST_FUNC(linalg_ops) { dot(sA, sB, ssC); ret(); // CHECK-LABEL: func @linalg_ops(%arg0: index, %arg1: index, %arg2: index) { - // CHECK: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] : !linalg<"view"> - // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] : !linalg<"view"> - // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}, *] : !linalg<"view"> + // CHECK: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] : !linalg.view + // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] : !linalg.view + // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}, *] : !linalg.view // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}] : !linalg.view - // CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg<"view"> - // CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg<"view"> + // CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg.view + // CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg.view // CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view // clang-format on @@ -97,8 +97,8 @@ TEST_FUNC(linalg_ops_folded_slices) { ret(); // CHECK-LABEL: func @linalg_ops_folded_slices(%arg0: index, %arg1: index, %arg2: index) { // CHECK-NOT: linalg.slice - // CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg<"view"> - // CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg<"view"> + // CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg.view + // CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg.view // CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view // clang-format on diff --git a/mlir/examples/Linalg/Linalg3/Example.cpp b/mlir/examples/Linalg/Linalg3/Example.cpp index 191f044..69a430b 100644 --- a/mlir/examples/Linalg/Linalg3/Example.cpp +++ b/mlir/examples/Linalg/Linalg3/Example.cpp @@ -69,11 +69,11 @@ TEST_FUNC(matmul_as_matvec) { // clang-format off // CHECK-LABEL: func @matmul_as_matvec(%arg0: memref, %arg1: memref, %arg2: memref) { // CHECK: %[[N:.*]] = dim %arg2, 1 : memref - // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg.view // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) { - // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view"> - // CHECK: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view"> - // CHECK: linalg.matvec(%[[vA]], %[[vB]], %[[vC]]) : !linalg<"view"> + // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg.view + // CHECK: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg.view + // CHECK: linalg.matvec(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view // clang-format on cleanupAndPrintFunction(f); } @@ -90,9 +90,9 @@ TEST_FUNC(matmul_as_dot) { // CHECK: %[[M:.*]] = dim %arg0, 0 : memref // CHECK: %[[N:.*]] = dim %arg2, 1 : memref // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) { - // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg.view // CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) { - // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg.view // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg.view // CHECK-NEXT: linalg.dot(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view // clang-format on @@ -113,20 +113,20 @@ TEST_FUNC(matmul_as_loops) { // CHECK: %[[rM:.*]] = linalg.range %c0:%[[M]]:%c1 : !linalg.range // CHECK: %[[rN:.*]] = linalg.range %c0:%[[N]]:%c1 : !linalg.range // CHECK: %[[rK:.*]] = linalg.range %c0:%[[K]]:%c1 : !linalg.range - // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg<"view"> - // CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg<"view"> - // CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg.view + // CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg.view + // CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg.view // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) { // CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) { // CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) { // CHECK: %{{.*}} = cmpi "eq", %{{.*}} : index - // CHECK: %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg<"view"> + // CHECK: %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg.view // CHECK: %{{.*}} = select {{.*}} : f32 - // CHECK: %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg<"view"> - // CHECK: %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg<"view"> + // CHECK: %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg.view + // CHECK: %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg.view // CHECK: %{{.*}} = mulf {{.*}} : f32 // CHECK: %{{.*}} = addf {{.*}} : f32 - // CHECK: linalg.store {{.*}}[%i0, %i1] : !linalg<"view"> + // CHECK: linalg.store {{.*}}[%i0, %i1] : !linalg.view // clang-format on cleanupAndPrintFunction(f); } @@ -144,20 +144,20 @@ TEST_FUNC(matmul_as_matvec_as_loops) { // CHECK: %[[M:.*]] = dim %arg0, 0 : memref // CHECK: %[[N:.*]] = dim %arg2, 1 : memref // CHECK: %[[K:.*]] = dim %arg0, 1 : memref - // CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg.view // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) { - // CHECK: %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg<"view"> - // CHECK: %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg<"view"> + // CHECK: %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg.view + // CHECK: %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg.view // CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) { // CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) { // CHECK: %{{.*}} = cmpi "eq", %i2, %{{.*}} : index - // CHECK: %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg<"view"> + // CHECK: %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg.view // CHECK: %[[C2:.*]] = select %{{.*}}, %{{.*}}, %[[C]] : f32 - // CHECK: %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg<"view"> - // CHECK: %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg<"view"> + // CHECK: %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg.view + // CHECK: %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg.view // CHECK: %{{.*}} = mulf %[[A]], %[[B]] : f32 // CHECK: %{{.*}} = addf %[[C2]], %{{.*}} : f32 - // CHECK: linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg<"view"> + // CHECK: linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg.view // clang-format on cleanupAndPrintFunction(f); } diff --git a/mlir/examples/Linalg/Linalg4/Example.cpp b/mlir/examples/Linalg/Linalg4/Example.cpp index 8eedbbd..00da459 100644 --- a/mlir/examples/Linalg/Linalg4/Example.cpp +++ b/mlir/examples/Linalg/Linalg4/Example.cpp @@ -114,13 +114,13 @@ TEST_FUNC(matmul_tiled_views) { // CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0) // CHECK-NEXT: %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg.range // CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range - // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg.view // CHECK: %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1) // CHECK-NEXT: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1) // CHECK-NEXT: %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg.range - // CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : !linalg<"view"> - // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : !linalg<"view"> - // CHECK-NEXT: linalg.matmul(%[[vA]], %[[vB]], %[[vC]]) : !linalg<"view"> + // CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : !linalg.view + // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : !linalg.view + // CHECK-NEXT: linalg.matmul(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view // clang-format on cleanupAndPrintFunction(f); } @@ -150,23 +150,23 @@ TEST_FUNC(matmul_tiled_views_as_loops) { // CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0) // CHECK-NEXT: %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg.range // CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range - // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg.view // CHECK: %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1) // CHECK-NEXT: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1) // CHECK-NEXT: %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg.range - // CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : !linalg<"view"> - // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : !linalg<"view"> + // CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : !linalg.view + // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : !linalg.view // CHECK-NEXT: affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0)(%[[i0max]]) { // CHECK-NEXT: affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0)(%[[i1max]]) { // CHECK-NEXT: affine.for %i4 = 0 to (d0) -> (d0)(%[[K]]) { // CHECK-NEXT: %{{.*}} = cmpi "eq", %i4, %c0 : index - // CHECK-NEXT: %{{.*}} = linalg.load %[[vC]][%i2, %i3] : !linalg<"view"> + // CHECK-NEXT: %{{.*}} = linalg.load %[[vC]][%i2, %i3] : !linalg.view // CHECK-NEXT: %{{.*}} = select %{{.*}}, %cst, %{{.*}} : f32 - // CHECK-NEXT: %{{.*}} = linalg.load %[[vB]][%i4, %i3] : !linalg<"view"> - // CHECK-NEXT: %{{.*}} = linalg.load %[[vA]][%i2, %i4] : !linalg<"view"> + // CHECK-NEXT: %{{.*}} = linalg.load %[[vB]][%i4, %i3] : !linalg.view + // CHECK-NEXT: %{{.*}} = linalg.load %[[vA]][%i2, %i4] : !linalg.view // CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 - // CHECK-NEXT: linalg.store %{{.*}}, %[[vC]][%i2, %i3] : !linalg<"view"> + // CHECK-NEXT: linalg.store %{{.*}}, %[[vC]][%i2, %i3] : !linalg.view // clang-format on cleanupAndPrintFunction(f); } diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index bc823af..3cc5ea6 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -2308,7 +2308,10 @@ dialect-type ::= '!' alias-name pretty-dialect-type-body? pretty-dialect-type-body ::= '<' pretty-dialect-type-contents+ '>' pretty-dialect-type-contents ::= pretty-dialect-type-body - | '[0-9a-zA-Z.-]+' + | '(' pretty-dialect-type-contents+ ')' + | '[' pretty-dialect-type-contents+ ']' + | '{' pretty-dialect-type-contents+ '}' + | '[^[<({>\])}\0]+' ``` diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 4f18e3d..fb1d15e 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -722,38 +722,52 @@ static bool isDialectTypeSimpleEnoughForPrettyForm(StringRef typeName) { if (typeName.front() != '<' || typeName.back() != '>') return false; - unsigned bracketDepth = 0; - while (!typeName.empty()) { + SmallVector nestedPunctuation; + do { + // If we ran out of characters, then we had a punctuation mismatch. + if (typeName.empty()) + return false; + auto c = typeName.front(); + typeName = typeName.drop_front(); + switch (c) { + // We never allow nul characters. This is an EOF indicator for the lexer + // which we could handle, but isn't important for any known dialect. + case '\0': + return false; case '<': - ++bracketDepth; - break; + case '[': + case '(': + case '{': + nestedPunctuation.push_back(c); + continue; + // Reject types with mismatched brackets. case '>': - // Reject types with mismatched brackets. - if (bracketDepth == 0) + if (nestedPunctuation.pop_back_val() != '<') return false; - --bracketDepth; break; - - case '.': - case '-': - case ' ': - case ',': - // These are all ok. + case ']': + if (nestedPunctuation.pop_back_val() != '[') + return false; + break; + case ')': + if (nestedPunctuation.pop_back_val() != ')') + return false; + break; + case '}': + if (nestedPunctuation.pop_back_val() != '}') + return false; break; - default: - if (isalpha(c) || isdigit(c)) - break; - // Unknown character abort. - return false; + continue; } - typeName = typeName.drop_front(); - } + // We're done when the punctuation is fully matched. + } while (!nestedPunctuation.empty()); - return bracketDepth == 0; + // If there were extra characters, then we failed. + return typeName.empty(); } void ModulePrinter::printType(Type type) { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 665b1b8..1156fc6 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -487,39 +487,67 @@ Parser::parseDimensionListRanked(SmallVectorImpl &dimensions, /// /// pretty-dialect-type-body ::= '<' pretty-dialect-type-contents+ '>' /// pretty-dialect-type-contents ::= pretty-dialect-type-body -/// | '[0-9a-zA-Z.,-]+' +/// | '(' pretty-dialect-type-contents+ ')' +/// | '[' pretty-dialect-type-contents+ ']' +/// | '{' pretty-dialect-type-contents+ '}' +/// | '[^[<({>\])}\0]+' /// ParseResult Parser::parsePrettyDialectTypeName(StringRef &prettyName) { - consumeToken(Token::less); + // Pretty type names are a relatively unstructured format that contains a + // series of properly nested punctuation, with anything else in the middle. + // Scan ahead to find it and consume it if successful, otherwise emit an + // error. + auto *curPtr = getTokenSpelling().data(); + + SmallVector nestedPunctuation; + + // Scan over the nested punctuation, bailing out on error and consuming until + // we find the end. We know that we're currently looking at the '<', so we + // can go until we find the matching '>' character. + assert(*curPtr == '<'); + do { + char c = *curPtr++; + switch (c) { + case '\0': + // This also handles the EOF case. + return emitError("unexpected nul or EOF in pretty dialect name"); + case '<': + case '[': + case '(': + case '{': + nestedPunctuation.push_back(c); + continue; - while (1) { - if (getToken().is(Token::greater)) { - auto *start = prettyName.begin(); - // Update the size of the covered range to include all the tokens we have - // skipped over. - unsigned length = getTokenSpelling().end() - start; - prettyName = StringRef(start, length); - consumeToken(Token::greater); - return ParseSuccess; - } + case '>': + if (nestedPunctuation.pop_back_val() != '<') + return emitError("unbalanced '>' character in pretty dialect name"); + break; + case ']': + if (nestedPunctuation.pop_back_val() != '[') + return emitError("unbalanced ']' character in pretty dialect name"); + break; + case ')': + if (nestedPunctuation.pop_back_val() != '(') + return emitError("unbalanced ')' character in pretty dialect name"); + break; + case '}': + if (nestedPunctuation.pop_back_val() != '{') + return emitError("unbalanced '}' character in pretty dialect name"); + break; - if (getToken().is(Token::less)) { - if (parsePrettyDialectTypeName(prettyName)) - return ParseFailure; + default: continue; } + } while (!nestedPunctuation.empty()); - // Check to see if the token contains simple characters. - bool isSimple = true; - for (auto c : getTokenSpelling()) - isSimple &= - (isalpha(c) || isdigit(c) || c == '.' || c == '-' || c == ','); + // Ok, we succeeded, remember where we stopped, reset the lexer to know it is + // consuming all this stuff, and return. + state.lex.resetPointer(curPtr); - if (!isSimple || getToken().is(Token::eof)) - return emitError("expected simple name in pretty dialect type"); - - consumeToken(); - } + unsigned length = curPtr - prettyName.begin(); + prettyName = StringRef(prettyName.begin(), length); + consumeToken(); + return ParseSuccess; } /// Parse an extended type. @@ -582,7 +610,7 @@ Type Parser::parseExtendedType() { // of it into prettyName. if (getToken().is(Token::less) && prettyName.bytes_end() == getTokenSpelling().bytes_begin()) { - if (parsePrettyDialectTypeName(prettyName) == ParseFailure) + if (parsePrettyDialectTypeName(prettyName)) return nullptr; } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index eb97289..8103fc7 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -927,8 +927,8 @@ func @invalid_nested_dominance() { // ----- -// expected-error @+1 {{expected simple name in pretty dialect type}} -func @invalid_unknown_type_dialect_name() -> !invalid.dialect +// expected-error @+1 {{unbalanced ']' character in pretty dialect name}} +func @invalid_unknown_type_dialect_name() -> !invalid.dialect // ----- diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 7233af0..d9368ba 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -860,5 +860,16 @@ func @pretty_dialect_type() { // CHECK: %2 = "foo.unknown_op"() : () -> !foo.complextype> %2 = "foo.unknown_op"() : () -> !foo.complextype> + + // CHECK: %3 = "foo.unknown_op"() : () -> !foo.complextype> + %3 = "foo.unknown_op"() : () -> !foo.complextype> + + // CHECK: %4 = "foo.unknown_op"() : () -> !foo.dialect + %4 = "foo.unknown_op"() : () -> !foo.dialect + + // Extraneous extra > character can't use the pretty syntax. + // CHECK: %5 = "foo.unknown_op"() : () -> !foo<"dialect>"> + %5 = "foo.unknown_op"() : () -> !foo<"dialect>"> + return } -- 2.7.4