From 24de676a6480aae151b1a2c80a483ec3ef98ab2a Mon Sep 17 00:00:00 2001 From: Anatoliy Talamanov Date: Fri, 6 Aug 2021 13:26:49 +0300 Subject: [PATCH] Merge pull request #20476 from TolyaTalamanov:at/support-unet-camvid-0001-segm-sample [G-API] Support postprocessing for not argmaxed outputs * Support postprocessing for not argmaxed outputs * Fix typo * Add assert * Remove static cast * CamelCast to snake_case * Fix windows warning * Add static_cast to uint8_t * Add const to variables --- modules/gapi/samples/semantic_segmentation.cpp | 83 +++++++++++++++++++------- 1 file changed, 63 insertions(+), 20 deletions(-) diff --git a/modules/gapi/samples/semantic_segmentation.cpp b/modules/gapi/samples/semantic_segmentation.cpp index 0a6e723..4cdb14c 100644 --- a/modules/gapi/samples/semantic_segmentation.cpp +++ b/modules/gapi/samples/semantic_segmentation.cpp @@ -47,6 +47,53 @@ std::string get_weights_path(const std::string &model_path) { CV_Assert(ext == ".xml"); return model_path.substr(0u, sz - EXT_LEN) + ".bin"; } + +void classesToColors(const cv::Mat &out_blob, + cv::Mat &mask_img) { + const int H = out_blob.size[0]; + const int W = out_blob.size[1]; + + mask_img.create(H, W, CV_8UC3); + GAPI_Assert(out_blob.type() == CV_8UC1); + const uint8_t* const classes = out_blob.ptr(); + + for (int rowId = 0; rowId < H; ++rowId) { + for (int colId = 0; colId < W; ++colId) { + uint8_t class_id = classes[rowId * W + colId]; + mask_img.at(rowId, colId) = + class_id < colors.size() + ? colors[class_id] + : cv::Vec3b{0, 0, 0}; // NB: sample supports 20 classes + } + } +} + +void probsToClasses(const cv::Mat& probs, cv::Mat& classes) { + const int C = probs.size[1]; + const int H = probs.size[2]; + const int W = probs.size[3]; + + classes.create(H, W, CV_8UC1); + GAPI_Assert(probs.depth() == CV_32F); + float* out_p = reinterpret_cast(probs.data); + uint8_t* classes_p = reinterpret_cast(classes.data); + + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + double max = 0; + int class_id = 0; + for (int c = 0; c < C; ++c) { + int idx = c * H * W + h * W + w; + if (out_p[idx] > max) { + max = out_p[idx]; + class_id = c; + } + } + classes_p[h * W + w] = static_cast(class_id); + } + } +} + } // anonymous namespace namespace custom { @@ -57,25 +104,21 @@ G_API_OP(PostProcessing, , "sample.custom.post_pro }; GAPI_OCV_KERNEL(OCVPostProcessing, PostProcessing) { - static void run(const cv::Mat &in, const cv::Mat &detected_classes, cv::Mat &out) { - // This kernel constructs output image by class table and colors vector - - // The semantic-segmentation-adas-0001 output a blob with the shape - // [B, C=1, H=1024, W=2048] - const int outHeight = 1024; - const int outWidth = 2048; - cv::Mat maskImg(outHeight, outWidth, CV_8UC3); - const int* const classes = detected_classes.ptr(); - for (int rowId = 0; rowId < outHeight; ++rowId) { - for (int colId = 0; colId < outWidth; ++colId) { - size_t classId = static_cast(classes[rowId * outWidth + colId]); - maskImg.at(rowId, colId) = - classId < colors.size() - ? colors[classId] - : cv::Vec3b{0, 0, 0}; // sample detects 20 classes - } + static void run(const cv::Mat &in, const cv::Mat &out_blob, cv::Mat &out) { + cv::Mat classes; + // NB: If output has more than single plane, it contains probabilities + // otherwise class id. + if (out_blob.size[1] > 1) { + probsToClasses(out_blob, classes); + } else { + out_blob.convertTo(classes, CV_8UC1); + classes = classes.reshape(1, out_blob.size[2]); } - cv::resize(maskImg, out, in.size()); + + cv::Mat mask_img; + classesToColors(classes, mask_img); + + cv::resize(mask_img, out, in.size()); const float blending = 0.3f; out = in * blending + out * (1 - blending); } @@ -104,8 +147,8 @@ int main(int argc, char *argv[]) { // Now build the graph cv::GMat in; - cv::GMat detected_classes = cv::gapi::infer(in); - cv::GMat out = custom::PostProcessing::on(in, detected_classes); + cv::GMat out_blob = cv::gapi::infer(in); + cv::GMat out = custom::PostProcessing::on(in, out_blob); cv::GStreamingCompiled pipeline = cv::GComputation(cv::GIn(in), cv::GOut(out)) .compileStreaming(cv::compile_args(kernels, networks)); -- 2.7.4