Add runtime utils support for print_memref_i8
authorNicolas Vasilache <ntv@google.com>
Thu, 19 Dec 2019 01:32:00 +0000 (17:32 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 19 Dec 2019 01:32:35 +0000 (17:32 -0800)
This CL adds print_memref_i8 along with a unit test.

PiperOrigin-RevId: 286299237

mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
mlir/test/mlir-cpu-runner/unranked_memref.mlir

index 7671db9..d4b6e1f 100644 (file)
@@ -89,7 +89,7 @@ template <typename T> struct UnrankedMemRefType {
 template <typename StreamType, typename T, int N>
 void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
   static_assert(N > 0, "Expected N > 0");
-  os << "Memref base@ = " << V.data << " rank = " << N
+  os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = " << N
      << " offset = " << V.offset << " sizes = [" << V.sizes[0];
   for (unsigned i = 1; i < N; ++i)
     os << ", " << V.sizes[i];
@@ -101,14 +101,14 @@ void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
 
 template <typename StreamType, typename T>
 void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) {
-  os << "Memref base@ = " << V.data << " rank = 0"
+  os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = 0"
      << " offset = " << V.offset;
 }
 
 template <typename T, typename StreamType>
 void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &V) {
   os << "Unranked Memref rank = " << V.rank << " "
-     << "descriptor@ = " << reinterpret_cast<float *>(V.descriptor) << " ";
+     << "descriptor@ = " << reinterpret_cast<void *>(V.descriptor) << "\n";
 }
 
 template <typename T, int Dim, int... Dims> struct Vector {
@@ -258,8 +258,8 @@ template <typename T, int N> void printMemRef(StridedMemRefType<T, N> &M) {
 }
 
 template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
-  std::cout << "\nMemref base@ = " << M.data << " rank = " << 0
-            << " offset = " << M.offset << " data = " << std::endl;
+  printMemRefMetaData(std::cout, M);
+  std::cout << " data = " << std::endl;
   std::cout << "[";
   MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset);
   std::cout << "]" << std::endl;
@@ -270,7 +270,10 @@ template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
 // Currently exposed C API.
 ////////////////////////////////////////////////////////////////////////////////
 extern "C" MLIR_RUNNER_UTILS_EXPORT void
+print_memref_i8(UnrankedMemRefType<int8_t> *M);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
 print_memref_f32(UnrankedMemRefType<float> *M);
+
 extern "C" MLIR_RUNNER_UTILS_EXPORT void
 print_memref_0d_f32(StridedMemRefType<float, 0> *M);
 extern "C" MLIR_RUNNER_UTILS_EXPORT void
index 9ff97cf..56829c6 100644 (file)
@@ -29,22 +29,38 @@ print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
   impl::printMemRef(*M);
 }
 
-extern "C" void print_memref_f32(UnrankedMemRefType<float> *M) {
+#define MEMREF_CASE(TYPE, RANK)                                                \
+  case RANK:                                                                   \
+    impl::printMemRef(*(static_cast<StridedMemRefType<TYPE, RANK> *>(ptr)));   \
+    break
+
+extern "C" void print_memref_i8(UnrankedMemRefType<int8_t> *M) {
   printUnrankedMemRefMetaData(std::cout, *M);
   int rank = M->rank;
   void *ptr = M->descriptor;
 
-#define MEMREF_CASE(RANK)                                                      \
-  case RANK:                                                                   \
-    impl::printMemRef(*(static_cast<StridedMemRefType<float, RANK> *>(ptr)));  \
-    break
+  switch (rank) {
+    MEMREF_CASE(int8_t, 0);
+    MEMREF_CASE(int8_t, 1);
+    MEMREF_CASE(int8_t, 2);
+    MEMREF_CASE(int8_t, 3);
+    MEMREF_CASE(int8_t, 4);
+  default:
+    assert(0 && "Unsupported rank to print");
+  }
+}
+
+extern "C" void print_memref_f32(UnrankedMemRefType<float> *M) {
+  printUnrankedMemRefMetaData(std::cout, *M);
+  int rank = M->rank;
+  void *ptr = M->descriptor;
 
   switch (rank) {
-    MEMREF_CASE(0);
-    MEMREF_CASE(1);
-    MEMREF_CASE(2);
-    MEMREF_CASE(3);
-    MEMREF_CASE(4);
+    MEMREF_CASE(float, 0);
+    MEMREF_CASE(float, 1);
+    MEMREF_CASE(float, 2);
+    MEMREF_CASE(float, 3);
+    MEMREF_CASE(float, 4);
   default:
     assert(0 && "Unsupported rank to print");
   }
index 4e721be..7447e9d 100644 (file)
@@ -1,19 +1,27 @@
 // RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
 
 // CHECK: rank = 2
+// CHECK: rank = 2
 // CHECK-SAME: sizes = [10, 3]
 // CHECK-SAME: strides = [3, 1]
 // CHECK-COUNT-10: [10, 10, 10]
 //
 // CHECK: rank = 2
+// CHECK: rank = 2
 // CHECK-SAME: sizes = [10, 3]
 // CHECK-SAME: strides = [3, 1]
 // CHECK-COUNT-10: [5, 5, 5]
 //
 // CHECK: rank = 2
+// CHECK: rank = 2
 // CHECK-SAME: sizes = [10, 3]
 // CHECK-SAME: strides = [3, 1]
 // CHECK-COUNT-10: [2, 2, 2]
+//
+// CHECK: rank = 0
+// CHECK: rank = 0
+// 122 is ASCII for 'z'.
+// CHECK: [z]
 func @main() -> () {
     %A = alloc() : memref<10x3xf32, 0>
     %f2 = constant 2.00000e+00 : f32
@@ -36,8 +44,16 @@ func @main() -> () {
     %U3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32>
     call @print_memref_f32(%U3) : (memref<*xf32>) -> ()
 
+    // 122 is ASCII for 'z'.
+    %i8_z = constant 122 : i8
+    %I8 = alloc() : memref<i8>
+    store %i8_z, %I8[]: memref<i8>
+    %U4 = memref_cast %I8 : memref<i8> to memref<*xi8>
+    call @print_memref_i8(%U4) : (memref<*xi8>) -> ()
+
     dealloc %A : memref<10x3xf32, 0>
     return
 }
 
+func @print_memref_i8(memref<*xi8>)
 func @print_memref_f32(memref<*xf32>)