Exposed HierarchicalClusteringIndex in OpenCV wrapper
authorMarius Muja <mariusm@cs.ubc.ca>
Thu, 27 Sep 2012 10:58:17 +0000 (03:58 -0700)
committerMarius Muja <mariusm@cs.ubc.ca>
Thu, 27 Sep 2012 10:58:17 +0000 (03:58 -0700)
modules/flann/include/opencv2/flann/defines.h
modules/flann/include/opencv2/flann/hierarchical_clustering_index.h
modules/flann/include/opencv2/flann/miniflann.hpp
modules/flann/src/miniflann.cpp

index 178f07b..13833b3 100644 (file)
@@ -137,6 +137,7 @@ enum flann_distance_t
     FLANN_DIST_CS         = 7,
     FLANN_DIST_KULLBACK_LEIBLER  = 8,
     FLANN_DIST_KL                = 8,
+    FLANN_DIST_HAMMING          = 9,
 
     // deprecated constants, should use the FLANN_DIST_* ones instead
     EUCLIDEAN = 1,
index 3b61bd2..a7c6e1b 100644 (file)
@@ -619,13 +619,13 @@ private:
             if (checks>=maxChecks) {
                 if (result.full()) return;
             }
-            checks += node->size;
             for (int i=0; i<node->size; ++i) {
                 int index = node->indices[i];
                 if (!checked[index]) {
                     DistanceType dist = distance(dataset[index], vec, veclen_);
                     result.addPoint(dist, index);
                     checked[index] = true;
+                    ++checks;
                 }
             }
         }
index d7fd90f..04249bf 100644 (file)
@@ -100,6 +100,12 @@ struct CV_EXPORTS AutotunedIndexParams : public IndexParams
                          float memory_weight = 0, float sample_fraction = 0.1);
 };
     
