From 674f2df4fe0b6af901fc7c7e8bd3fb37e1e8516c Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 17 Aug 2020 20:25:28 +0200 Subject: [PATCH] [mlir] Fix printing of unranked memrefs in non-default memory space The type printer was ignoring the memory space on unranked memrefs. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D86096 --- mlir/lib/IR/AsmPrinter.cpp | 3 +++ mlir/test/IR/core-ops.mlir | 5 +++++ mlir/test/IR/invalid-ops.mlir | 2 +- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index c8b4a86..61eecb8 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1650,6 +1650,9 @@ void ModulePrinter::printType(Type type) { .Case([&](UnrankedMemRefType memrefTy) { os << "memref<*x"; printType(memrefTy.getElementType()); + // Only print the memory space if it is the non-default one. + if (memrefTy.getMemorySpace()) + os << ", " << memrefTy.getMemorySpace(); os << '>'; }) .Case([&](ComplexType complexTy) { diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 89bcd75f..7447071 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -703,6 +703,11 @@ func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref, %arg2 : memref<64 return } +// Check that unranked memrefs with non-default memory space roundtrip +// properly. +// CHECK-LABEL: @unranked_memref_roundtrip(memref<*xf32, 4>) +func @unranked_memref_roundtrip(memref<*xf32, 4>) + // CHECK-LABEL: func @memref_view(%arg0 func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8> diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 6302a8a..5573911 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1076,7 +1076,7 @@ func @invalid_prefetch_locality_hint(%i : index) { // incompatible memory space func @invalid_memref_cast() { %0 = alloc() : memref<2x5xf32, 0> - // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32>' are cast incompatible}} + // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}} %1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 1> return } -- 2.7.4