[PP GAPI] Split/Merge kernels; support for 8S, 16U, 16S, 32S (#2276)
authorAnton Potapov <anton.potapov@intel.com>
Tue, 22 Sep 2020 15:18:26 +0000 (18:18 +0300)
committerGitHub <noreply@github.com>
Tue, 22 Sep 2020 15:18:26 +0000 (18:18 +0300)
- introduced type_dispatch primitive
 - refactored SplitX and MergeX kernels to use type_dispatch
 - extended SplitX and MergeX to support 8S, 16U, 16S, 32S types

inference-engine/src/preprocessing/ie_preprocess_gapi_kernels.cpp
inference-engine/tests_deprecated/fluid_preproc/cpu/fluid_tests_cpu.cpp

index e104dcf..153eaed 100644 (file)
@@ -34,6 +34,7 @@
 #include <type_traits>
 #include <utility>
 #include <vector>
+#include <functional>
 
 #if defined(__GNUC__) && (__GNUC__ <= 5)
 #include <cmath>
@@ -431,13 +432,93 @@ void splitRow(const uint8_t* in, std::array<uint8_t*, chs>& outs, int length) {
     }
 }
 
+namespace {
+    template<typename type>
+    struct cv_type_to_depth;
+
+    template<> struct cv_type_to_depth<std::uint8_t>    { enum { depth = CV_8U  }; };
+    template<> struct cv_type_to_depth<std::int8_t>     { enum { depth = CV_8S  }; };
+    template<> struct cv_type_to_depth<std::uint16_t>   { enum { depth = CV_16U }; };
+    template<> struct cv_type_to_depth<std::int16_t>    { enum { depth = CV_16S }; };
+    template<> struct cv_type_to_depth<std::int32_t>    { enum { depth = CV_32S }; };
+    template<> struct cv_type_to_depth<float>           { enum { depth = CV_32F }; };
+
+    template<typename ... types>
+    struct typelist {};
+
+    template<typename type_list>
+    struct head;
+
+    template<template<typename ...> class list, typename head_t, typename ... types>
+    struct head<list<head_t, types...>> { using type = head_t;};
+
+    template<typename typelist>
+    using head_t = typename head<typelist>::type;
+
+    template<typename type>
+    struct type_to_type {};
+}
+
+namespace {
+    template <typename typelist>
+    struct type_dispatch_impl;
+
+    template <template<typename ...> class typelist, typename... type>
+    struct type_dispatch_impl<typelist<type...>> {
+
+        template <typename result_t, typename default_t, typename type_id_t, typename type_to_id_t, typename type_to_value_t>
+        static result_t dispatch(type_id_t type_id, type_to_id_t&& type_to_id, type_to_value_t&& type_to_value, default_t default_value) {
+            result_t res = default_value;
+
+            std::initializer_list<int> ({(type_id == type_to_id(type_to_type<type>{}) ? (res = type_to_value(type_to_type<type>{})), 0 : 0)...});
+            return res;
+        }
+    };
+}
+
+template<typename typelist, typename default_t, typename type_id_t, typename type_to_id_t, typename type_to_value_t,
+         typename result_t = decltype(std::declval<type_to_value_t>()(type_to_type<head_t<typelist>> {}))>
+result_t type_dispatch(type_id_t type_id, type_to_id_t&& type_to_id, type_to_value_t&& type_to_value, default_t default_value = {}){
+    return type_dispatch_impl<typelist>::template dispatch<result_t>(std::forward<type_id_t>(type_id),
+                                                                     std::forward<type_to_id_t>(type_to_id),
+                                                                     std::forward<type_to_value_t>(type_to_value),
+                                                                     std::forward<default_t>(default_value));
+}
+
+namespace {
+    struct cv_type_id {
+        template <typename type>
+        const int operator()(type_to_type<type> ){ return cv_type_to_depth<type>::depth;}
+    };
+
+}
+template<typename typelist>
+bool is_cv_type_in_list(const int type_id){
+    return type_dispatch<typelist>(type_id, cv_type_id{}, [](...){ return true;}, false);
+}
+
+namespace {
+    using merge_supported_types = typelist<uint8_t, int8_t, uint16_t, int16_t, int32_t, float>;
+
+    template<int chs>
+    struct typed_merge_row {
+        using p_f = void (*)(const std::array<const uint8_t*, chs>& ins, uint8_t* out, int length);
+
+        template <typename type>
+        p_f operator()(type_to_type<type> ){ return mergeRow<type,chs>;}
+    };
+}
+
 GAPI_FLUID_KERNEL(FMerge2, Merge2, false) {
     static const int LPI = 4;
     static const int Window = 1;
     static void run(const cv::gapi::fluid::View& a,
                     const cv::gapi::fluid::View& b,
                           cv::gapi::fluid::Buffer& out) {
-        const auto rowFunc = (a.meta().depth == CV_8U) ? &mergeRow<uint8_t, 2> : &mergeRow<float, 2>;
+
+        GAPI_DbgAssert(is_cv_type_in_list<merge_supported_types>(out.meta().depth));
+
+        const auto rowFunc = type_dispatch<merge_supported_types>(out.meta().depth, cv_type_id{}, typed_merge_row<2>{}, nullptr);
         for (int l = 0; l < out.lpi(); l++) {
             rowFunc({a.InLineB(l), b.InLineB(l)}, out.OutLineB(l), a.length());
         }
@@ -451,7 +532,10 @@ GAPI_FLUID_KERNEL(FMerge3, Merge3, false) {
                     const cv::gapi::fluid::View& b,
                     const cv::gapi::fluid::View& c,
                           cv::gapi::fluid::Buffer& out) {
-        const auto rowFunc = (a.meta().depth == CV_8U) ? &mergeRow<uint8_t, 3> : &mergeRow<float, 3>;
+
+        GAPI_DbgAssert(is_cv_type_in_list<merge_supported_types>(out.meta().depth));
+
+        const auto rowFunc = type_dispatch<merge_supported_types>(out.meta().depth, cv_type_id{}, typed_merge_row<3>{}, nullptr);
         for (int l = 0; l < out.lpi(); l++) {
             rowFunc({a.InLineB(l), b.InLineB(l), c.InLineB(l)}, out.OutLineB(l), a.length());
         }
@@ -466,13 +550,29 @@ GAPI_FLUID_KERNEL(FMerge4, Merge4, false) {
                     const cv::gapi::fluid::View& c,
                     const cv::gapi::fluid::View& d,
                           cv::gapi::fluid::Buffer& out) {
-        const auto rowFunc = (a.meta().depth == CV_8U) ? &mergeRow<uint8_t, 4> : &mergeRow<float, 4>;
+
+        GAPI_DbgAssert(is_cv_type_in_list<merge_supported_types>(out.meta().depth));
+
+        const auto rowFunc = type_dispatch<merge_supported_types>(out.meta().depth, cv_type_id{}, typed_merge_row<4>{}, nullptr);
         for (int l = 0; l < out.lpi(); l++) {
             rowFunc({a.InLineB(l), b.InLineB(l), c.InLineB(l), d.InLineB(l)}, out.OutLineB(l), a.length());
         }
     }
 };
 
+
+namespace {
+using split_supported_types = typelist<uint8_t, int8_t, uint16_t, int16_t, int32_t, float>;
+
+template<int chs>
+struct typed_split_row {
+    using p_f = void (*)(const uint8_t* in, std::array<uint8_t*, chs>& outs, int length);
+
+    template <typename type>
+    p_f operator()(type_to_type<type> ){ return splitRow<type,chs>;}
+};
+
+}
 GAPI_FLUID_KERNEL(FSplit2, Split2, false) {
     static const int LPI = 4;
     static const int Window = 1;
@@ -484,10 +584,9 @@ GAPI_FLUID_KERNEL(FSplit2, Split2, false) {
         GAPI_DbgAssert(1 == out2.meta().chan);
         GAPI_DbgAssert(in.meta().depth == out1.meta().depth);
         GAPI_DbgAssert(in.meta().depth == out2.meta().depth);
-        GAPI_DbgAssert(CV_8U == in.meta().depth || CV_32F == in.meta().depth);
-        const auto rowFunc = (in.meta().depth == CV_8U) ?
-                             &splitRow<uint8_t, 2> :
-                             &splitRow<float  , 2>;
+        GAPI_DbgAssert(is_cv_type_in_list<split_supported_types>(in.meta().depth));
+
+        const auto rowFunc = type_dispatch<split_supported_types>(in.meta().depth, cv_type_id{}, typed_split_row<2>{}, nullptr);
         for (int i = 0, lpi = out1.lpi(); i < lpi; i++) {
             std::array<uint8_t*, 2> outs = {out1.OutLineB(i), out2.OutLineB(i)};
             rowFunc(in.InLineB(i), outs, in.length());
@@ -509,10 +608,10 @@ GAPI_FLUID_KERNEL(FSplit3, Split3, false) {
         GAPI_DbgAssert(in.meta().depth == out1.meta().depth);
         GAPI_DbgAssert(in.meta().depth == out2.meta().depth);
         GAPI_DbgAssert(in.meta().depth == out3.meta().depth);
-        GAPI_DbgAssert(CV_8U == in.meta().depth || CV_32F == in.meta().depth);
-        const auto rowFunc = (in.meta().depth == CV_8U) ?
-                             &splitRow<uint8_t, 3> :
-                             &splitRow<float  , 3>;
+
+        GAPI_DbgAssert(is_cv_type_in_list<split_supported_types>(in.meta().depth));
+
+        const auto rowFunc = type_dispatch<split_supported_types>(in.meta().depth, cv_type_id{}, typed_split_row<3>{}, nullptr);
         for (int i = 0, lpi = out1.lpi(); i < lpi; i++) {
             std::array<uint8_t*, 3> outs = {out1.OutLineB(i), out2.OutLineB(i),
                                             out3.OutLineB(i)};
@@ -538,10 +637,9 @@ GAPI_FLUID_KERNEL(FSplit4, Split4, false) {
         GAPI_DbgAssert(in.meta().depth == out2.meta().depth);
         GAPI_DbgAssert(in.meta().depth == out3.meta().depth);
         GAPI_DbgAssert(in.meta().depth == out4.meta().depth);
-        GAPI_DbgAssert(CV_8U == in.meta().depth || CV_32F == in.meta().depth);
-        const auto rowFunc = (in.meta().depth == CV_8U) ?
-                             &splitRow<uint8_t, 4> :
-                             &splitRow<float  , 4>;
+        GAPI_DbgAssert(is_cv_type_in_list<split_supported_types>(in.meta().depth));
+
+        const auto rowFunc = type_dispatch<split_supported_types>(in.meta().depth, cv_type_id{}, typed_split_row<4>{}, nullptr);
         for (int i = 0, lpi = out1.lpi(); i < lpi; i++) {
             std::array<uint8_t*, 4> outs = {out1.OutLineB(i), out2.OutLineB(i),
                                             out3.OutLineB(i), out4.OutLineB(i)};
index acdac01..6164bb3 100644 (file)
@@ -132,7 +132,7 @@ INSTANTIATE_TEST_CASE_P(ResizeTestFluid_F32, ResizeTestGAPI,
 
 INSTANTIATE_TEST_CASE_P(SplitTestFluid, SplitTestGAPI,
                         Combine(Values(2, 3, 4),
-                                Values(CV_8U, CV_32F),
+                                Values(CV_8U, CV_8S, CV_16U, CV_16S, CV_32F, CV_32S),
                                 Values(TEST_SIZES),
                                 Values(0)));
 
@@ -144,7 +144,7 @@ INSTANTIATE_TEST_CASE_P(ChanToPlaneTestFluid, ChanToPlaneTestGAPI,
 
 INSTANTIATE_TEST_CASE_P(MergeTestFluid, MergeTestGAPI,
                         Combine(Values(2, 3, 4),
-                                Values(CV_8U, CV_32F),
+                                Values(CV_8U, CV_8S, CV_16U, CV_16S, CV_32F, CV_32S),
                                 Values(TEST_SIZES),
                                 Values(0)));