From b89b3cdb612f43e91c7330a13c3f60a1ad251e50 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Sun, 7 May 2023 19:28:46 -0400 Subject: [PATCH] [mlir][TestDialect] Fix invalid custom op printers This fixes a few custom printers which were printing IR that couldn't be round-tripped. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D150080 --- mlir/test/IR/parser.mlir | 13 ++++++++----- mlir/test/IR/pretty-region-args.mlir | 6 ++---- mlir/test/lib/Dialect/Test/TestDialect.cpp | 17 ++++++++++++----- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 2041ca9..66c9adc 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1,4 +1,7 @@ // RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | FileCheck %s -check-prefix GENERIC +// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect -mlir-print-op-generic | FileCheck %s -check-prefix GENERIC // CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d4, d3)> #map = affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d4, d3)> @@ -1128,10 +1131,10 @@ func.func @special_float_values_in_tensors() { // Test parsing of an op with multiple region arguments, and without a // delimiter. -// CHECK-LABEL: func @op_with_region_args +// GENERIC-LABEL: op_with_region_args func.func @op_with_region_args() { - // CHECK: "test.polyfor"() ({ - // CHECK-NEXT: ^bb{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // GENERIC: "test.polyfor"() ({ + // GENERIC-NEXT: ^bb{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index): test.polyfor %i, %j, %k { "foo"() : () -> () } @@ -1185,9 +1188,9 @@ func.func @parse_wrapped_keyword_test() { return } -// CHECK-LABEL: func @parse_base64_test +// GENERIC-LABEL: parse_base64_test func.func @parse_base64_test() { - // CHECK: test.parse_b64 "hello world" + // GENERIC: "test.parse_b64"() <{b64 = "hello world"}> test.parse_b64 "aGVsbG8gd29ybGQ=" return } diff --git a/mlir/test/IR/pretty-region-args.mlir b/mlir/test/IR/pretty-region-args.mlir index 764bea2..f642014 100644 --- a/mlir/test/IR/pretty-region-args.mlir +++ b/mlir/test/IR/pretty-region-args.mlir @@ -6,8 +6,7 @@ func.func @custom_region_names() -> () { ^bb0(%arg0: index, %arg1: index, %arg2: index): "foo"() : () -> () }) { arg_names = ["i", "j", "k"] } : () -> () - // CHECK: test.polyfor - // CHECK-NEXT: ^bb{{.*}}(%i: index, %j: index, %k: index): + // CHECK: test.polyfor %i, %j, %k return } @@ -18,8 +17,7 @@ func.func @weird_names() -> () { ^bb0(%arg0: i32, %arg1: i32, %arg2: index): "foo"() : () -> i32 }) { arg_names = ["a .^x", "0"] } : () -> () - // CHECK: test.polyfor - // CHECK-NEXT: ^bb{{.*}}(%a_.5Ex: i32, %_0: i32, %arg0: index): + // CHECK: test.polyfor %a_.5Ex, %_0, %arg0 // CHECK-NEXT: %0 = "foo"() return } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index d167d02..1ec7698 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -34,6 +34,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Base64.h" #include #include @@ -1020,7 +1021,7 @@ ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, } void IsolatedRegionOp::print(OpAsmPrinter &p) { - p << "test.isolated_region "; + p << ' '; p.printOperand(getOperand()); p.shadowRegionArgs(getRegion(), getOperand()); p << ' '; @@ -1054,7 +1055,7 @@ ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { } void AffineScopeOp::print(OpAsmPrinter &p) { - p << "test.affine_scope "; + p << " "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } @@ -1103,8 +1104,7 @@ ParseResult ParseB64BytesOp::parse(OpAsmParser &parser, } void ParseB64BytesOp::print(OpAsmPrinter &p) { - // Don't print the base64 version to check that we decoded it correctly. - p << " \"" << getB64() << "\""; + p << " \"" << llvm::encodeBase64(getB64()) << "\""; } //===----------------------------------------------------------------------===// @@ -1260,7 +1260,14 @@ ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { return parser.parseRegion(*body, ivsInfo); } -void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } +void PolyForOp::print(OpAsmPrinter &p) { + p << " "; + llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) { + p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true); + }); + p << " "; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); +} void PolyForOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { -- 2.7.4