1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "components/assist_ranker/base_predictor.h"
7 #include "base/feature_list.h"
8 #include "components/assist_ranker/proto/ranker_example.pb.h"
9 #include "components/assist_ranker/proto/ranker_model.pb.h"
10 #include "components/assist_ranker/ranker_example_util.h"
11 #include "components/assist_ranker/ranker_model.h"
12 #include "services/metrics/public/cpp/ukm_entry_builder.h"
13 #include "services/metrics/public/cpp/ukm_recorder.h"
16 namespace assist_ranker {
18 BasePredictor::BasePredictor(const PredictorConfig& config) : config_(config) {
19 // TODO(chrome-ranker-team): validate config.
20 if (config_.field_trial) {
21 is_query_enabled_ = base::FeatureList::IsEnabled(*config_.field_trial);
23 DVLOG(0) << "No field trial specified";
27 BasePredictor::~BasePredictor() {}
29 void BasePredictor::LoadModel(std::unique_ptr<RankerModelLoader> model_loader) {
30 if (!is_query_enabled_)
34 DVLOG(0) << "This predictor already has a model loader.";
37 // Take ownership of the model loader.
38 model_loader_ = std::move(model_loader);
39 // Kick off the initial model load.
40 model_loader_->NotifyOfRankerActivity();
43 void BasePredictor::OnModelAvailable(
44 std::unique_ptr<assist_ranker::RankerModel> model) {
45 ranker_model_ = std::move(model);
46 is_ready_ = Initialize();
49 bool BasePredictor::IsReady() {
50 if (!is_ready_ && model_loader_)
51 model_loader_->NotifyOfRankerActivity();
56 void BasePredictor::LogFeatureToUkm(const std::string& feature_name,
57 const Feature& feature,
58 ukm::UkmEntryBuilder* ukm_builder) {
61 if (!base::ContainsKey(*config_.feature_whitelist, feature_name)) {
62 DVLOG(1) << "Feature not whitelisted: " << feature_name;
66 switch (feature.feature_type_case()) {
67 case Feature::kBoolValue:
68 case Feature::kFloatValue:
69 case Feature::kInt32Value:
70 case Feature::kStringValue: {
71 int64_t feature_int64_value = -1;
72 FeatureToInt64(feature, &feature_int64_value);
73 DVLOG(3) << "Logging: " << feature_name << ": " << feature_int64_value;
74 ukm_builder->SetMetric(feature_name, feature_int64_value);
77 case Feature::kStringList: {
78 for (int i = 0; i < feature.string_list().string_value_size(); ++i) {
79 int64_t feature_int64_value = -1;
80 FeatureToInt64(feature, &feature_int64_value, i);
81 DVLOG(3) << "Logging: " << feature_name << ": " << feature_int64_value;
82 ukm_builder->SetMetric(feature_name, feature_int64_value);
87 DVLOG(0) << "Could not convert feature to int: " << feature_name;
91 void BasePredictor::LogExampleToUkm(const RankerExample& example,
92 ukm::SourceId source_id) {
93 if (config_.log_type != LOG_UKM) {
94 DVLOG(0) << "Wrong log type in predictor config: " << config_.log_type;
98 if (!config_.feature_whitelist) {
99 DVLOG(0) << "No whitelist specified.";
102 if (config_.feature_whitelist->empty()) {
103 DVLOG(0) << "Empty whitelist, examples will not be logged.";
107 ukm::UkmEntryBuilder builder(source_id, config_.logging_name);
108 for (const auto& feature_kv : example.features()) {
109 LogFeatureToUkm(feature_kv.first, feature_kv.second, &builder);
111 builder.Record(ukm::UkmRecorder::Get());
114 std::string BasePredictor::GetModelName() const {
115 return config_.model_name;
118 GURL BasePredictor::GetModelUrl() const {
119 if (!config_.field_trial_url_param) {
120 DVLOG(1) << "No URL specified.";
124 return GURL(config_.field_trial_url_param->Get());
127 RankerExample BasePredictor::PreprocessExample(const RankerExample& example) {
128 if (ranker_model_->proto().has_metadata() &&
129 ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) {
130 return HashExampleFeatureNames(example);
135 } // namespace assist_ranker