blob: 687180a78e9fafe0655116dbba424b3d2ed61f33 [file] [log] [blame]
// Copyright (C) 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "icing/scoring/advanced_scoring/score-expression.h"
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/join/join-children-fetcher.h"
#include "icing/schema/section.h"
#include "icing/scoring/bm25f-calculator.h"
#include "icing/scoring/scored-document-hit.h"
#include "icing/scoring/section-weights.h"
#include "icing/store/document-associated-score-data.h"
#include "icing/store/document-filter-data.h"
#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
#include "icing/util/embedding-util.h"
#include "icing/util/logging.h"
#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
namespace {
libtextclassifier3::Status CheckChildrenNotNull(
const std::vector<std::unique_ptr<ScoreExpression>>& children) {
for (const auto& child : children) {
ICING_RETURN_ERROR_IF_NULL(child);
}
return libtextclassifier3::Status::OK;
}
} // namespace
libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
OperatorScoreExpression::Create(
OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children) {
if (children.empty()) {
return absl_ports::InvalidArgumentError(
"OperatorScoreExpression must have at least one argument.");
}
ICING_RETURN_IF_ERROR(CheckChildrenNotNull(children));
bool children_all_constant_double = true;
for (const auto& child : children) {
if (child->type() != ScoreExpressionType::kDouble) {
return absl_ports::InvalidArgumentError(
"Operators are only supported for double type.");
}
if (!child->is_constant()) {
children_all_constant_double = false;
}
}
if (op == OperatorType::kNegative) {
if (children.size() != 1) {
return absl_ports::InvalidArgumentError(
"Negative operator must have only 1 argument.");
}
}
std::unique_ptr<ScoreExpression> expression =
std::unique_ptr<OperatorScoreExpression>(
new OperatorScoreExpression(op, std::move(children)));
if (children_all_constant_double) {
// Because all of the children are constants, this expression does not
// depend on the DocHitInto or query_it that are passed into it.
return ConstantScoreExpression::Create(
expression->eval(DocHitInfo(), /*query_it=*/nullptr));
}
return expression;
}
libtextclassifier3::StatusOr<double> OperatorScoreExpression::eval(
const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
// The Create factory guarantees that an operator will have at least one
// child.
ICING_ASSIGN_OR_RETURN(double res, children_.at(0)->eval(hit_info, query_it));
if (op_ == OperatorType::kNegative) {
return -res;
}
for (int i = 1; i < children_.size(); ++i) {
ICING_ASSIGN_OR_RETURN(double v, children_.at(i)->eval(hit_info, query_it));
switch (op_) {
case OperatorType::kPlus:
res += v;
break;
case OperatorType::kMinus:
res -= v;
break;
case OperatorType::kTimes:
res *= v;
break;
case OperatorType::kDiv:
res /= v;
break;
case OperatorType::kNegative:
return absl_ports::InternalError("Should never reach here.");
}
if (!std::isfinite(res)) {
return absl_ports::InvalidArgumentError(
"Got a non-finite value while evaluating operator score expression.");
}
}
return res;
}
const std::unordered_map<std::string, MathFunctionScoreExpression::FunctionType>
MathFunctionScoreExpression::kFunctionNames = {
{"log", FunctionType::kLog}, {"pow", FunctionType::kPow},
{"max", FunctionType::kMax}, {"min", FunctionType::kMin},
{"len", FunctionType::kLen}, {"sum", FunctionType::kSum},
{"avg", FunctionType::kAvg}, {"sqrt", FunctionType::kSqrt},
{"abs", FunctionType::kAbs}, {"sin", FunctionType::kSin},
{"cos", FunctionType::kCos}, {"tan", FunctionType::kTan}};
const std::unordered_set<MathFunctionScoreExpression::FunctionType>
MathFunctionScoreExpression::kVariableArgumentsFunctions = {
FunctionType::kMax, FunctionType::kMin, FunctionType::kLen,
FunctionType::kSum, FunctionType::kAvg};
libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
MathFunctionScoreExpression::Create(
FunctionType function_type,
std::vector<std::unique_ptr<ScoreExpression>> args) {
if (args.empty()) {
return absl_ports::InvalidArgumentError(
"Math functions must have at least one argument.");
}
ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
// Received a list type in the function argument.
if (args.size() == 1 && args[0]->type() == ScoreExpressionType::kDoubleList) {
// Only certain functions support list type.
if (kVariableArgumentsFunctions.count(function_type) > 0) {
return std::unique_ptr<MathFunctionScoreExpression>(
new MathFunctionScoreExpression(function_type, std::move(args)));
}
return absl_ports::InvalidArgumentError(absl_ports::StrCat(
"Received an unsupported list type argument in the math function."));
}
bool args_all_constant_double = true;
for (const auto& child : args) {
if (child->type() != ScoreExpressionType::kDouble) {
return absl_ports::InvalidArgumentError(
"Got an invalid type for the math function. Should expect a double "
"type argument.");
}
if (!child->is_constant()) {
args_all_constant_double = false;
}
}
switch (function_type) {
case FunctionType::kLog:
if (args.size() != 1 && args.size() != 2) {
return absl_ports::InvalidArgumentError(
"log must have 1 or 2 arguments.");
}
break;
case FunctionType::kPow:
if (args.size() != 2) {
return absl_ports::InvalidArgumentError("pow must have 2 arguments.");
}
break;
case FunctionType::kSqrt:
if (args.size() != 1) {
return absl_ports::InvalidArgumentError("sqrt must have 1 argument.");
}
break;
case FunctionType::kAbs:
if (args.size() != 1) {
return absl_ports::InvalidArgumentError("abs must have 1 argument.");
}
break;
case FunctionType::kSin:
if (args.size() != 1) {
return absl_ports::InvalidArgumentError("sin must have 1 argument.");
}
break;
case FunctionType::kCos:
if (args.size() != 1) {
return absl_ports::InvalidArgumentError("cos must have 1 argument.");
}
break;
case FunctionType::kTan:
if (args.size() != 1) {
return absl_ports::InvalidArgumentError("tan must have 1 argument.");
}
break;
// Functions that support variable length arguments
case FunctionType::kMax:
[[fallthrough]];
case FunctionType::kMin:
[[fallthrough]];
case FunctionType::kLen:
[[fallthrough]];
case FunctionType::kSum:
[[fallthrough]];
case FunctionType::kAvg:
break;
}
std::unique_ptr<ScoreExpression> expression =
std::unique_ptr<MathFunctionScoreExpression>(
new MathFunctionScoreExpression(function_type, std::move(args)));
if (args_all_constant_double) {
// Because all of the arguments are constants, this expression does not
// depend on the DocHitInto or query_it that are passed into it.
return ConstantScoreExpression::Create(
expression->eval(DocHitInfo(), /*query_it=*/nullptr));
}
return expression;
}
libtextclassifier3::StatusOr<double> MathFunctionScoreExpression::eval(
const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
std::vector<double> values;
if (args_.at(0)->type() == ScoreExpressionType::kDoubleList) {
ICING_ASSIGN_OR_RETURN(values, args_.at(0)->eval_list(hit_info, query_it));
} else {
for (const auto& child : args_) {
ICING_ASSIGN_OR_RETURN(double v, child->eval(hit_info, query_it));
values.push_back(v);
}
}
double res = 0;
switch (function_type_) {
case FunctionType::kLog:
if (values.size() == 1) {
res = log(values[0]);
} else {
// argument 0 is log base
// argument 1 is the value
res = log(values[1]) / log(values[0]);
}
break;
case FunctionType::kPow:
res = pow(values[0], values[1]);
break;
case FunctionType::kMax:
if (values.empty()) {
return absl_ports::InvalidArgumentError(
"Got an empty parameter set in max function");
}
res = *std::max_element(values.begin(), values.end());
break;
case FunctionType::kMin:
if (values.empty()) {
return absl_ports::InvalidArgumentError(
"Got an empty parameter set in min function");
}
res = *std::min_element(values.begin(), values.end());
break;
case FunctionType::kLen:
res = values.size();
break;
case FunctionType::kSum:
res = std::reduce(values.begin(), values.end());
break;
case FunctionType::kAvg:
if (values.empty()) {
return absl_ports::InvalidArgumentError(
"Got an empty parameter set in avg function.");
}
res = std::reduce(values.begin(), values.end()) / values.size();
break;
case FunctionType::kSqrt:
res = sqrt(values[0]);
break;
case FunctionType::kAbs:
res = abs(values[0]);
break;
case FunctionType::kSin:
res = sin(values[0]);
break;
case FunctionType::kCos:
res = cos(values[0]);
break;
case FunctionType::kTan:
res = tan(values[0]);
break;
}
if (!std::isfinite(res)) {
return absl_ports::InvalidArgumentError(
"Got a non-finite value while evaluating math function score "
"expression.");
}
return res;
}
const std::unordered_map<std::string,
DocumentFunctionScoreExpression::FunctionType>
DocumentFunctionScoreExpression::kFunctionNames = {
{"documentScore", FunctionType::kDocumentScore},
{"creationTimestamp", FunctionType::kCreationTimestamp},
{"usageCount", FunctionType::kUsageCount},
{"usageLastUsedTimestamp", FunctionType::kUsageLastUsedTimestamp}};
libtextclassifier3::StatusOr<std::unique_ptr<DocumentFunctionScoreExpression>>
DocumentFunctionScoreExpression::Create(
FunctionType function_type,
std::vector<std::unique_ptr<ScoreExpression>> args,
const DocumentStore* document_store, double default_score,
int64_t current_time_ms) {
if (args.empty()) {
return absl_ports::InvalidArgumentError(
"Document-based functions must have at least one argument.");
}
ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
if (args[0]->type() != ScoreExpressionType::kDocument) {
return absl_ports::InvalidArgumentError(
"The first parameter of document-based functions must be \"this\".");
}
switch (function_type) {
case FunctionType::kDocumentScore:
[[fallthrough]];
case FunctionType::kCreationTimestamp:
if (args.size() != 1) {
return absl_ports::InvalidArgumentError(
"DocumentScore/CreationTimestamp must have 1 argument.");
}
break;
case FunctionType::kUsageCount:
[[fallthrough]];
case FunctionType::kUsageLastUsedTimestamp:
if (args.size() != 2 || args[1]->type() != ScoreExpressionType::kDouble) {
return absl_ports::InvalidArgumentError(
"UsageCount/UsageLastUsedTimestamp must have 2 arguments. The "
"first argument should be \"this\", and the second argument "
"should be the usage type.");
}
break;
}
return std::unique_ptr<DocumentFunctionScoreExpression>(
new DocumentFunctionScoreExpression(function_type, std::move(args),
document_store, default_score,
current_time_ms));
}
libtextclassifier3::StatusOr<double> DocumentFunctionScoreExpression::eval(
const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
switch (function_type_) {
case FunctionType::kDocumentScore:
[[fallthrough]];
case FunctionType::kCreationTimestamp: {
ICING_ASSIGN_OR_RETURN(DocumentAssociatedScoreData score_data,
document_store_.GetDocumentAssociatedScoreData(
hit_info.document_id()),
default_score_);
if (function_type_ == FunctionType::kDocumentScore) {
return static_cast<double>(score_data.document_score());
}
return static_cast<double>(score_data.creation_timestamp_ms());
}
case FunctionType::kUsageCount:
[[fallthrough]];
case FunctionType::kUsageLastUsedTimestamp: {
ICING_ASSIGN_OR_RETURN(double raw_usage_type,
args_[1]->eval(hit_info, query_it));
int usage_type = (int)raw_usage_type;
if (usage_type < 1 || usage_type > 3 || raw_usage_type != usage_type) {
return absl_ports::InvalidArgumentError(
"Usage type must be an integer from 1 to 3");
}
std::optional<UsageStore::UsageScores> usage_scores =
document_store_.GetUsageScores(hit_info.document_id(),
current_time_ms_);
if (!usage_scores) {
// If there's no UsageScores entry present for this doc, then just
// treat it as a default instance.
usage_scores = UsageStore::UsageScores();
}
if (function_type_ == FunctionType::kUsageCount) {
if (usage_type == 1) {
return usage_scores->usage_type1_count;
} else if (usage_type == 2) {
return usage_scores->usage_type2_count;
} else {
return usage_scores->usage_type3_count;
}
}
if (usage_type == 1) {
return usage_scores->usage_type1_last_used_timestamp_s * 1000.0;
} else if (usage_type == 2) {
return usage_scores->usage_type2_last_used_timestamp_s * 1000.0;
} else {
return usage_scores->usage_type3_last_used_timestamp_s * 1000.0;
}
}
}
}
libtextclassifier3::StatusOr<
std::unique_ptr<RelevanceScoreFunctionScoreExpression>>
RelevanceScoreFunctionScoreExpression::Create(
std::vector<std::unique_ptr<ScoreExpression>> args,
Bm25fCalculator* bm25f_calculator, double default_score) {
if (args.size() != 1) {
return absl_ports::InvalidArgumentError(
"relevanceScore must have 1 argument.");
}
ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
if (args[0]->type() != ScoreExpressionType::kDocument) {
return absl_ports::InvalidArgumentError(
"relevanceScore must take \"this\" as its argument.");
}
return std::unique_ptr<RelevanceScoreFunctionScoreExpression>(
new RelevanceScoreFunctionScoreExpression(bm25f_calculator,
default_score));
}
libtextclassifier3::StatusOr<double>
RelevanceScoreFunctionScoreExpression::eval(
const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
if (query_it == nullptr) {
return default_score_;
}
return static_cast<double>(
bm25f_calculator_.ComputeScore(query_it, hit_info, default_score_));
}
libtextclassifier3::StatusOr<
std::unique_ptr<ChildrenRankingSignalsFunctionScoreExpression>>
ChildrenRankingSignalsFunctionScoreExpression::Create(
std::vector<std::unique_ptr<ScoreExpression>> args,
const JoinChildrenFetcher* join_children_fetcher) {
if (args.size() != 1) {
return absl_ports::InvalidArgumentError(
"childrenRankingSignals must have 1 argument.");
}
ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
if (args[0]->type() != ScoreExpressionType::kDocument) {
return absl_ports::InvalidArgumentError(
"childrenRankingSignals must take \"this\" as its argument.");
}
if (join_children_fetcher == nullptr) {
return absl_ports::InvalidArgumentError(
"childrenRankingSignals must only be used with join, but "
"JoinChildrenFetcher "
"is not provided.");
}
return std::unique_ptr<ChildrenRankingSignalsFunctionScoreExpression>(
new ChildrenRankingSignalsFunctionScoreExpression(
*join_children_fetcher));
}
libtextclassifier3::StatusOr<std::vector<double>>
ChildrenRankingSignalsFunctionScoreExpression::eval_list(
const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
ICING_ASSIGN_OR_RETURN(
std::vector<ScoredDocumentHit> children_hits,
join_children_fetcher_.GetChildren(hit_info.document_id()));
std::vector<double> children_scores;
children_scores.reserve(children_hits.size());
for (const ScoredDocumentHit& child_hit : children_hits) {
children_scores.push_back(child_hit.score());
}
return std::move(children_scores);
}
libtextclassifier3::StatusOr<
std::unique_ptr<PropertyWeightsFunctionScoreExpression>>
PropertyWeightsFunctionScoreExpression::Create(
std::vector<std::unique_ptr<ScoreExpression>> args,
const DocumentStore* document_store, const SectionWeights* section_weights,
int64_t current_time_ms) {
if (args.size() != 1) {
return absl_ports::InvalidArgumentError(
"propertyWeights must have 1 argument.");
}
ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
if (args[0]->type() != ScoreExpressionType::kDocument) {
return absl_ports::InvalidArgumentError(
"propertyWeights must take \"this\" as its argument.");
}
return std::unique_ptr<PropertyWeightsFunctionScoreExpression>(
new PropertyWeightsFunctionScoreExpression(
document_store, section_weights, current_time_ms));
}
libtextclassifier3::StatusOr<std::vector<double>>
PropertyWeightsFunctionScoreExpression::eval_list(
const DocHitInfo& hit_info, const DocHitInfoIterator*) const {
std::vector<double> weights;
SectionIdMask sections = hit_info.hit_section_ids_mask();
SchemaTypeId schema_type_id = GetSchemaTypeId(hit_info.document_id());
while (sections != 0) {
SectionId section_id = __builtin_ctzll(sections);
sections &= ~(UINT64_C(1) << section_id);
weights.push_back(section_weights_.GetNormalizedSectionWeight(
schema_type_id, section_id));
}
return weights;
}
SchemaTypeId PropertyWeightsFunctionScoreExpression::GetSchemaTypeId(
DocumentId document_id) const {
auto filter_data_optional =
document_store_.GetAliveDocumentFilterData(document_id, current_time_ms_);
if (!filter_data_optional) {
// This should never happen. The only failure case for
// GetAliveDocumentFilterData is if the document_id is outside of the range
// of allocated document_ids, which shouldn't be possible since we're
// getting this document_id from the posting lists.
ICING_LOG(WARNING) << "No document filter data for document ["
<< document_id << "]";
return kInvalidSchemaTypeId;
}
return filter_data_optional.value().schema_type_id();
}
libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
GetSearchSpecEmbeddingFunctionScoreExpression::Create(
std::vector<std::unique_ptr<ScoreExpression>> args) {
if (args.size() != 1) {
return absl_ports::InvalidArgumentError(
absl_ports::StrCat(kFunctionName, " must have 1 argument."));
}
if (args[0]->type() != ScoreExpressionType::kDouble) {
return absl_ports::InvalidArgumentError(
absl_ports::StrCat(kFunctionName, " got invalid argument type."));
}
bool is_constant = args[0]->is_constant();
std::unique_ptr<ScoreExpression> expression =
std::unique_ptr<GetSearchSpecEmbeddingFunctionScoreExpression>(
new GetSearchSpecEmbeddingFunctionScoreExpression(
std::move(args[0])));
if (is_constant) {
return ConstantScoreExpression::Create(
expression->eval(DocHitInfo(), /*query_it=*/nullptr),
expression->type());
}
return expression;
}
libtextclassifier3::StatusOr<double>
GetSearchSpecEmbeddingFunctionScoreExpression::eval(
const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
ICING_ASSIGN_OR_RETURN(double raw_query_index,
arg_->eval(hit_info, query_it));
uint32_t query_index = (uint32_t)raw_query_index;
if (query_index != raw_query_index) {
return absl_ports::InvalidArgumentError(
"The index of an embedding query must be an integer.");
}
return query_index;
}
libtextclassifier3::StatusOr<
std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
MatchedSemanticScoresFunctionScoreExpression::Create(
std::vector<std::unique_ptr<ScoreExpression>> args,
SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
const EmbeddingQueryResults* embedding_query_results) {
ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
if (args.empty() || args[0]->type() != ScoreExpressionType::kDocument) {
return absl_ports::InvalidArgumentError(
absl_ports::StrCat(kFunctionName, " is not called with \"this\""));
}
if (args.size() != 2 && args.size() != 3) {
return absl_ports::InvalidArgumentError(
absl_ports::StrCat(kFunctionName, " got invalid number of arguments."));
}
if (args[1]->type() != ScoreExpressionType::kVectorIndex) {
return absl_ports::InvalidArgumentError(absl_ports::StrCat(
kFunctionName, " got invalid argument type for embedding vector."));
}
if (args.size() == 3 && args[2]->type() != ScoreExpressionType::kString) {
return absl_ports::InvalidArgumentError(
"Embedding metric can only be given as a string.");
}
SearchSpecProto::EmbeddingQueryMetricType::Code metric_type =
default_metric_type;
if (args.size() == 3) {
if (!args[2]->is_constant()) {
return absl_ports::InvalidArgumentError(
"Embedding metric can only be given as a constant string.");
}
ICING_ASSIGN_OR_RETURN(std::string_view metric, args[2]->eval_string());
ICING_ASSIGN_OR_RETURN(
metric_type,
embedding_util::GetEmbeddingQueryMetricTypeFromName(metric));
}
return std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>(
new MatchedSemanticScoresFunctionScoreExpression(
std::move(args), metric_type, *embedding_query_results));
}
libtextclassifier3::StatusOr<std::vector<double>>
MatchedSemanticScoresFunctionScoreExpression::eval_list(
const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
ICING_ASSIGN_OR_RETURN(double raw_query_index,
args_[1]->eval(hit_info, query_it));
uint32_t query_index = (uint32_t)raw_query_index;
const std::vector<double>* scores =
embedding_query_results_.GetMatchedScoresForDocument(
query_index, metric_type_, hit_info.document_id());
if (scores == nullptr) {
return std::vector<double>();
}
return *scores;
}
} // namespace lib
} // namespace icing