From 95b5a4fd675bd125086d3878df1d4cd3d47d2485 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 12 Dec 2019 07:32:36 -0800 Subject: [PATCH] Move cpu runner utils templates to .h This allows reusing the implementation in various places by just including and permits more easily writing test functions without explicit template instantiations. This also modifies UnrankedMemRefType to take a template type parameter since it cannot be type agnostic atm. PiperOrigin-RevId: 285187711 --- .../mlir-cpu-runner/include/mlir_runner_utils.h | 181 +++++++++++++++++++-- mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp | 149 +---------------- mlir/test/mlir-cpu-runner/utils.mlir | 4 +- 3 files changed, 175 insertions(+), 159 deletions(-) diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h index 8a58493..ba68295 100644 --- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h +++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h @@ -17,7 +17,9 @@ #ifndef MLIR_CPU_RUNNER_MLIRUTILS_H_ #define MLIR_CPU_RUNNER_MLIRUTILS_H_ +#include #include +#include #ifdef _WIN32 #ifndef MLIR_RUNNER_UTILS_EXPORT @@ -79,7 +81,7 @@ template struct StridedMemRefType { }; // Unranked MemRef -struct UnrankedMemRefType { +template struct UnrankedMemRefType { int64_t rank; void *descriptor; }; @@ -103,26 +105,12 @@ void printMemRefMetaData(StreamType &os, StridedMemRefType &V) { << " offset = " << V.offset; } -extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_f32(UnrankedMemRefType *M); - -template -void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType &V) { +template +void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType &V) { os << "Unranked Memref rank = " << V.rank << " " << "descriptor@ = " << reinterpret_cast(V.descriptor) << " "; } -extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_0d_f32(StridedMemRefType *M); -extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_1d_f32(StridedMemRefType *M); -extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_2d_f32(StridedMemRefType *M); -extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_3d_f32(StridedMemRefType *M); -extern "C" MLIR_RUNNER_UTILS_EXPORT void -print_memref_4d_f32(StridedMemRefType *M); - template struct Vector { Vector vector[Dim]; }; @@ -135,6 +123,165 @@ using Vector3D = Vector; template using Vector4D = Vector; +//////////////////////////////////////////////////////////////////////////////// +// Templated instantiation follows. +//////////////////////////////////////////////////////////////////////////////// +namespace impl { +template +std::ostream &operator<<(std::ostream &os, const Vector &v); + +template struct StaticSizeMult { + static constexpr int value = 1; +}; + +template struct StaticSizeMult { + static constexpr int value = N * StaticSizeMult::value; +}; + +static void printSpace(std::ostream &os, int count) { + for (int i = 0; i < count; ++i) { + os << ' '; + } +} + +template struct VectorDataPrinter { + static void print(std::ostream &os, const Vector &val); +}; + +template +void VectorDataPrinter::print(std::ostream &os, + const Vector &val) { + static_assert(M > 0, "0 dimensioned tensor"); + static_assert(sizeof(val) == M * StaticSizeMult::value * sizeof(T), + "Incorrect vector size!"); + // First + os << "(" << val.vector[0]; + if (M > 1) + os << ", "; + if (sizeof...(Dims) > 1) + os << "\n"; + // Kernel + for (unsigned i = 1; i + 1 < M; ++i) { + printSpace(os, 2 * sizeof...(Dims)); + os << val.vector[i] << ", "; + if (sizeof...(Dims) > 1) + os << "\n"; + } + // Last + if (M > 1) { + printSpace(os, sizeof...(Dims)); + os << val.vector[M - 1]; + } + os << ")"; +} + +template +std::ostream &operator<<(std::ostream &os, const Vector &v) { + VectorDataPrinter::print(os, v); + return os; +} + +template struct MemRefDataPrinter { + static void print(std::ostream &os, T *base, int64_t rank, int64_t offset, + int64_t *sizes, int64_t *strides); + static void printFirst(std::ostream &os, T *base, int64_t rank, + int64_t offset, int64_t *sizes, int64_t *strides); + static void printLast(std::ostream &os, T *base, int64_t rank, int64_t offset, + int64_t *sizes, int64_t *strides); +}; + +template struct MemRefDataPrinter { + static void print(std::ostream &os, T *base, int64_t rank, int64_t offset, + int64_t *sizes = nullptr, int64_t *strides = nullptr); +}; + +template +void MemRefDataPrinter::printFirst(std::ostream &os, T *base, + int64_t rank, int64_t offset, + int64_t *sizes, int64_t *strides) { + os << "["; + MemRefDataPrinter::print(os, base, rank, offset, sizes + 1, + strides + 1); + // If single element, close square bracket and return early. + if (sizes[0] <= 1) { + os << "]"; + return; + } + os << ", "; + if (N > 1) + os << "\n"; +} + +template +void MemRefDataPrinter::print(std::ostream &os, T *base, int64_t rank, + int64_t offset, int64_t *sizes, + int64_t *strides) { + printFirst(os, base, rank, offset, sizes, strides); + for (unsigned i = 1; i + 1 < sizes[0]; ++i) { + printSpace(os, rank - N + 1); + MemRefDataPrinter::print(os, base, rank, offset + i * strides[0], + sizes + 1, strides + 1); + os << ", "; + if (N > 1) + os << "\n"; + } + if (sizes[0] <= 1) + return; + printLast(os, base, rank, offset, sizes, strides); +} + +template +void MemRefDataPrinter::printLast(std::ostream &os, T *base, int64_t rank, + int64_t offset, int64_t *sizes, + int64_t *strides) { + printSpace(os, rank - N + 1); + MemRefDataPrinter::print(os, base, rank, + offset + (sizes[0] - 1) * (*strides), + sizes + 1, strides + 1); + os << "]"; +} + +template +void MemRefDataPrinter::print(std::ostream &os, T *base, int64_t rank, + int64_t offset, int64_t *sizes, + int64_t *strides) { + os << base[offset]; +} + +template void printMemRef(StridedMemRefType &M) { + static_assert(N > 0, "Expected N > 0"); + printMemRefMetaData(std::cout, M); + std::cout << " data = " << std::endl; + MemRefDataPrinter::print(std::cout, M.data, N, M.offset, M.sizes, + M.strides); + std::cout << std::endl; +} + +template void printMemRef(StridedMemRefType &M) { + std::cout << "\nMemref base@ = " << M.data << " rank = " << 0 + << " offset = " << M.offset << " data = " << std::endl; + std::cout << "["; + MemRefDataPrinter::print(std::cout, M.data, 0, M.offset); + std::cout << "]" << std::endl; +} +} // namespace impl + +//////////////////////////////////////////////////////////////////////////////// +// Currently exposed C API. +//////////////////////////////////////////////////////////////////////////////// +extern "C" MLIR_RUNNER_UTILS_EXPORT void +print_memref_f32(UnrankedMemRefType *M); +extern "C" MLIR_RUNNER_UTILS_EXPORT void +print_memref_0d_f32(StridedMemRefType *M); +extern "C" MLIR_RUNNER_UTILS_EXPORT void +print_memref_1d_f32(StridedMemRefType *M); +extern "C" MLIR_RUNNER_UTILS_EXPORT void +print_memref_2d_f32(StridedMemRefType *M); +extern "C" MLIR_RUNNER_UTILS_EXPORT void +print_memref_3d_f32(StridedMemRefType *M); +extern "C" MLIR_RUNNER_UTILS_EXPORT void +print_memref_4d_f32(StridedMemRefType *M); + extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_vector_4x4xf32(StridedMemRefType, 2> *M); diff --git a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp index 056ff65..f8007d7 100644 --- a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp +++ b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp @@ -21,153 +21,20 @@ //===----------------------------------------------------------------------===// #include "include/mlir_runner_utils.h" -#include -#include - -template struct StaticSizeMult { - static constexpr int value = 1; -}; - -template struct StaticSizeMult { - static constexpr int value = N * StaticSizeMult::value; -}; - -static void printSpace(std::ostream &os, int count) { - for (int i = 0; i < count; ++i) { - os << ' '; - } -} - -template struct VectorDataPrinter { - static void print(std::ostream &os, const Vector &val); -}; - -template -void VectorDataPrinter::print(std::ostream &os, - const Vector &val) { - static_assert(M > 0, "0 dimensioned tensor"); - static_assert(sizeof(val) == M * StaticSizeMult::value * sizeof(T), - "Incorrect vector size!"); - // First - os << "(" << val.vector[0]; - if (M > 1) - os << ", "; - if (sizeof...(Dims) > 1) - os << "\n"; - // Kernel - for (unsigned i = 1; i + 1 < M; ++i) { - printSpace(os, 2 * sizeof...(Dims)); - os << val.vector[i] << ", "; - if (sizeof...(Dims) > 1) - os << "\n"; - } - // Last - printSpace(os, sizeof...(Dims)); - os << val.vector[M - 1] << ")"; -} - -template -std::ostream &operator<<(std::ostream &os, const Vector &v) { - VectorDataPrinter::print(os, v); - return os; -} - -template struct MemRefDataPrinter { - static void print(std::ostream &os, T *base, int64_t rank, int64_t offset, - int64_t *sizes, int64_t *strides); - static void printFirst(std::ostream &os, T *base, int64_t rank, - int64_t offset, int64_t *sizes, int64_t *strides); - static void printLast(std::ostream &os, T *base, int64_t rank, int64_t offset, - int64_t *sizes, int64_t *strides); -}; - -template struct MemRefDataPrinter { - static void print(std::ostream &os, T *base, int64_t rank, int64_t offset, - int64_t *sizes = nullptr, int64_t *strides = nullptr); -}; - -template -void MemRefDataPrinter::printFirst(std::ostream &os, T *base, - int64_t rank, int64_t offset, - int64_t *sizes, int64_t *strides) { - os << "["; - MemRefDataPrinter::print(os, base, rank, offset, sizes + 1, - strides + 1); - // If single element, close square bracket and return early. - if (sizes[0] <= 1) { - os << "]"; - return; - } - os << ", "; - if (N > 1) - os << "\n"; -} - -template -void MemRefDataPrinter::print(std::ostream &os, T *base, int64_t rank, - int64_t offset, int64_t *sizes, - int64_t *strides) { - printFirst(os, base, rank, offset, sizes, strides); - for (unsigned i = 1; i + 1 < sizes[0]; ++i) { - printSpace(os, rank - N + 1); - MemRefDataPrinter::print(os, base, rank, offset + i * strides[0], - sizes + 1, strides + 1); - os << ", "; - if (N > 1) - os << "\n"; - } - if (sizes[0] <= 1) - return; - printLast(os, base, rank, offset, sizes, strides); -} - -template -void MemRefDataPrinter::printLast(std::ostream &os, T *base, int64_t rank, - int64_t offset, int64_t *sizes, - int64_t *strides) { - printSpace(os, rank - N + 1); - MemRefDataPrinter::print(os, base, rank, - offset + (sizes[0] - 1) * (*strides), - sizes + 1, strides + 1); - os << "]"; -} - -template -void MemRefDataPrinter::print(std::ostream &os, T *base, int64_t rank, - int64_t offset, int64_t *sizes, - int64_t *strides) { - os << base[offset]; -} - -template void printMemRef(StridedMemRefType &M) { - static_assert(N > 0, "Expected N > 0"); - printMemRefMetaData(std::cout, M); - std::cout << " data = " << std::endl; - MemRefDataPrinter::print(std::cout, M.data, N, M.offset, M.sizes, - M.strides); - std::cout << std::endl; -} - -template void printMemRef(StridedMemRefType &M) { - std::cout << "\nMemref base@ = " << M.data << " rank = " << 0 - << " offset = " << M.offset << " data = ["; - MemRefDataPrinter::print(std::cout, M.data, 0, M.offset); - std::cout << "]" << std::endl; -} extern "C" void print_memref_vector_4x4xf32(StridedMemRefType, 2> *M) { - printMemRef(*M); + impl::printMemRef(*M); } -extern "C" void print_memref_f32(UnrankedMemRefType *M) { +extern "C" void print_memref_f32(UnrankedMemRefType *M) { printUnrankedMemRefMetaData(std::cout, *M); int rank = M->rank; void *ptr = M->descriptor; #define MEMREF_CASE(RANK) \ case RANK: \ - printMemRef(*(static_cast *>(ptr))); \ + impl::printMemRef(*(static_cast *>(ptr))); \ break switch (rank) { @@ -182,17 +49,17 @@ extern "C" void print_memref_f32(UnrankedMemRefType *M) { } extern "C" void print_memref_0d_f32(StridedMemRefType *M) { - printMemRef(*M); + impl::printMemRef(*M); } extern "C" void print_memref_1d_f32(StridedMemRefType *M) { - printMemRef(*M); + impl::printMemRef(*M); } extern "C" void print_memref_2d_f32(StridedMemRefType *M) { - printMemRef(*M); + impl::printMemRef(*M); } extern "C" void print_memref_3d_f32(StridedMemRefType *M) { - printMemRef(*M); + impl::printMemRef(*M); } extern "C" void print_memref_4d_f32(StridedMemRefType *M) { - printMemRef(*M); + impl::printMemRef(*M); } diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir index 099b856..2a56920 100644 --- a/mlir/test/mlir-cpu-runner/utils.mlir +++ b/mlir/test/mlir-cpu-runner/utils.mlir @@ -12,7 +12,9 @@ func @print_0d() { dealloc %A : memref return } -// PRINT-0D: Memref base@ = {{.*}} rank = 0 offset = 0 data = [2] +// PRINT-0D: Unranked Memref rank = 0 descriptor@ = {{.*}} +// PRINT-0D: Memref base@ = {{.*}} rank = 0 offset = 0 data = +// PRINT-0D: [2] func @print_1d() { %f = constant 2.00000e+00 : f32 -- 2.7.4