Fix emulator build error
[platform/framework/web/chromium-efl.git] / components / browsing_topics / annotator_impl.cc
1 // Copyright 2023 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "components/browsing_topics/annotator_impl.h"
6
7 #include "base/barrier_closure.h"
8 #include "base/containers/contains.h"
9 #include "base/dcheck_is_on.h"
10 #include "base/files/file_util.h"
11 #include "base/ranges/algorithm.h"
12 #include "base/strings/string_number_conversions.h"
13 #include "base/strings/string_util.h"
14 #include "base/task/sequenced_task_runner.h"
15 #include "components/optimization_guide/core/optimization_guide_model_provider.h"
16 #include "components/optimization_guide/proto/models.pb.h"
17 #include "components/optimization_guide/proto/page_topics_model_metadata.pb.h"
18 #include "components/optimization_guide/proto/page_topics_override_list.pb.h"
19 #include "third_party/abseil-cpp/absl/strings/ascii.h"
20 #include "third_party/blink/public/common/features.h"
21 #include "third_party/zlib/google/compression_utils.h"
22
23 namespace browsing_topics {
24
25 namespace {
26
27 // The ID of the NONE category in the taxonomy. This node always exists.
28 // Semantically, the none category is attached to data for which we can say
29 // with certainty that no single label in the taxonomy is appropriate.
30 const int32_t kNoneCategoryId = -2;
31
32 // The |kMeaninglessPrefixV2MinVersion| needed to support meaningless prefix v2.
33 // This should be compared with the version provided the model metadata.
34 const int32_t kMeaninglessPrefixV2MinVersion = 2;
35
36 const base::FilePath::CharType kOverrideListBasePath[] =
37     FILE_PATH_LITERAL("override_list.pb.gz");
38
39 // The result of an override list file load attempt. These values are logged to
40 // UMA histograms, do not change or reorder values. Make sure to update
41 // |BrowsingTopicsOverrideListFileLoadResult| in
42 // //tools/metrics/histograms/enums.xml.
43 enum class OverrideListFileLoadResult {
44   kUnknown = 0,
45   kSuccess = 1,
46   kCouldNotReadFile = 2,
47   kCouldNotUncompressFile = 3,
48   kCouldNotUnmarshalProtobuf = 4,
49   kMaxValue = kCouldNotUnmarshalProtobuf,
50 };
51
52 void RecordOverrideListFileLoadResult(OverrideListFileLoadResult result) {
53   base::UmaHistogramEnumeration("BrowsingTopics.OverrideList.FileLoadResult",
54                                 result);
55 }
56
57 absl::optional<std::unordered_map<std::string, std::vector<int32_t>>>
58 LoadOverrideListFromFile(const base::FilePath& path) {
59   if (!path.IsAbsolute() ||
60       path.BaseName() != base::FilePath(kOverrideListBasePath)) {
61     NOTREACHED();
62     // This is enforced by calling code, so no UMA in this case.
63     return absl::nullopt;
64   }
65
66   std::string file_contents;
67   if (!base::ReadFileToString(path, &file_contents)) {
68     RecordOverrideListFileLoadResult(
69         OverrideListFileLoadResult::kCouldNotReadFile);
70     return absl::nullopt;
71   }
72
73   if (!compression::GzipUncompress(file_contents, &file_contents)) {
74     RecordOverrideListFileLoadResult(
75         OverrideListFileLoadResult::kCouldNotUncompressFile);
76     return absl::nullopt;
77   }
78
79   optimization_guide::proto::PageTopicsOverrideList override_list_pb;
80   if (!override_list_pb.ParseFromString(file_contents)) {
81     RecordOverrideListFileLoadResult(
82         OverrideListFileLoadResult::kCouldNotUnmarshalProtobuf);
83     return absl::nullopt;
84   }
85
86   std::unordered_map<std::string, std::vector<int32_t>> override_list;
87   for (const optimization_guide::proto::PageTopicsOverrideEntry& entry :
88        override_list_pb.entries()) {
89     override_list.emplace(
90         entry.domain(), std::vector<int32_t>{entry.topics().topic_ids().begin(),
91                                              entry.topics().topic_ids().end()});
92   }
93
94   RecordOverrideListFileLoadResult(OverrideListFileLoadResult::kSuccess);
95   return override_list;
96 }
97
98 // Returns the length of the leading meaningless prefix of a host name as
99 // defined for the Topics Model.
100 //
101 // The full list of meaningless prefixes are:
102 //   ^(www[0-9]*|web|ftp|wap|home)$
103 //   ^(m|mobile|amp|w)$
104 int MeaninglessPrefixLength(const std::string& host) {
105   size_t len = host.size();
106
107   int dots = base::ranges::count(host, '.');
108   if (dots < 2) {
109     return 0;
110   }
111
112   if (len > 4 && base::StartsWith(host, "www")) {
113     // Check that all characters after "www" and up to first "." are
114     // digits.
115     for (size_t i = 3; i < len; ++i) {
116       if (host[i] == '.') {
117         return i + 1;
118       }
119       if (!absl::ascii_isdigit(static_cast<unsigned char>(host[i]))) {
120         return 0;
121       }
122     }
123   } else {
124     static const auto* kMeaninglessPrefixesLenMap = new std::set<std::string>(
125         {"web", "ftp", "wap", "home", "m", "w", "amp", "mobile"});
126
127     size_t prefix_len = host.find('.');
128     std::string prefix = host.substr(0, prefix_len);
129     const auto& it = kMeaninglessPrefixesLenMap->find(prefix);
130     if (it != kMeaninglessPrefixesLenMap->end() && len > it->size() + 1) {
131       return it->size() + 1;
132     }
133   }
134   return 0;
135 }
136
137 bool IsModelTaxonomyVersionSupported(int model_taxonomy_version) {
138   // Taxonomy version 1 is a special case, where the server would send nothing
139   // (i.e. use 0) for the taxonomy version.
140   if (blink::features::kBrowsingTopicsTaxonomyVersion.Get() == 1) {
141     return model_taxonomy_version == 0;
142   }
143
144   return model_taxonomy_version ==
145          blink::features::kBrowsingTopicsTaxonomyVersion.Get();
146 }
147
148 }  // namespace
149
150 AnnotatorImpl::AnnotatorImpl(
151     optimization_guide::OptimizationGuideModelProvider* model_provider,
152     scoped_refptr<base::SequencedTaskRunner> background_task_runner,
153     const absl::optional<optimization_guide::proto::Any>& model_metadata)
154     : BertModelHandler(
155           model_provider,
156           background_task_runner,
157           optimization_guide::proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2,
158           model_metadata),
159       background_task_runner_(background_task_runner) {
160   // Unloading the model is done via custom logic in this class.
161   SetShouldUnloadModelOnComplete(false);
162 }
163 AnnotatorImpl::~AnnotatorImpl() = default;
164
165 void AnnotatorImpl::NotifyWhenModelAvailable(base::OnceClosure callback) {
166   if (GetBrowsingTopicsModelInfo().has_value()) {
167     std::move(callback).Run();
168     return;
169   }
170   model_available_callbacks_.AddUnsafe(std::move(callback));
171 }
172
173 absl::optional<optimization_guide::ModelInfo>
174 AnnotatorImpl::GetBrowsingTopicsModelInfo() const {
175 #if DCHECK_IS_ON()
176   if (GetModelInfo()) {
177     DCHECK(GetModelInfo()->GetModelMetadata());
178     absl::optional<optimization_guide::proto::PageTopicsModelMetadata>
179         model_metadata = optimization_guide::ParsedAnyMetadata<
180             optimization_guide::proto::PageTopicsModelMetadata>(
181             *GetModelInfo()->GetModelMetadata());
182     DCHECK(model_metadata);
183     DCHECK(IsModelTaxonomyVersionSupported(model_metadata->taxonomy_version()));
184   }
185 #endif  // DCHECK_IS_ON()
186   return GetModelInfo();
187 }
188
189 void AnnotatorImpl::BatchAnnotate(BatchAnnotationCallback callback,
190                                   const std::vector<std::string>& inputs) {
191   if (override_list_file_path_.has_value() && !override_list_.has_value()) {
192     background_task_runner_->PostTaskAndReplyWithResult(
193         FROM_HERE,
194         base::BindOnce(&LoadOverrideListFromFile, *override_list_file_path_),
195         base::BindOnce(&AnnotatorImpl::OnOverrideListLoadAttemptDone,
196                        weak_ptr_factory_.GetWeakPtr(), std::move(callback),
197                        inputs));
198     return;
199   }
200   StartBatchAnnotate(std::move(callback), inputs);
201 }
202
203 void AnnotatorImpl::OnOverrideListLoadAttemptDone(
204     BatchAnnotationCallback callback,
205     const std::vector<std::string>& inputs,
206     absl::optional<std::unordered_map<std::string, std::vector<int32_t>>>
207         override_list) {
208   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
209   DCHECK(override_list_file_path_);
210
211   // If the override list is supposed to be used, it must be. Otherwise do not
212   // compute any annotations.
213   if (!override_list) {
214     std::vector<Annotation> annotations;
215     for (const std::string& input : inputs) {
216       annotations.emplace_back(input);
217     }
218     std::move(callback).Run(annotations);
219     return;
220   }
221   override_list_ = override_list;
222   StartBatchAnnotate(std::move(callback), inputs);
223 }
224
225 void AnnotatorImpl::StartBatchAnnotate(BatchAnnotationCallback callback,
226                                        const std::vector<std::string>& inputs) {
227   in_progess_batches_++;
228
229   std::unique_ptr<std::vector<Annotation>> annotations =
230       std::make_unique<std::vector<Annotation>>();
231   annotations->reserve(inputs.size());
232   for (const std::string& input : inputs) {
233     annotations->push_back(Annotation(input));
234   }
235   std::vector<Annotation>* annotations_ptr = annotations.get();
236
237   // Note on Lifetime: |annotations| is owned by |on_batch_complete_closure|
238   // which is guaranteed to not be called until the |barrier_closure| has been
239   // invoked |inputs.size()| times. Thus, passing raw pointers to the
240   // heap-allocated |annotations| is safe.
241
242   base::OnceClosure on_batch_complete_closure = base::BindOnce(
243       &AnnotatorImpl::OnBatchComplete, weak_ptr_factory_.GetWeakPtr(),
244       std::move(callback), std::move(annotations));
245
246   base::RepeatingClosure barrier_closure =
247       base::BarrierClosure(inputs.size(), std::move(on_batch_complete_closure));
248
249   for (size_t i = 0; i < inputs.size(); i++) {
250     AnnotateSingleInput(
251         /*single_input_done_signal=*/barrier_closure,
252         /*annotation=*/(annotations_ptr->data() + i));
253   }
254 }
255
256 void AnnotatorImpl::OnBatchComplete(
257     BatchAnnotationCallback callback,
258     std::unique_ptr<std::vector<Annotation>> annotations_ptr) {
259   std::move(callback).Run(*annotations_ptr);
260
261   // Only unload the model once all batches have been completed.
262   DCHECK_GT(in_progess_batches_, 0U);
263   in_progess_batches_--;
264   if (in_progess_batches_ == 0) {
265     UnloadModel();
266   }
267 }
268
269 std::string AnnotatorImpl::PreprocessHost(const std::string& host) const {
270   std::string output = base::ToLowerASCII(host);
271
272   // Meaningless prefix v2 is only supported/required for
273   // |kMeaninglessPrefixV2MinVersion| and on.
274   if (version_ >= kMeaninglessPrefixV2MinVersion) {
275     int idx = MeaninglessPrefixLength(output);
276     if (idx > 0) {
277       output = output.substr(idx);
278     }
279   } else {
280     // Strip the 'www.' if it exists.
281     if (base::StartsWith(output, "www.")) {
282       output = output.substr(4);
283     }
284   }
285
286   static const char kCharsToReplaceWithSpace[] = {'-', '_', '.', '+'};
287   for (char c : kCharsToReplaceWithSpace) {
288     std::replace(output.begin(), output.end(), c, ' ');
289   }
290
291   return output;
292 }
293
294 void AnnotatorImpl::AnnotateSingleInput(
295     base::OnceClosure single_input_done_signal,
296     Annotation* annotation) {
297   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
298   std::string processed_input = PreprocessHost(annotation->input);
299
300   if (override_list_) {
301     DCHECK(override_list_file_path_);
302     auto iter = override_list_->find(processed_input);
303
304     base::UmaHistogramBoolean("BrowsingTopics.OverrideList.UsedOverride",
305                               iter != override_list_->end());
306
307     if (iter != override_list_->end()) {
308       annotation->topics = iter->second;
309       std::move(single_input_done_signal).Run();
310       // |annotation| may have been destroyed, do not use it past here.
311       return;
312     }
313   }
314
315   ExecuteModelWithInput(
316       base::BindOnce(
317           &AnnotatorImpl::PostprocessCategoriesToBatchAnnotationResult,
318           weak_ptr_factory_.GetWeakPtr(), std::move(single_input_done_signal),
319           annotation),
320       processed_input);
321 }
322
323 void AnnotatorImpl::PostprocessCategoriesToBatchAnnotationResult(
324     base::OnceClosure single_input_done_signal,
325     Annotation* annotation,
326     const absl::optional<std::vector<tflite::task::core::Category>>& output) {
327   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
328   if (output) {
329     annotation->topics = ExtractCategoriesFromModelOutput(*output).value_or(
330         std::vector<int32_t>{});
331   }
332
333   std::move(single_input_done_signal).Run();
334   // |annotation| may have been destroyed, do not use it past here.
335 }
336
337 absl::optional<std::vector<int32_t>>
338 AnnotatorImpl::ExtractCategoriesFromModelOutput(
339     const std::vector<tflite::task::core::Category>& model_output) const {
340   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
341
342   absl::optional<optimization_guide::proto::PageTopicsModelMetadata>
343       model_metadata = ParsedSupportedFeaturesForLoadedModel<
344           optimization_guide::proto::PageTopicsModelMetadata>();
345   if (!model_metadata) {
346     return absl::nullopt;
347   }
348
349   absl::optional<std::string> visibility_category_name =
350       model_metadata->output_postprocessing_params().has_visibility_params() &&
351               model_metadata->output_postprocessing_params()
352                   .visibility_params()
353                   .has_category_name()
354           ? absl::make_optional(model_metadata->output_postprocessing_params()
355                                     .visibility_params()
356                                     .category_name())
357           : absl::nullopt;
358
359   std::vector<std::pair<int32_t, float>> category_candidates;
360
361   for (const auto& category : model_output) {
362     if (visibility_category_name &&
363         category.class_name == *visibility_category_name) {
364       continue;
365     }
366     // Assume everything else is for categories.
367     int category_id;
368     if (base::StringToInt(category.class_name, &category_id)) {
369       category_candidates.emplace_back(category_id,
370                                        static_cast<float>(category.score));
371     }
372   }
373
374   // Postprocess categories.
375   if (!model_metadata->output_postprocessing_params().has_category_params()) {
376     // No parameters for postprocessing, so just return.
377     return absl::nullopt;
378   }
379   const optimization_guide::proto::PageTopicsCategoryPostprocessingParams
380       category_params =
381           model_metadata->output_postprocessing_params().category_params();
382
383   // Determine the categories with the highest weights.
384   std::sort(
385       category_candidates.begin(), category_candidates.end(),
386       [](const std::pair<int32_t, float>& a,
387          const std::pair<int32_t, float>& b) { return a.second > b.second; });
388   size_t max_categories = static_cast<size_t>(category_params.max_categories());
389   float total_weight = 0.0;
390   float sum_positive_scores = 0.0;
391   absl::optional<std::pair<size_t, float>> none_idx_and_weight;
392   std::vector<std::pair<int32_t, float>> categories;
393   categories.reserve(max_categories);
394   for (size_t i = 0; i < category_candidates.size() && i < max_categories;
395        i++) {
396     std::pair<int32_t, float> candidate = category_candidates[i];
397     categories.push_back(candidate);
398     total_weight += candidate.second;
399
400     if (candidate.second > 0) {
401       sum_positive_scores += candidate.second;
402     }
403
404     if (candidate.first == kNoneCategoryId) {
405       none_idx_and_weight = std::make_pair(i, candidate.second);
406     }
407   }
408
409   // Prune out categories that do not meet the minimum threshold.
410   if (category_params.min_category_weight() > 0) {
411     base::EraseIf(categories, [&](const std::pair<int32_t, float>& category) {
412       return category.second < category_params.min_category_weight();
413     });
414   }
415
416   // Prune out none weights.
417   if (total_weight == 0) {
418     return absl::nullopt;
419   }
420   if (none_idx_and_weight) {
421     if ((none_idx_and_weight->second / total_weight) >
422         category_params.min_none_weight()) {
423       // None weight is too strong.
424       return absl::nullopt;
425     }
426     // None weight doesn't matter, so prune it out. Note that it may have
427     // already been removed above if its weight was below the category min.
428     base::EraseIf(categories, [&](const std::pair<int32_t, float>& category) {
429       return category.first == kNoneCategoryId;
430     });
431   }
432
433   // Normalize category weights.
434   float normalization_factor =
435       sum_positive_scores > 0 ? sum_positive_scores : 1.0;
436   base::EraseIf(categories, [&](const std::pair<int32_t, float>& category) {
437     return (category.second / normalization_factor) <
438            category_params.min_normalized_weight_within_top_n();
439   });
440
441   std::vector<int32_t> final_categories;
442   final_categories.reserve(categories.size());
443   for (const auto& category : categories) {
444     // We expect the weight to be between 0 and 1.
445     DCHECK(category.second >= 0.0 && category.second <= 1.0);
446     final_categories.emplace_back(category.first);
447   }
448   DCHECK_LE(final_categories.size(), max_categories);
449
450   return final_categories;
451 }
452
453 void AnnotatorImpl::UnloadModel() {
454   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
455
456   optimization_guide::BertModelHandler::UnloadModel();
457   override_list_ = absl::nullopt;
458 }
459
460 void AnnotatorImpl::OnModelUpdated(
461     optimization_guide::proto::OptimizationTarget optimization_target,
462     base::optional_ref<const optimization_guide::ModelInfo> model_info) {
463   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
464   // First invoke parent to update internal status.
465   optimization_guide::BertModelHandler::OnModelUpdated(optimization_target,
466                                                        model_info);
467
468   if (optimization_target !=
469       optimization_guide::proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2) {
470     return;
471   }
472
473   if (!model_info.has_value() || !model_info->GetModelMetadata()) {
474     return;
475   }
476
477   absl::optional<optimization_guide::proto::PageTopicsModelMetadata>
478       model_metadata = optimization_guide::ParsedAnyMetadata<
479           optimization_guide::proto::PageTopicsModelMetadata>(
480           *model_info->GetModelMetadata());
481   if (!model_metadata) {
482     return;
483   }
484
485   if (!IsModelTaxonomyVersionSupported(model_metadata->taxonomy_version())) {
486     // Also clear the model in the underlying model executor code so that it
487     // cannot be accidentally called on the wrong taxonomy version.
488     optimization_guide::BertModelHandler::OnModelUpdated(
489         optimization_guide::proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2,
490         absl::nullopt);
491     return;
492   }
493   version_ = model_metadata->version();
494
495   // New model, new override list.
496   override_list_file_path_ = absl::nullopt;
497   override_list_ = absl::nullopt;
498   for (const base::FilePath& path : model_info->GetAdditionalFiles()) {
499     DCHECK(path.IsAbsolute());
500     if (path.BaseName() == base::FilePath(kOverrideListBasePath)) {
501       override_list_file_path_ = path;
502       break;
503     }
504   }
505
506   // Run any callbacks that were waiting for an updated model.
507   //
508   // This should always be the last statement in this method, after all internal
509   // state has been updated because these callbacks may trigger an immediate
510   // annotation request.
511   model_available_callbacks_.Notify();
512 }
513
514 }  // namespace browsing_topics