Update Torch ENet sample
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 24 Nov 2017 18:20:18 +0000 (21:20 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 24 Nov 2017 18:20:18 +0000 (21:20 +0300)
samples/dnn/torch_enet.cpp

index 4f9ad21..6101d17 100644 (file)
@@ -20,21 +20,35 @@ const String keys =
                        "https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa }"
         "{model m   || path to Torch .net model file (model_best.net) }"
         "{image i   || path to image file }"
-        "{c_names c || path to file with classnames for channels (optional, categories.txt) }"
         "{result r  || path to save output blob (optional, binary format, NCHW order) }"
         "{show s    || whether to show all output channels or not}"
-        "{o_blob    || output blob's name. If empty, last blob's name in net is used}"
-        ;
+        "{o_blob    || output blob's name. If empty, last blob's name in net is used}";
 
-static void colorizeSegmentation(const Mat &score, Mat &segm,
-                                 Mat &legend, vector<String> &classNames, vector<Vec3b> &colors);
-static vector<Vec3b> readColors(const String &filename, vector<String>& classNames);
+static const int kNumClasses = 20;
+
+static const String classes[] = {
+    "Background", "Road", "Sidewalk", "Building", "Wall", "Fence", "Pole",
+    "TrafficLight", "TrafficSign", "Vegetation", "Terrain", "Sky", "Person",
+    "Rider", "Car", "Truck", "Bus", "Train", "Motorcycle", "Bicycle"
+};
+
+static const Vec3b colors[] = {
+    Vec3b(0, 0, 0), Vec3b(244, 126, 205), Vec3b(254, 83, 132), Vec3b(192, 200, 189),
+    Vec3b(50, 56, 251), Vec3b(65, 199, 228), Vec3b(240, 178, 193), Vec3b(201, 67, 188),
+    Vec3b(85, 32, 33), Vec3b(116, 25, 18), Vec3b(162, 33, 72), Vec3b(101, 150, 210),
+    Vec3b(237, 19, 16), Vec3b(149, 197, 72), Vec3b(80, 182, 21), Vec3b(141, 5, 207),
+    Vec3b(189, 156, 39), Vec3b(235, 170, 186), Vec3b(133, 109, 144), Vec3b(231, 160, 96)
+};
+
+static void showLegend();
+
+static void colorizeSegmentation(const Mat &score, Mat &segm);
 
 int main(int argc, char **argv)
 {
     CommandLineParser parser(argc, argv, keys);
 
-    if (parser.has("help"))
+    if (parser.has("help") || argc == 1)
     {
         parser.printMessage();
         return 0;
@@ -49,7 +63,6 @@ int main(int argc, char **argv)
         return 0;
     }
 
-    String classNamesFile = parser.get<String>("c_names");
     String resultFile = parser.get<String>("result");
 
     //! [Read model and initialize network]
@@ -63,17 +76,11 @@ int main(int argc, char **argv)
         exit(-1);
     }
 
-    Size origSize = img.size();
-    Size inputImgSize = cv::Size(1024, 512);
-
-    if (inputImgSize != origSize)
-        resize(img, img, inputImgSize);       //Resize image to input size
-
-    Mat inputBlob = blobFromImage(img, 1./255);   //Convert Mat to image batch
+    Mat inputBlob = blobFromImage(img, 1./255, Size(1024, 512), Scalar(), true, false);   //Convert Mat to image batch
     //! [Prepare blob]
 
     //! [Set input blob]
-    net.setInput(inputBlob, "");        //set the network input
+    net.setInput(inputBlob);        //set the network input
     //! [Set input blob]
 
     TickMeter tm;
@@ -102,41 +109,47 @@ int main(int argc, char **argv)
 
     if (parser.has("show"))
     {
-        std::vector<String> classNames;
-        vector<cv::Vec3b> colors;
-        if(!classNamesFile.empty()) {
-            colors = readColors(classNamesFile, classNames);
-        }
-        Mat segm, legend;
-        colorizeSegmentation(result, segm, legend, classNames, colors);
+        Mat segm, show;
+        colorizeSegmentation(result, segm);
+        showLegend();
 
-        Mat show;
+        cv::resize(segm, segm, img.size(), 0, 0, cv::INTER_NEAREST);
         addWeighted(img, 0.1, segm, 0.9, 0.0, show);
 
-        cv::resize(show, show, origSize, 0, 0, cv::INTER_NEAREST);
         imshow("Result", show);
-        if(classNames.size())
-            imshow("Legend", legend);
         waitKey();
     }
-
     return 0;
 } //main
 
