added gpu::graphcut for float sources (CUDA 5.0)
authorVladislav Vinogradov <vlad.vinogradov@itseez.com>
Tue, 31 Jul 2012 06:45:40 +0000 (10:45 +0400)
committerVladislav Vinogradov <vlad.vinogradov@itseez.com>
Tue, 31 Jul 2012 08:46:04 +0000 (12:46 +0400)
modules/gpu/src/graphcuts.cpp

index aba9ee3..0546ce3 100644 (file)
@@ -71,24 +71,32 @@ namespace
             return pState;\r
         }\r
 \r
-    private:        \r
+    private:\r
         NppiGraphcutState* pState;\r
     };\r
 }\r
 \r
 void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTransp, GpuMat& top, GpuMat& bottom, GpuMat& labels, GpuMat& buf, Stream& s)\r
 {\r
+#if (CUDA_VERSION < 5000)\r
+    CV_Assert(terminals.type() == CV_32S);\r
+#else\r
+    CV_Assert(terminals.type() == CV_32S || terminals.type() == CV_32F);\r
+#endif\r
+\r
     Size src_size = terminals.size();\r
 \r
-    CV_Assert(terminals.type() == CV_32S);\r
     CV_Assert(leftTransp.size() == Size(src_size.height, src_size.width));\r
-    CV_Assert(leftTransp.type() == CV_32S);\r
+    CV_Assert(leftTransp.type() == terminals.type());\r
+\r
     CV_Assert(rightTransp.size() == Size(src_size.height, src_size.width));\r
-    CV_Assert(rightTransp.type() == CV_32S);\r
+    CV_Assert(rightTransp.type() == terminals.type());\r
+\r
     CV_Assert(top.size() == src_size);\r
-    CV_Assert(top.type() == CV_32S);\r
+    CV_Assert(top.type() == terminals.type());\r
+\r
     CV_Assert(bottom.size() == src_size);\r
-    CV_Assert(bottom.type() == CV_32S);\r
+    CV_Assert(bottom.type() == terminals.type());\r
 \r
     labels.create(src_size, CV_8U);\r
 \r
@@ -106,44 +114,61 @@ void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTrans
     NppStreamHandler h(stream);\r
 \r
     NppiGraphcutStateHandler state(sznpp, buf.ptr<Npp8u>(), nppiGraphcutInitAlloc);\r
-    \r
+\r
+#if (CUDA_VERSION < 5000)\r
     nppSafeCall( nppiGraphcut_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(), top.ptr<Npp32s>(), bottom.ptr<Npp32s>(),\r
         static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );\r
+#else\r
+    if (terminals.type() == CV_32S)\r
+    {\r
+        nppSafeCall( nppiGraphcut_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(), top.ptr<Npp32s>(), bottom.ptr<Npp32s>(),\r
+            static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );\r
+    }\r
+    else\r
+    {\r
+        nppSafeCall( nppiGraphcut_32f8u(terminals.ptr<Npp32f>(), leftTransp.ptr<Npp32f>(), rightTransp.ptr<Npp32f>(), top.ptr<Npp32f>(), bottom.ptr<Npp32f>(),\r
+            static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );\r
+    }\r
+#endif\r
 \r
     if (stream == 0)\r
         cudaSafeCall( cudaDeviceSynchronize() );\r
 }\r
 \r
-void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTransp, GpuMat& top, GpuMat& topLeft, GpuMat& topRight, \r
+void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTransp, GpuMat& top, GpuMat& topLeft, GpuMat& topRight,\r
               GpuMat& bottom, GpuMat& bottomLeft, GpuMat& bottomRight, GpuMat& labels, GpuMat& buf, Stream& s)\r
 {\r
-    Size src_size = terminals.size();\r
-\r
+#if (CUDA_VERSION < 5000)\r
     CV_Assert(terminals.type() == CV_32S);\r
+#else\r
+    CV_Assert(terminals.type() == CV_32S || terminals.type() == CV_32F);\r
+#endif\r
+\r
+    Size src_size = terminals.size();\r
 \r
     CV_Assert(leftTransp.size() == Size(src_size.height, src_size.width));\r
-    CV_Assert(leftTransp.type() == CV_32S);\r
+    CV_Assert(leftTransp.type() == terminals.type());\r
 \r
     CV_Assert(rightTransp.size() == Size(src_size.height, src_size.width));\r
-    CV_Assert(rightTransp.type() == CV_32S);\r
+    CV_Assert(rightTransp.type() == terminals.type());\r
 \r
     CV_Assert(top.size() == src_size);\r
-    CV_Assert(top.type() == CV_32S);\r
+    CV_Assert(top.type() == terminals.type());\r
 \r
     CV_Assert(topLeft.size() == src_size);\r
-    CV_Assert(topLeft.type() == CV_32S);\r
+    CV_Assert(topLeft.type() == terminals.type());\r
 \r
     CV_Assert(topRight.size() == src_size);\r
-    CV_Assert(topRight.type() == CV_32S);\r
+    CV_Assert(topRight.type() == terminals.type());\r
 \r
     CV_Assert(bottom.size() == src_size);\r
-    CV_Assert(bottom.type() == CV_32S);\r
+    CV_Assert(bottom.type() == terminals.type());\r
 \r
     CV_Assert(bottomLeft.size() == src_size);\r
-    CV_Assert(bottomLeft.type() == CV_32S);\r
+    CV_Assert(bottomLeft.type() == terminals.type());\r
 \r
     CV_Assert(bottomRight.size() == src_size);\r
-    CV_Assert(bottomRight.type() == CV_32S);\r
+    CV_Assert(bottomRight.type() == terminals.type());\r
 \r
     labels.create(src_size, CV_8U);\r
 \r
@@ -161,11 +186,28 @@ void cv::gpu::graphcut(GpuMat& terminals, GpuMat& leftTransp, GpuMat& rightTrans
     NppStreamHandler h(stream);\r
 \r
     NppiGraphcutStateHandler state(sznpp, buf.ptr<Npp8u>(), nppiGraphcut8InitAlloc);\r
-    \r
-    nppSafeCall( nppiGraphcut8_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(), \r
+\r
+#if (CUDA_VERSION < 5000)\r
+    nppSafeCall( nppiGraphcut8_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(),\r
         top.ptr<Npp32s>(), topLeft.ptr<Npp32s>(), topRight.ptr<Npp32s>(),\r
         bottom.ptr<Npp32s>(), bottomLeft.ptr<Npp32s>(), bottomRight.ptr<Npp32s>(),\r
         static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );\r
+#else\r
+    if (terminals.type() == CV_32S)\r
+    {\r
+        nppSafeCall( nppiGraphcut8_32s8u(terminals.ptr<Npp32s>(), leftTransp.ptr<Npp32s>(), rightTransp.ptr<Npp32s>(),\r
+            top.ptr<Npp32s>(), topLeft.ptr<Npp32s>(), topRight.ptr<Npp32s>(),\r
+            bottom.ptr<Npp32s>(), bottomLeft.ptr<Npp32s>(), bottomRight.ptr<Npp32s>(),\r
+            static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );\r
+    }\r
+    else\r
+    {\r
+        nppSafeCall( nppiGraphcut8_32f8u(terminals.ptr<Npp32f>(), leftTransp.ptr<Npp32f>(), rightTransp.ptr<Npp32f>(),\r
+            top.ptr<Npp32f>(), topLeft.ptr<Npp32f>(), topRight.ptr<Npp32f>(),\r
+            bottom.ptr<Npp32f>(), bottomLeft.ptr<Npp32f>(), bottomRight.ptr<Npp32f>(),\r
+            static_cast<int>(terminals.step), static_cast<int>(leftTransp.step), sznpp, labels.ptr<Npp8u>(), static_cast<int>(labels.step), state) );\r
+    }\r
+#endif\r
 \r
     if (stream == 0)\r
         cudaSafeCall( cudaDeviceSynchronize() );\r