+struct CV_EXPORTS HierarchicalClusteringIndexParams : public IndexParams
+{
+    HierarchicalClusteringIndexParams(int branching = 32, 
+                      cvflann::flann_centers_init_t centers_init = cvflann::FLANN_CENTERS_RANDOM, int trees = 4, int leaf_size = 100 );
+};
+
 struct CV_EXPORTS KMeansIndexParams : public IndexParams
 {
     KMeansIndexParams(int branching = 32, int iterations = 11,
index e1dbe5b..972ae72 100644 (file)
@@ -256,17 +256,33 @@ KMeansIndexParams::KMeansIndexParams(int branching, int iterations,
     // cluster boundary index. Used when searching the kmeans tree
     p["cb_index"] = cb_index;
 }
+
+HierarchicalClusteringIndexParams::HierarchicalClusteringIndexParams(int branching ,
+                                      flann_centers_init_t centers_init,
+                                      int trees, int leaf_size)
+{
+    ::cvflann::IndexParams& p = get_params(*this);
+    p["algorithm"] = FLANN_INDEX_HIERARCHICAL;
+    // The branching factor used in the hierarchical clustering
+    p["branching"] = branching;
+    // Algorithm used for picking the initial cluster centers
+    p["centers_init"] = centers_init;
+    // number of parallel trees to build
+    p["trees"] = trees;
+    // maximum leaf size
+    p["leaf_size"] = leaf_size;
+}
     
 LshIndexParams::LshIndexParams(int table_number, int key_size, int multi_probe_level)
 {
     ::cvflann::IndexParams& p = get_params(*this);
     p["algorithm"] = FLANN_INDEX_LSH;
     // The number of hash tables to use
-    p["table_number"] = (unsigned)table_number;
+    p["table_number"] = table_number;
     // The length of the key in the hash tables
-    p["key_size"] = (unsigned)key_size;
+    p["key_size"] = key_size;
     // Number of levels to use in multi-probe (0 for standard LSH)
-    p["multi_probe_level"] = (unsigned)multi_probe_level;
+    p["multi_probe_level"] = multi_probe_level;
 }    
     
 SavedIndexParams::SavedIndexParams(const std::string& _filename)
@@ -317,7 +333,6 @@ typedef ::cvflann::Hamming<uchar> HammingDistance;
 #else
 typedef ::cvflann::HammingLUT HammingDistance;
 #endif
-typedef ::cvflann::LshIndex<HammingDistance> LshIndex;
 
 Index::Index()
 {
@@ -351,14 +366,11 @@ void Index::build(InputArray _data, const IndexParams& params, flann_distance_t
     featureType = data.type();
     distType = _distType;
 
-    if( algo == FLANN_INDEX_LSH )
-    {
-        buildIndex_<HammingDistance, LshIndex>(index, data, params);
-        return;
-    }
-    
     switch( distType )
     {
+    case FLANN_DIST_HAMMING:
+        buildIndex< HammingDistance >(index, data, params);
+        break;
     case FLANN_DIST_L2:
         buildIndex< ::cvflann::L2<float> >(index, data, params);
         break;
@@ -406,15 +418,12 @@ void Index::release()
 {
     if( !index )
         return;
-    if( algo == FLANN_INDEX_LSH )
-    {
-        deleteIndex_<LshIndex>(index);
-    }
-    else
+        
+    switch( distType )
     {
-        CV_Assert( featureType == CV_32F );
-        switch( distType )
-        {
+        case FLANN_DIST_HAMMING:
+            deleteIndex< HammingDistance >(index);
+            break;
         case FLANN_DIST_L2:
             deleteIndex< ::cvflann::L2<float> >(index);
             break;
@@ -440,7 +449,6 @@ void Index::release()
 #endif
         default:
             CV_Error(CV_StsBadArg, "Unknown/unsupported distance type");
-        }
     }
     index = 0;
 }
@@ -539,18 +547,15 @@ void Index::knnSearch(InputArray _query, OutputArray _indices,
                OutputArray _dists, int knn, const SearchParams& params)
 {
     Mat query = _query.getMat(), indices, dists;
-    int dtype = algo == FLANN_INDEX_LSH ? CV_32S : CV_32F;
+    int dtype = distType == FLANN_DIST_HAMMING ? CV_32S : CV_32F;
     
     createIndicesDists( _indices, _dists, indices, dists, query.rows, knn, knn, dtype );
     
-    if( algo == FLANN_INDEX_LSH )
-    {
-        runKnnSearch_<HammingDistance, LshIndex>(index, query, indices, dists, knn, params);
-        return;
-    }
-    
     switch( distType )
     {
+    case FLANN_DIST_HAMMING:
+        runKnnSearch<HammingDistance>(index, query, indices, dists, knn, params);
+        break;
     case FLANN_DIST_L2:
         runKnnSearch< ::cvflann::L2<float> >(index, query, indices, dists, knn, params);
         break;
@@ -584,7 +589,7 @@ int Index::radiusSearch(InputArray _query, OutputArray _indices,
                         const SearchParams& params)
 {
     Mat query = _query.getMat(), indices, dists;
-    int dtype = algo == FLANN_INDEX_LSH ? CV_32S : CV_32F;
+    int dtype = distType == FLANN_DIST_HAMMING ? CV_32S : CV_32F;
     CV_Assert( maxResults > 0 );
     createIndicesDists( _indices, _dists, indices, dists, query.rows, maxResults, INT_MAX, dtype );
     
@@ -593,6 +598,9 @@ int Index::radiusSearch(InputArray _query, OutputArray _indices,
     
     switch( distType )
     {
+    case FLANN_DIST_HAMMING:
+        return runRadiusSearch< HammingDistance >(index, query, indices, dists, radius, params);
+
     case FLANN_DIST_L2:
         return runRadiusSearch< ::cvflann::L2<float> >(index, query, indices, dists, radius, params);
     case FLANN_DIST_L1:
@@ -647,15 +655,11 @@ void Index::save(const std::string& filename) const
     if (fout == NULL)
         CV_Error_( CV_StsError, ("Can not open file %s for writing FLANN index\n", filename.c_str()) );
     
-    if( algo == FLANN_INDEX_LSH )
-    {
-        saveIndex_<LshIndex>(this, index, fout);
-        fclose(fout);
-        return;
-    }
-    
     switch( distType )
     {
+    case FLANN_DIST_HAMMING:
+        saveIndex< HammingDistance >(this, index, fout);
+        break;
     case FLANN_DIST_L2:
         saveIndex< ::cvflann::L2<float> >(this, index, fout);
         break;
@@ -739,54 +743,51 @@ bool Index::load(InputArray _data, const std::string& filename)
         return false;
     }
     
-    if( !((algo == FLANN_INDEX_LSH && featureType == CV_8U) ||
-          (algo != FLANN_INDEX_LSH && featureType == CV_32F)) )
+    int idistType = 0;
+    ::cvflann::load_value(fin, idistType);
+    distType = (flann_distance_t)idistType;
+
+    if( !((distType == FLANN_DIST_HAMMING && featureType == CV_8U) || 
+          (distType != FLANN_DIST_HAMMING && featureType == CV_32F)) )
     {
         fprintf(stderr, "Reading FLANN index error: unsupported feature type %d for the index type %d\n", featureType, algo);
         fclose(fin);
         return false;
     }
-    int idistType = 0;
-    ::cvflann::load_value(fin, idistType);
-    distType = (flann_distance_t)idistType;
     
-    if( algo == FLANN_INDEX_LSH )
-    {
-        loadIndex_<HammingDistance, LshIndex>(this, index, data, fin);
-    }
-    else
+    switch( distType )
     {
-        switch( distType )
-        {
-        case FLANN_DIST_L2:
-            loadIndex< ::cvflann::L2<float> >(this, index, data, fin);
-            break;
-        case FLANN_DIST_L1:
-            loadIndex< ::cvflann::L1<float> >(this, index, data, fin);
-            break;
-    #if MINIFLANN_SUPPORT_EXOTIC_DISTANCE_TYPES
-        case FLANN_DIST_MAX:
-            loadIndex< ::cvflann::MaxDistance<float> >(this, index, data, fin);
-            break;
-        case FLANN_DIST_HIST_INTERSECT:
-            loadIndex< ::cvflann::HistIntersectionDistance<float> >(index, data, fin);
-            break;
-        case FLANN_DIST_HELLINGER:
-            loadIndex< ::cvflann::HellingerDistance<float> >(this, index, data, fin);
-            break;
-        case FLANN_DIST_CHI_SQUARE:
-            loadIndex< ::cvflann::ChiSquareDistance<float> >(this, index, data, fin);
-            break;
-        case FLANN_DIST_KL:
-            loadIndex< ::cvflann::KL_Divergence<float> >(this, index, data, fin);
-            break;
-    #endif
-        default:
-            fprintf(stderr, "Reading FLANN index error: unsupported distance type %d\n", distType);
-            ok = false;
-        }
+    case FLANN_DIST_HAMMING:
+        loadIndex< HammingDistance >(this, index, data, fin);
+        break;
+    case FLANN_DIST_L2:
+        loadIndex< ::cvflann::L2<float> >(this, index, data, fin);
+        break;
+    case FLANN_DIST_L1:
+        loadIndex< ::cvflann::L1<float> >(this, index, data, fin);
+        break;
+#if MINIFLANN_SUPPORT_EXOTIC_DISTANCE_TYPES
+    case FLANN_DIST_MAX:
+        loadIndex< ::cvflann::MaxDistance<float> >(this, index, data, fin);
+        break;
+    case FLANN_DIST_HIST_INTERSECT:
+        loadIndex< ::cvflann::HistIntersectionDistance<float> >(index, data, fin);
+        break;
+    case FLANN_DIST_HELLINGER:
+        loadIndex< ::cvflann::HellingerDistance<float> >(this, index, data, fin);
+        break;
+    case FLANN_DIST_CHI_SQUARE:
+        loadIndex< ::cvflann::ChiSquareDistance<float> >(this, index, data, fin);
+        break;
+    case FLANN_DIST_KL:
+        loadIndex< ::cvflann::KL_Divergence<float> >(this, index, data, fin);
+        break;
+#endif
+    default:
+        fprintf(stderr, "Reading FLANN index error: unsupported distance type %d\n", distType);
+        ok = false;
     }
-    
+
     if( fin )
         fclose(fin);
     return ok;