From 1a64ca328fdcd7c5f60e73b19938f945d0b2fc7b Mon Sep 17 00:00:00 2001 From: Junyan He Date: Tue, 24 Jun 2014 16:35:58 +0800 Subject: [PATCH] Add the support for vector type in printf. Signed-off-by: Junyan He Reviewed-by: Zhigang Gong --- backend/src/ir/printf.cpp | 144 +++++++++++++++++--------------- backend/src/ir/printf.hpp | 11 ++- backend/src/llvm/llvm_printf_parser.cpp | 83 ++++++++++++++++-- kernels/test_printf.cl | 10 ++- 4 files changed, 167 insertions(+), 81 deletions(-) diff --git a/backend/src/ir/printf.cpp b/backend/src/ir/printf.cpp index 58711e2..68b2ce4 100644 --- a/backend/src/ir/printf.cpp +++ b/backend/src/ir/printf.cpp @@ -84,8 +84,6 @@ namespace gbe str += num_str; } - // TODO: Handle the vector here. - switch (state.length_modifier) { case PRINTF_LM_HH: str += "hh"; @@ -97,7 +95,7 @@ namespace gbe str += "l"; break; case PRINTF_LM_HL: - str += "hl"; + str += ""; break; default: assert(state.length_modifier == PRINTF_LM_NONE); @@ -105,12 +103,12 @@ namespace gbe } #define PRINT_SOMETHING(target_ty, conv) do { \ - pf_str = pf_str + std::string(#conv); \ + if (!vec_i) \ + pf_str = pf_str + std::string(#conv); \ printf(pf_str.c_str(), \ ((target_ty *)((char *)buf_addr + slot.state->out_buf_sizeof_offset * \ global_wk_sz0 * global_wk_sz1 * global_wk_sz2)) \ - [k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i]); \ - pf_str = ""; \ + [(k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i) * vec_num + vec_i]);\ } while (0) @@ -126,80 +124,88 @@ namespace gbe for (i = 0; i < global_wk_sz0; i++) { for (j = 0; j < global_wk_sz1; j++) { for (k = 0; k < global_wk_sz2; k++) { - int flag = ((int *)index_addr)[stmt*global_wk_sz0*global_wk_sz1*global_wk_sz2 + k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i]; + + int flag = ((int *)index_addr)[stmt*global_wk_sz0*global_wk_sz1*global_wk_sz2 + + k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i]; if (flag) { - pf_str = ""; for (auto &slot : pf) { + pf_str = ""; + int vec_num; + if (slot.type == PRINTF_SLOT_TYPE_STRING) { - pf_str = pf_str + std::string(slot.str); + printf("%s", slot.str); continue; } assert(slot.type == PRINTF_SLOT_TYPE_STATE); generatePrintfFmtString(*slot.state, pf_str); - switch (slot.state->conversion_specifier) { - case PRINTF_CONVERSION_D: - case PRINTF_CONVERSION_I: - PRINT_SOMETHING(int, d); - break; - - case PRINTF_CONVERSION_O: - PRINT_SOMETHING(int, o); - break; - case PRINTF_CONVERSION_U: - PRINT_SOMETHING(int, u); - break; - case PRINTF_CONVERSION_X: - PRINT_SOMETHING(int, X); - break; - case PRINTF_CONVERSION_x: - PRINT_SOMETHING(int, x); - break; - - case PRINTF_CONVERSION_C: - PRINT_SOMETHING(char, c); - break; - - case PRINTF_CONVERSION_F: - PRINT_SOMETHING(float, F); - break; - case PRINTF_CONVERSION_f: - PRINT_SOMETHING(float, f); - break; - case PRINTF_CONVERSION_E: - PRINT_SOMETHING(float, E); - break; - case PRINTF_CONVERSION_e: - PRINT_SOMETHING(float, e); - break; - case PRINTF_CONVERSION_G: - PRINT_SOMETHING(float, G); - break; - case PRINTF_CONVERSION_g: - PRINT_SOMETHING(float, g); - break; - case PRINTF_CONVERSION_A: - PRINT_SOMETHING(float, A); - break; - case PRINTF_CONVERSION_a: - PRINT_SOMETHING(float, a); - break; - - case PRINTF_CONVERSION_S: - pf_str = pf_str + "s"; - printf(pf_str.c_str(), slot.state->str.c_str()); - pf_str = ""; - break; - - default: - assert(0); - return; + vec_num = slot.state->vector_n > 0 ? slot.state->vector_n : 1; + + for (int vec_i = 0; vec_i < vec_num; vec_i++) { + if (vec_i) + printf(","); + + switch (slot.state->conversion_specifier) { + case PRINTF_CONVERSION_D: + case PRINTF_CONVERSION_I: + PRINT_SOMETHING(int, d); + break; + + case PRINTF_CONVERSION_O: + PRINT_SOMETHING(int, o); + break; + case PRINTF_CONVERSION_U: + PRINT_SOMETHING(int, u); + break; + case PRINTF_CONVERSION_X: + PRINT_SOMETHING(int, X); + break; + case PRINTF_CONVERSION_x: + PRINT_SOMETHING(int, x); + break; + + case PRINTF_CONVERSION_C: + PRINT_SOMETHING(char, c); + break; + + case PRINTF_CONVERSION_F: + PRINT_SOMETHING(float, F); + break; + case PRINTF_CONVERSION_f: + PRINT_SOMETHING(float, f); + break; + case PRINTF_CONVERSION_E: + PRINT_SOMETHING(float, E); + break; + case PRINTF_CONVERSION_e: + PRINT_SOMETHING(float, e); + break; + case PRINTF_CONVERSION_G: + PRINT_SOMETHING(float, G); + break; + case PRINTF_CONVERSION_g: + PRINT_SOMETHING(float, g); + break; + case PRINTF_CONVERSION_A: + PRINT_SOMETHING(float, A); + break; + case PRINTF_CONVERSION_a: + PRINT_SOMETHING(float, a); + break; + + case PRINTF_CONVERSION_S: + pf_str = pf_str + "s"; + printf(pf_str.c_str(), slot.state->str.c_str()); + break; + + default: + assert(0); + return; + } } - } - if (pf_str != "") { - printf("%s", pf_str.c_str()); + pf_str = ""; } } } diff --git a/backend/src/ir/printf.hpp b/backend/src/ir/printf.hpp index 8b759d4..680b8e6 100644 --- a/backend/src/ir/printf.hpp +++ b/backend/src/ir/printf.hpp @@ -182,6 +182,13 @@ namespace gbe uint32_t getPrintfBufferElementSize(uint32_t i) { PrintfSlot* slot = slots[i]; + int vec_num = 1; + if (slot->state->vector_n > 0) { + vec_num = slot->state->vector_n; + } + + assert(vec_num > 0 && vec_num <= 16); + switch (slot->state->conversion_specifier) { case PRINTF_CONVERSION_I: case PRINTF_CONVERSION_D: @@ -191,7 +198,7 @@ namespace gbe case PRINTF_CONVERSION_x: /* Char will be aligned to sizeof(int) here. */ case PRINTF_CONVERSION_C: - return (uint32_t)sizeof(int); + return (uint32_t)(sizeof(int) * vec_num); case PRINTF_CONVERSION_E: case PRINTF_CONVERSION_e: case PRINTF_CONVERSION_F: @@ -200,7 +207,7 @@ namespace gbe case PRINTF_CONVERSION_g: case PRINTF_CONVERSION_A: case PRINTF_CONVERSION_a: - return (uint32_t)sizeof(float); + return (uint32_t)(sizeof(float) * vec_num); case PRINTF_CONVERSION_S: return (uint32_t)0; default: diff --git a/backend/src/llvm/llvm_printf_parser.cpp b/backend/src/llvm/llvm_printf_parser.cpp index dcad036..ff8e259 100644 --- a/backend/src/llvm/llvm_printf_parser.cpp +++ b/backend/src/llvm/llvm_printf_parser.cpp @@ -98,7 +98,7 @@ namespace gbe return -1; #define FMT_PLUS_PLUS do { \ - if (fmt + 1 < end) fmt++; \ + if (fmt + 1 <= end) fmt++; \ else { \ printf("Error, line: %d, fmt > end\n", __LINE__); \ return -1; \ @@ -627,20 +627,21 @@ error: conversion need to be applied. */ switch (arg->getType()->getTypeID()) { case Type::IntegerTyID: { + bool sign = false; switch (slot.state->conversion_specifier) { case PRINTF_CONVERSION_I: case PRINTF_CONVERSION_D: - /* Int to Int, just store. */ - dst_type = Type::getInt32PtrTy(module->getContext(), 1); - sizeof_size = sizeof(int); - return true; - + sign = true; case PRINTF_CONVERSION_O: case PRINTF_CONVERSION_U: case PRINTF_CONVERSION_x: case PRINTF_CONVERSION_X: - /* To uint, add a conversion. */ - arg = builder->CreateIntCast(arg, Type::getInt32Ty(module->getContext()), true); + /* If the bits change, we need to consider the signed. */ + if (arg->getType() != Type::getInt32Ty(module->getContext())) { + arg = builder->CreateIntCast(arg, Type::getInt32Ty(module->getContext()), sign); + } + + /* Int to Int, just store. */ dst_type = Type::getInt32PtrTy(module->getContext(), 1); sizeof_size = sizeof(int); return true; @@ -745,6 +746,72 @@ error: } } + case Type::VectorTyID: { + Type* vect_type = arg->getType(); + Type* elt_type = vect_type->getVectorElementType(); + int vec_num = vect_type->getVectorNumElements(); + bool sign = false; + + if (vec_num != slot.state->vector_n) { + return false; + } + + switch (slot.state->conversion_specifier) { + case PRINTF_CONVERSION_I: + case PRINTF_CONVERSION_D: + sign = true; + case PRINTF_CONVERSION_O: + case PRINTF_CONVERSION_U: + case PRINTF_CONVERSION_x: + case PRINTF_CONVERSION_X: + if (elt_type->getTypeID() != Type::IntegerTyID) + return false; + + /* If the bits change, we need to consider the signed. */ + if (elt_type != Type::getInt32Ty(elt_type->getContext())) { + Value *II = NULL; + for (int i = 0; i < vec_num; i++) { + Value *vec = II ? II : UndefValue::get(VectorType::get(Type::getInt32Ty(elt_type->getContext()), vec_num)); + Value *cv = ConstantInt::get(Type::getInt32Ty(elt_type->getContext()), i); + Value *org = builder->CreateExtractElement(arg, cv); + Value *cvt = builder->CreateIntCast(org, Type::getInt32Ty(module->getContext()), sign); + II = builder->CreateInsertElement(vec, cvt, cv); + } + arg = II; + } + + dst_type = arg->getType()->getPointerTo(1); + sizeof_size = sizeof(int) * vec_num; + return true; + + case PRINTF_CONVERSION_F: + case PRINTF_CONVERSION_f: + case PRINTF_CONVERSION_E: + case PRINTF_CONVERSION_e: + case PRINTF_CONVERSION_G: + case PRINTF_CONVERSION_g: + case PRINTF_CONVERSION_A: + case PRINTF_CONVERSION_a: + if (elt_type->getTypeID() != Type::DoubleTyID && elt_type->getTypeID() != Type::FloatTyID) + return false; + + if (elt_type->getTypeID() != Type::FloatTyID) { + Value *II = NULL; + for (int i = 0; i < vec_num; i++) { + Value *vec = II ? II : UndefValue::get(VectorType::get(Type::getFloatTy(elt_type->getContext()), vec_num)); + Value *cv = ConstantInt::get(Type::getInt32Ty(elt_type->getContext()), i); + Value *org = builder->CreateExtractElement(arg, cv); + Value* cvt = builder->CreateFPCast(org, Type::getFloatTy(module->getContext())); + II = builder->CreateInsertElement(vec, cvt, cv); + } + arg = II; + } + } + dst_type = arg->getType()->getPointerTo(1); + sizeof_size = sizeof(int) * vec_num; + return true; + } + default: return false; } diff --git a/kernels/test_printf.cl b/kernels/test_printf.cl index c21ee98..84bb478 100644 --- a/kernels/test_printf.cl +++ b/kernels/test_printf.cl @@ -6,6 +6,10 @@ test_printf(void) int z = (int)get_global_id(2); uint a = 'x'; float f = 5.0f; + int3 vec; + vec.x = x; + vec.y = y; + vec.z = z; if (x == 0 && y == 0 && z == 0) { printf("--- Welcome to the printf test of %s ---\n", "Intel Beignet"); @@ -16,8 +20,8 @@ test_printf(void) if (x % 15 == 0) if (y % 3 == 0) if (z % 7 == 0) - printf("######## global_id(x, y, z) = (%d, %d, %d), global_size(d0, d1, d3) = (%d, %d, %d)\n", - x, y, z, get_global_size(0), get_global_size(1), get_global_size(2)); + printf("######## global_id(x, y, z) = %v3d, global_size(d0, d1, d3) = (%d, %d, %d)\n", + vec, get_global_size(0), get_global_size(1), get_global_size(2)); if (x == 1) if (y == 0) { @@ -26,7 +30,9 @@ test_printf(void) else printf("#### output a float to int is %d\n", f); } + if (x == 0 && y == 0 && z == 0) { printf("--- End to the printf test ---\n"); } + } -- 2.7.4