1 // Ceres Solver - A fast non-linear least squares minimizer
2 // Copyright 2015 Google Inc. All rights reserved.
3 // http://ceres-solver.org/
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are met:
8 // * Redistributions of source code must retain the above copyright notice,
9 // this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above copyright notice,
11 // this list of conditions and the following disclaimer in the documentation
12 // and/or other materials provided with the distribution.
13 // * Neither the name of Google Inc. nor the names of its contributors may be
14 // used to endorse or promote products derived from this software without
15 // specific prior written permission.
17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 // POSSIBILITY OF SUCH DAMAGE.
29 // Author: David Gallup (dgallup@google.com)
30 // Sameer Agarwal (sameeragarwal@google.com)
32 #include "ceres/canonical_views_clustering.h"
34 #include "ceres/collections_port.h"
35 #include "ceres/graph.h"
36 #include "ceres/internal/macros.h"
37 #include "ceres/map_util.h"
38 #include "glog/logging.h"
45 typedef HashMap<int, int> IntMap;
46 typedef HashSet<int> IntSet;
48 class CanonicalViewsClustering {
50 CanonicalViewsClustering() {}
52 // Compute the canonical views clustering of the vertices of the
53 // graph. centers will contain the vertices that are the identified
54 // as the canonical views/cluster centers, and membership is a map
55 // from vertices to cluster_ids. The i^th cluster center corresponds
56 // to the i^th cluster. It is possible depending on the
57 // configuration of the clustering algorithm that some of the
58 // vertices may not be assigned to any cluster. In this case they
59 // are assigned to a cluster with id = kInvalidClusterId.
60 void ComputeClustering(const CanonicalViewsClusteringOptions& options,
61 const WeightedGraph<int>& graph,
66 void FindValidViews(IntSet* valid_views) const;
67 double ComputeClusteringQualityDifference(const int candidate,
68 const vector<int>& centers) const;
69 void UpdateCanonicalViewAssignments(const int canonical_view);
70 void ComputeClusterMembership(const vector<int>& centers,
71 IntMap* membership) const;
73 CanonicalViewsClusteringOptions options_;
74 const WeightedGraph<int>* graph_;
75 // Maps a view to its representative canonical view (its cluster
77 IntMap view_to_canonical_view_;
78 // Maps a view to its similarity to its current cluster center.
79 HashMap<int, double> view_to_canonical_view_similarity_;
80 CERES_DISALLOW_COPY_AND_ASSIGN(CanonicalViewsClustering);
83 void ComputeCanonicalViewsClustering(
84 const CanonicalViewsClusteringOptions& options,
85 const WeightedGraph<int>& graph,
88 time_t start_time = time(NULL);
89 CanonicalViewsClustering cv;
90 cv.ComputeClustering(options, graph, centers, membership);
91 VLOG(2) << "Canonical views clustering time (secs): "
92 << time(NULL) - start_time;
95 // Implementation of CanonicalViewsClustering
96 void CanonicalViewsClustering::ComputeClustering(
97 const CanonicalViewsClusteringOptions& options,
98 const WeightedGraph<int>& graph,
100 IntMap* membership) {
102 CHECK_NOTNULL(centers)->clear();
103 CHECK_NOTNULL(membership)->clear();
107 FindValidViews(&valid_views);
108 while (valid_views.size() > 0) {
109 // Find the next best canonical view.
110 double best_difference = -std::numeric_limits<double>::max();
113 // TODO(sameeragarwal): Make this loop multi-threaded.
114 for (IntSet::const_iterator view = valid_views.begin();
115 view != valid_views.end();
117 const double difference =
118 ComputeClusteringQualityDifference(*view, *centers);
119 if (difference > best_difference) {
120 best_difference = difference;
125 CHECK_GT(best_difference, -std::numeric_limits<double>::max());
127 // Add canonical view if quality improves, or if minimum is not
128 // yet met, otherwise break.
129 if ((best_difference <= 0) &&
130 (centers->size() >= options_.min_views)) {
134 centers->push_back(best_view);
135 valid_views.erase(best_view);
136 UpdateCanonicalViewAssignments(best_view);
139 ComputeClusterMembership(*centers, membership);
142 // Return the set of vertices of the graph which have valid vertex
144 void CanonicalViewsClustering::FindValidViews(
145 IntSet* valid_views) const {
146 const IntSet& views = graph_->vertices();
147 for (IntSet::const_iterator view = views.begin();
150 if (graph_->VertexWeight(*view) != WeightedGraph<int>::InvalidWeight()) {
151 valid_views->insert(*view);
156 // Computes the difference in the quality score if 'candidate' were
157 // added to the set of canonical views.
158 double CanonicalViewsClustering::ComputeClusteringQualityDifference(
160 const vector<int>& centers) const {
163 options_.view_score_weight * graph_->VertexWeight(candidate);
165 // Compute how much the quality score changes if the candidate view
166 // was added to the list of canonical views and its nearest
167 // neighbors became members of its cluster.
168 const IntSet& neighbors = graph_->Neighbors(candidate);
169 for (IntSet::const_iterator neighbor = neighbors.begin();
170 neighbor != neighbors.end();
172 const double old_similarity =
173 FindWithDefault(view_to_canonical_view_similarity_, *neighbor, 0.0);
174 const double new_similarity = graph_->EdgeWeight(*neighbor, candidate);
175 if (new_similarity > old_similarity) {
176 difference += new_similarity - old_similarity;
180 // Number of views penalty.
181 difference -= options_.size_penalty_weight;
184 for (int i = 0; i < centers.size(); ++i) {
185 difference -= options_.similarity_penalty_weight *
186 graph_->EdgeWeight(centers[i], candidate);
192 // Reassign views if they're more similar to the new canonical view.
193 void CanonicalViewsClustering::UpdateCanonicalViewAssignments(
194 const int canonical_view) {
195 const IntSet& neighbors = graph_->Neighbors(canonical_view);
196 for (IntSet::const_iterator neighbor = neighbors.begin();
197 neighbor != neighbors.end();
199 const double old_similarity =
200 FindWithDefault(view_to_canonical_view_similarity_, *neighbor, 0.0);
201 const double new_similarity =
202 graph_->EdgeWeight(*neighbor, canonical_view);
203 if (new_similarity > old_similarity) {
204 view_to_canonical_view_[*neighbor] = canonical_view;
205 view_to_canonical_view_similarity_[*neighbor] = new_similarity;
210 // Assign a cluster id to each view.
211 void CanonicalViewsClustering::ComputeClusterMembership(
212 const vector<int>& centers,
213 IntMap* membership) const {
214 CHECK_NOTNULL(membership)->clear();
216 // The i^th cluster has cluster id i.
217 IntMap center_to_cluster_id;
218 for (int i = 0; i < centers.size(); ++i) {
219 center_to_cluster_id[centers[i]] = i;
222 static const int kInvalidClusterId = -1;
224 const IntSet& views = graph_->vertices();
225 for (IntSet::const_iterator view = views.begin();
228 IntMap::const_iterator it =
229 view_to_canonical_view_.find(*view);
230 int cluster_id = kInvalidClusterId;
231 if (it != view_to_canonical_view_.end()) {
232 cluster_id = FindOrDie(center_to_cluster_id, it->second);
235 InsertOrDie(membership, *view, cluster_id);
239 } // namespace internal