1 // Copyright 2023 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "components/browsing_topics/annotator_impl.h"
7 #include "base/barrier_closure.h"
8 #include "base/containers/contains.h"
9 #include "base/dcheck_is_on.h"
10 #include "base/files/file_util.h"
11 #include "base/ranges/algorithm.h"
12 #include "base/strings/string_number_conversions.h"
13 #include "base/strings/string_util.h"
14 #include "base/task/sequenced_task_runner.h"
15 #include "components/optimization_guide/core/optimization_guide_model_provider.h"
16 #include "components/optimization_guide/proto/models.pb.h"
17 #include "components/optimization_guide/proto/page_topics_model_metadata.pb.h"
18 #include "components/optimization_guide/proto/page_topics_override_list.pb.h"
19 #include "third_party/abseil-cpp/absl/strings/ascii.h"
20 #include "third_party/blink/public/common/features.h"
21 #include "third_party/zlib/google/compression_utils.h"
23 namespace browsing_topics {
27 // The ID of the NONE category in the taxonomy. This node always exists.
28 // Semantically, the none category is attached to data for which we can say
29 // with certainty that no single label in the taxonomy is appropriate.
30 const int32_t kNoneCategoryId = -2;
32 // The |kMeaninglessPrefixV2MinVersion| needed to support meaningless prefix v2.
33 // This should be compared with the version provided the model metadata.
34 const int32_t kMeaninglessPrefixV2MinVersion = 2;
36 const base::FilePath::CharType kOverrideListBasePath[] =
37 FILE_PATH_LITERAL("override_list.pb.gz");
39 // The result of an override list file load attempt. These values are logged to
40 // UMA histograms, do not change or reorder values. Make sure to update
41 // |BrowsingTopicsOverrideListFileLoadResult| in
42 // //tools/metrics/histograms/enums.xml.
43 enum class OverrideListFileLoadResult {
46 kCouldNotReadFile = 2,
47 kCouldNotUncompressFile = 3,
48 kCouldNotUnmarshalProtobuf = 4,
49 kMaxValue = kCouldNotUnmarshalProtobuf,
52 void RecordOverrideListFileLoadResult(OverrideListFileLoadResult result) {
53 base::UmaHistogramEnumeration("BrowsingTopics.OverrideList.FileLoadResult",
57 absl::optional<std::unordered_map<std::string, std::vector<int32_t>>>
58 LoadOverrideListFromFile(const base::FilePath& path) {
59 if (!path.IsAbsolute() ||
60 path.BaseName() != base::FilePath(kOverrideListBasePath)) {
62 // This is enforced by calling code, so no UMA in this case.
66 std::string file_contents;
67 if (!base::ReadFileToString(path, &file_contents)) {
68 RecordOverrideListFileLoadResult(
69 OverrideListFileLoadResult::kCouldNotReadFile);
73 if (!compression::GzipUncompress(file_contents, &file_contents)) {
74 RecordOverrideListFileLoadResult(
75 OverrideListFileLoadResult::kCouldNotUncompressFile);
79 optimization_guide::proto::PageTopicsOverrideList override_list_pb;
80 if (!override_list_pb.ParseFromString(file_contents)) {
81 RecordOverrideListFileLoadResult(
82 OverrideListFileLoadResult::kCouldNotUnmarshalProtobuf);
86 std::unordered_map<std::string, std::vector<int32_t>> override_list;
87 for (const optimization_guide::proto::PageTopicsOverrideEntry& entry :
88 override_list_pb.entries()) {
89 override_list.emplace(
90 entry.domain(), std::vector<int32_t>{entry.topics().topic_ids().begin(),
91 entry.topics().topic_ids().end()});
94 RecordOverrideListFileLoadResult(OverrideListFileLoadResult::kSuccess);
98 // Returns the length of the leading meaningless prefix of a host name as
99 // defined for the Topics Model.
101 // The full list of meaningless prefixes are:
102 // ^(www[0-9]*|web|ftp|wap|home)$
103 // ^(m|mobile|amp|w)$
104 int MeaninglessPrefixLength(const std::string& host) {
105 size_t len = host.size();
107 int dots = base::ranges::count(host, '.');
112 if (len > 4 && base::StartsWith(host, "www")) {
113 // Check that all characters after "www" and up to first "." are
115 for (size_t i = 3; i < len; ++i) {
116 if (host[i] == '.') {
119 if (!absl::ascii_isdigit(static_cast<unsigned char>(host[i]))) {
124 static const auto* kMeaninglessPrefixesLenMap = new std::set<std::string>(
125 {"web", "ftp", "wap", "home", "m", "w", "amp", "mobile"});
127 size_t prefix_len = host.find('.');
128 std::string prefix = host.substr(0, prefix_len);
129 const auto& it = kMeaninglessPrefixesLenMap->find(prefix);
130 if (it != kMeaninglessPrefixesLenMap->end() && len > it->size() + 1) {
131 return it->size() + 1;
137 bool IsModelTaxonomyVersionSupported(int model_taxonomy_version) {
138 // Taxonomy version 1 is a special case, where the server would send nothing
139 // (i.e. use 0) for the taxonomy version.
140 if (blink::features::kBrowsingTopicsTaxonomyVersion.Get() == 1) {
141 return model_taxonomy_version == 0;
144 return model_taxonomy_version ==
145 blink::features::kBrowsingTopicsTaxonomyVersion.Get();
150 AnnotatorImpl::AnnotatorImpl(
151 optimization_guide::OptimizationGuideModelProvider* model_provider,
152 scoped_refptr<base::SequencedTaskRunner> background_task_runner,
153 const absl::optional<optimization_guide::proto::Any>& model_metadata)
156 background_task_runner,
157 optimization_guide::proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2,
159 background_task_runner_(background_task_runner) {
160 // Unloading the model is done via custom logic in this class.
161 SetShouldUnloadModelOnComplete(false);
163 AnnotatorImpl::~AnnotatorImpl() = default;
165 void AnnotatorImpl::NotifyWhenModelAvailable(base::OnceClosure callback) {
166 if (GetBrowsingTopicsModelInfo().has_value()) {
167 std::move(callback).Run();
170 model_available_callbacks_.AddUnsafe(std::move(callback));
173 absl::optional<optimization_guide::ModelInfo>
174 AnnotatorImpl::GetBrowsingTopicsModelInfo() const {
176 if (GetModelInfo()) {
177 DCHECK(GetModelInfo()->GetModelMetadata());
178 absl::optional<optimization_guide::proto::PageTopicsModelMetadata>
179 model_metadata = optimization_guide::ParsedAnyMetadata<
180 optimization_guide::proto::PageTopicsModelMetadata>(
181 *GetModelInfo()->GetModelMetadata());
182 DCHECK(model_metadata);
183 DCHECK(IsModelTaxonomyVersionSupported(model_metadata->taxonomy_version()));
185 #endif // DCHECK_IS_ON()
186 return GetModelInfo();
189 void AnnotatorImpl::BatchAnnotate(BatchAnnotationCallback callback,
190 const std::vector<std::string>& inputs) {
191 if (override_list_file_path_.has_value() && !override_list_.has_value()) {
192 background_task_runner_->PostTaskAndReplyWithResult(
194 base::BindOnce(&LoadOverrideListFromFile, *override_list_file_path_),
195 base::BindOnce(&AnnotatorImpl::OnOverrideListLoadAttemptDone,
196 weak_ptr_factory_.GetWeakPtr(), std::move(callback),
200 StartBatchAnnotate(std::move(callback), inputs);
203 void AnnotatorImpl::OnOverrideListLoadAttemptDone(
204 BatchAnnotationCallback callback,
205 const std::vector<std::string>& inputs,
206 absl::optional<std::unordered_map<std::string, std::vector<int32_t>>>
208 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
209 DCHECK(override_list_file_path_);
211 // If the override list is supposed to be used, it must be. Otherwise do not
212 // compute any annotations.
213 if (!override_list) {
214 std::vector<Annotation> annotations;
215 for (const std::string& input : inputs) {
216 annotations.emplace_back(input);
218 std::move(callback).Run(annotations);
221 override_list_ = override_list;
222 StartBatchAnnotate(std::move(callback), inputs);
225 void AnnotatorImpl::StartBatchAnnotate(BatchAnnotationCallback callback,
226 const std::vector<std::string>& inputs) {
227 in_progess_batches_++;
229 std::unique_ptr<std::vector<Annotation>> annotations =
230 std::make_unique<std::vector<Annotation>>();
231 annotations->reserve(inputs.size());
232 for (const std::string& input : inputs) {
233 annotations->push_back(Annotation(input));
235 std::vector<Annotation>* annotations_ptr = annotations.get();
237 // Note on Lifetime: |annotations| is owned by |on_batch_complete_closure|
238 // which is guaranteed to not be called until the |barrier_closure| has been
239 // invoked |inputs.size()| times. Thus, passing raw pointers to the
240 // heap-allocated |annotations| is safe.
242 base::OnceClosure on_batch_complete_closure = base::BindOnce(
243 &AnnotatorImpl::OnBatchComplete, weak_ptr_factory_.GetWeakPtr(),
244 std::move(callback), std::move(annotations));
246 base::RepeatingClosure barrier_closure =
247 base::BarrierClosure(inputs.size(), std::move(on_batch_complete_closure));
249 for (size_t i = 0; i < inputs.size(); i++) {
251 /*single_input_done_signal=*/barrier_closure,
252 /*annotation=*/(annotations_ptr->data() + i));
256 void AnnotatorImpl::OnBatchComplete(
257 BatchAnnotationCallback callback,
258 std::unique_ptr<std::vector<Annotation>> annotations_ptr) {
259 std::move(callback).Run(*annotations_ptr);
261 // Only unload the model once all batches have been completed.
262 DCHECK_GT(in_progess_batches_, 0U);
263 in_progess_batches_--;
264 if (in_progess_batches_ == 0) {
269 std::string AnnotatorImpl::PreprocessHost(const std::string& host) const {
270 std::string output = base::ToLowerASCII(host);
272 // Meaningless prefix v2 is only supported/required for
273 // |kMeaninglessPrefixV2MinVersion| and on.
274 if (version_ >= kMeaninglessPrefixV2MinVersion) {
275 int idx = MeaninglessPrefixLength(output);
277 output = output.substr(idx);
280 // Strip the 'www.' if it exists.
281 if (base::StartsWith(output, "www.")) {
282 output = output.substr(4);
286 static const char kCharsToReplaceWithSpace[] = {'-', '_', '.', '+'};
287 for (char c : kCharsToReplaceWithSpace) {
288 std::replace(output.begin(), output.end(), c, ' ');
294 void AnnotatorImpl::AnnotateSingleInput(
295 base::OnceClosure single_input_done_signal,
296 Annotation* annotation) {
297 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
298 std::string processed_input = PreprocessHost(annotation->input);
300 if (override_list_) {
301 DCHECK(override_list_file_path_);
302 auto iter = override_list_->find(processed_input);
304 base::UmaHistogramBoolean("BrowsingTopics.OverrideList.UsedOverride",
305 iter != override_list_->end());
307 if (iter != override_list_->end()) {
308 annotation->topics = iter->second;
309 std::move(single_input_done_signal).Run();
310 // |annotation| may have been destroyed, do not use it past here.
315 ExecuteModelWithInput(
317 &AnnotatorImpl::PostprocessCategoriesToBatchAnnotationResult,
318 weak_ptr_factory_.GetWeakPtr(), std::move(single_input_done_signal),
323 void AnnotatorImpl::PostprocessCategoriesToBatchAnnotationResult(
324 base::OnceClosure single_input_done_signal,
325 Annotation* annotation,
326 const absl::optional<std::vector<tflite::task::core::Category>>& output) {
327 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
329 annotation->topics = ExtractCategoriesFromModelOutput(*output).value_or(
330 std::vector<int32_t>{});
333 std::move(single_input_done_signal).Run();
334 // |annotation| may have been destroyed, do not use it past here.
337 absl::optional<std::vector<int32_t>>
338 AnnotatorImpl::ExtractCategoriesFromModelOutput(
339 const std::vector<tflite::task::core::Category>& model_output) const {
340 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
342 absl::optional<optimization_guide::proto::PageTopicsModelMetadata>
343 model_metadata = ParsedSupportedFeaturesForLoadedModel<
344 optimization_guide::proto::PageTopicsModelMetadata>();
345 if (!model_metadata) {
346 return absl::nullopt;
349 absl::optional<std::string> visibility_category_name =
350 model_metadata->output_postprocessing_params().has_visibility_params() &&
351 model_metadata->output_postprocessing_params()
354 ? absl::make_optional(model_metadata->output_postprocessing_params()
359 std::vector<std::pair<int32_t, float>> category_candidates;
361 for (const auto& category : model_output) {
362 if (visibility_category_name &&
363 category.class_name == *visibility_category_name) {
366 // Assume everything else is for categories.
368 if (base::StringToInt(category.class_name, &category_id)) {
369 category_candidates.emplace_back(category_id,
370 static_cast<float>(category.score));
374 // Postprocess categories.
375 if (!model_metadata->output_postprocessing_params().has_category_params()) {
376 // No parameters for postprocessing, so just return.
377 return absl::nullopt;
379 const optimization_guide::proto::PageTopicsCategoryPostprocessingParams
381 model_metadata->output_postprocessing_params().category_params();
383 // Determine the categories with the highest weights.
385 category_candidates.begin(), category_candidates.end(),
386 [](const std::pair<int32_t, float>& a,
387 const std::pair<int32_t, float>& b) { return a.second > b.second; });
388 size_t max_categories = static_cast<size_t>(category_params.max_categories());
389 float total_weight = 0.0;
390 float sum_positive_scores = 0.0;
391 absl::optional<std::pair<size_t, float>> none_idx_and_weight;
392 std::vector<std::pair<int32_t, float>> categories;
393 categories.reserve(max_categories);
394 for (size_t i = 0; i < category_candidates.size() && i < max_categories;
396 std::pair<int32_t, float> candidate = category_candidates[i];
397 categories.push_back(candidate);
398 total_weight += candidate.second;
400 if (candidate.second > 0) {
401 sum_positive_scores += candidate.second;
404 if (candidate.first == kNoneCategoryId) {
405 none_idx_and_weight = std::make_pair(i, candidate.second);
409 // Prune out categories that do not meet the minimum threshold.
410 if (category_params.min_category_weight() > 0) {
411 base::EraseIf(categories, [&](const std::pair<int32_t, float>& category) {
412 return category.second < category_params.min_category_weight();
416 // Prune out none weights.
417 if (total_weight == 0) {
418 return absl::nullopt;
420 if (none_idx_and_weight) {
421 if ((none_idx_and_weight->second / total_weight) >
422 category_params.min_none_weight()) {
423 // None weight is too strong.
424 return absl::nullopt;
426 // None weight doesn't matter, so prune it out. Note that it may have
427 // already been removed above if its weight was below the category min.
428 base::EraseIf(categories, [&](const std::pair<int32_t, float>& category) {
429 return category.first == kNoneCategoryId;
433 // Normalize category weights.
434 float normalization_factor =
435 sum_positive_scores > 0 ? sum_positive_scores : 1.0;
436 base::EraseIf(categories, [&](const std::pair<int32_t, float>& category) {
437 return (category.second / normalization_factor) <
438 category_params.min_normalized_weight_within_top_n();
441 std::vector<int32_t> final_categories;
442 final_categories.reserve(categories.size());
443 for (const auto& category : categories) {
444 // We expect the weight to be between 0 and 1.
445 DCHECK(category.second >= 0.0 && category.second <= 1.0);
446 final_categories.emplace_back(category.first);
448 DCHECK_LE(final_categories.size(), max_categories);
450 return final_categories;
453 void AnnotatorImpl::UnloadModel() {
454 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
456 optimization_guide::BertModelHandler::UnloadModel();
457 override_list_ = absl::nullopt;
460 void AnnotatorImpl::OnModelUpdated(
461 optimization_guide::proto::OptimizationTarget optimization_target,
462 base::optional_ref<const optimization_guide::ModelInfo> model_info) {
463 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
464 // First invoke parent to update internal status.
465 optimization_guide::BertModelHandler::OnModelUpdated(optimization_target,
468 if (optimization_target !=
469 optimization_guide::proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2) {
473 if (!model_info.has_value() || !model_info->GetModelMetadata()) {
477 absl::optional<optimization_guide::proto::PageTopicsModelMetadata>
478 model_metadata = optimization_guide::ParsedAnyMetadata<
479 optimization_guide::proto::PageTopicsModelMetadata>(
480 *model_info->GetModelMetadata());
481 if (!model_metadata) {
485 if (!IsModelTaxonomyVersionSupported(model_metadata->taxonomy_version())) {
486 // Also clear the model in the underlying model executor code so that it
487 // cannot be accidentally called on the wrong taxonomy version.
488 optimization_guide::BertModelHandler::OnModelUpdated(
489 optimization_guide::proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2,
493 version_ = model_metadata->version();
495 // New model, new override list.
496 override_list_file_path_ = absl::nullopt;
497 override_list_ = absl::nullopt;
498 for (const base::FilePath& path : model_info->GetAdditionalFiles()) {
499 DCHECK(path.IsAbsolute());
500 if (path.BaseName() == base::FilePath(kOverrideListBasePath)) {
501 override_list_file_path_ = path;
506 // Run any callbacks that were waiting for an updated model.
508 // This should always be the last statement in this method, after all internal
509 // state has been updated because these callbacks may trigger an immediate
510 // annotation request.
511 model_available_callbacks_.Notify();
514 } // namespace browsing_topics