-static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vector<String> &classNames, vector<Vec3b> &colors)
+static void showLegend()
+{
+    static const int kBlockHeight = 30;
+
+    cv::Mat legend(kBlockHeight * kNumClasses, 200, CV_8UC3);
+    for(int i = 0; i < kNumClasses; i++)
+    {
+        cv::Mat block = legend.rowRange(i * kBlockHeight, (i + 1) * kBlockHeight);
+        block.setTo(colors[i]);
+        putText(block, classes[i], Point(0, kBlockHeight / 2), FONT_HERSHEY_SIMPLEX, 0.5, Vec3b(255, 255, 255));
+    }
+    imshow("Legend", legend);
+}
+
+static void colorizeSegmentation(const Mat &score, Mat &segm)
 {
     const int rows = score.size[2];
     const int cols = score.size[3];
     const int chns = score.size[1];
 
-    cv::Mat maxCl(rows, cols, CV_8UC1);
-    cv::Mat maxVal(rows, cols, CV_32FC1);
-    for (int ch = 0; ch < chns; ch++)
+    Mat maxCl = Mat::zeros(rows, cols, CV_8UC1);
+    Mat maxVal(rows, cols, CV_32FC1, score.data);
+    for (int ch = 1; ch < chns; ch++)
     {
         for (int row = 0; row < rows; row++)
         {
             const float *ptrScore = score.ptr<float>(0, ch, row);
-            uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
+            uint8_t *ptrMaxCl = maxCl.ptr<uint8_t>(row);
             float *ptrMaxVal = maxVal.ptr<float>(row);
             for (int col = 0; col < cols; col++)
             {
@@ -153,57 +166,10 @@ static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vecto
     for (int row = 0; row < rows; row++)
     {
         const uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
-        cv::Vec3b *ptrSegm = segm.ptr<cv::Vec3b>(row);
+        Vec3b *ptrSegm = segm.ptr<Vec3b>(row);
         for (int col = 0; col < cols; col++)
         {
             ptrSegm[col] = colors[ptrMaxCl[col]];
         }
     }
-
-    if (classNames.size() == colors.size())
-    {
-        int blockHeight = 30;
-        legend.create(blockHeight*(int)classNames.size(), 200, CV_8UC3);
-        for(int i = 0; i < (int)classNames.size(); i++)
-        {
-            cv::Mat block = legend.rowRange(i*blockHeight, (i+1)*blockHeight);
-            block = colors[i];
-            putText(block, classNames[i], Point(0, blockHeight/2), FONT_HERSHEY_SIMPLEX, 0.5, Scalar());
-        }
-    }
-}
-
-static vector<Vec3b> readColors(const String &filename, vector<String>& classNames)
-{
-    vector<cv::Vec3b> colors;
-    classNames.clear();
-
-    ifstream fp(filename.c_str());
-    if (!fp.is_open())
-    {
-        cerr << "File with colors not found: " << filename << endl;
-        exit(-1);
-    }
-
-    string line;
-    while (!fp.eof())
-    {
-        getline(fp, line);
-        if (line.length())
-        {
-            stringstream ss(line);
-
-            string name; ss >> name;
-            int temp;
-            cv::Vec3b color;
-            ss >> temp; color[0] = (uchar)temp;
-            ss >> temp; color[1] = (uchar)temp;
-            ss >> temp; color[2] = (uchar)temp;
-            classNames.push_back(name);
-            colors.push_back(color);
-        }
-    }
-
-    fp.close();
-    return colors;
 }