From 50f9be6d2d62d36a5c7d6d11d8ed413dc91a4fca Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 18 Dec 2019 17:32:00 -0800 Subject: [PATCH] Add runtime utils support for print_memref_i8 This CL adds print_memref_i8 along with a unit test. PiperOrigin-RevId: 286299237 --- .../mlir-cpu-runner/include/mlir_runner_utils.h | 13 +++++--- mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp | 36 ++++++++++++++++------ mlir/test/mlir-cpu-runner/unranked_memref.mlir | 16 ++++++++++ 3 files changed, 50 insertions(+), 15 deletions(-) diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h index 7671db9..d4b6e1f 100644 --- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h +++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h @@ -89,7 +89,7 @@ template struct UnrankedMemRefType { template void printMemRefMetaData(StreamType &os, StridedMemRefType &V) { static_assert(N > 0, "Expected N > 0"); - os << "Memref base@ = " << V.data << " rank = " << N + os << "Memref base@ = " << reinterpret_cast(V.data) << " rank = " << N << " offset = " << V.offset << " sizes = [" << V.sizes[0]; for (unsigned i = 1; i < N; ++i) os << ", " << V.sizes[i]; @@ -101,14 +101,14 @@ void printMemRefMetaData(StreamType &os, StridedMemRefType &V) { template void printMemRefMetaData(StreamType &os, StridedMemRefType &V) { - os << "Memref base@ = " << V.data << " rank = 0" + os << "Memref base@ = " << reinterpret_cast(V.data) << " rank = 0" << " offset = " << V.offset; } template void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType &V) { os << "Unranked Memref rank = " << V.rank << " " - << "descriptor@ = " << reinterpret_cast(V.descriptor) << " "; + << "descriptor@ = " << reinterpret_cast(V.descriptor) << "\n"; } template struct Vector { @@ -258,8 +258,8 @@ template void printMemRef(StridedMemRefType &M) { } template void printMemRef(StridedMemRefType &M) { - std::cout << "\nMemref base@ = " << M.data << " rank = " << 0 - << " offset = " << M.offset << " data = " << std::endl; + printMemRefMetaData(std::cout, M); + std::cout << " data = " << std::endl; std::cout << "["; MemRefDataPrinter::print(std::cout, M.data, 0, M.offset); std::cout << "]" << std::endl; @@ -270,7 +270,10 @@ template void printMemRef(StridedMemRefType &M) { // Currently exposed C API. //////////////////////////////////////////////////////////////////////////////// extern "C" MLIR_RUNNER_UTILS_EXPORT void +print_memref_i8(UnrankedMemRefType *M); +extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_f32(UnrankedMemRefType *M); + extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_0d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void diff --git a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp index 9ff97cf..56829c6 100644 --- a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp +++ b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp @@ -29,22 +29,38 @@ print_memref_vector_4x4xf32(StridedMemRefType, 2> *M) { impl::printMemRef(*M); } -extern "C" void print_memref_f32(UnrankedMemRefType *M) { +#define MEMREF_CASE(TYPE, RANK) \ + case RANK: \ + impl::printMemRef(*(static_cast *>(ptr))); \ + break + +extern "C" void print_memref_i8(UnrankedMemRefType *M) { printUnrankedMemRefMetaData(std::cout, *M); int rank = M->rank; void *ptr = M->descriptor; -#define MEMREF_CASE(RANK) \ - case RANK: \ - impl::printMemRef(*(static_cast *>(ptr))); \ - break + switch (rank) { + MEMREF_CASE(int8_t, 0); + MEMREF_CASE(int8_t, 1); + MEMREF_CASE(int8_t, 2); + MEMREF_CASE(int8_t, 3); + MEMREF_CASE(int8_t, 4); + default: + assert(0 && "Unsupported rank to print"); + } +} + +extern "C" void print_memref_f32(UnrankedMemRefType *M) { + printUnrankedMemRefMetaData(std::cout, *M); + int rank = M->rank; + void *ptr = M->descriptor; switch (rank) { - MEMREF_CASE(0); - MEMREF_CASE(1); - MEMREF_CASE(2); - MEMREF_CASE(3); - MEMREF_CASE(4); + MEMREF_CASE(float, 0); + MEMREF_CASE(float, 1); + MEMREF_CASE(float, 2); + MEMREF_CASE(float, 3); + MEMREF_CASE(float, 4); default: assert(0 && "Unsupported rank to print"); } diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir index 4e721be..7447e9d 100644 --- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir +++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir @@ -1,19 +1,27 @@ // RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s // CHECK: rank = 2 +// CHECK: rank = 2 // CHECK-SAME: sizes = [10, 3] // CHECK-SAME: strides = [3, 1] // CHECK-COUNT-10: [10, 10, 10] // // CHECK: rank = 2 +// CHECK: rank = 2 // CHECK-SAME: sizes = [10, 3] // CHECK-SAME: strides = [3, 1] // CHECK-COUNT-10: [5, 5, 5] // // CHECK: rank = 2 +// CHECK: rank = 2 // CHECK-SAME: sizes = [10, 3] // CHECK-SAME: strides = [3, 1] // CHECK-COUNT-10: [2, 2, 2] +// +// CHECK: rank = 0 +// CHECK: rank = 0 +// 122 is ASCII for 'z'. +// CHECK: [z] func @main() -> () { %A = alloc() : memref<10x3xf32, 0> %f2 = constant 2.00000e+00 : f32 @@ -36,8 +44,16 @@ func @main() -> () { %U3 = memref_cast %V2 : memref to memref<*xf32> call @print_memref_f32(%U3) : (memref<*xf32>) -> () + // 122 is ASCII for 'z'. + %i8_z = constant 122 : i8 + %I8 = alloc() : memref + store %i8_z, %I8[]: memref + %U4 = memref_cast %I8 : memref to memref<*xi8> + call @print_memref_i8(%U4) : (memref<*xi8>) -> () + dealloc %A : memref<10x3xf32, 0> return } +func @print_memref_i8(memref<*xi8>) func @print_memref_f32(memref<*xf32>) -- 2.7.4