blob: 21180db0e132cffa5df4b5e77ec6c63da50a7dd4 [file] [log] [blame]
// Copyright (C) 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
#define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/index/embed/embedding-hit.h"
#include "icing/index/embed/embedding-index.h"
#include "icing/index/embed/embedding-query-results.h"
#include "icing/index/embed/embedding-scorer.h"
#include "icing/index/embed/posting-list-embedding-hit-accessor.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/index/iterator/section-restrict-data.h"
#include "icing/proto/search.pb.h"
#include "icing/schema/section.h"
namespace icing {
namespace lib {
class DocHitInfoIteratorEmbedding : public DocHitInfoLeafIterator {
public:
// Create a DocHitInfoIterator for iterating through all docs which have an
// embedding matched with the provided query with a score in the range of
// [score_low, score_high], using the provided metric_type.
//
// The iterator will store the matched embedding scores in score_map to
// prepare for scoring.
//
// The iterator will handle the section restriction logic internally by the
// provided section_restrict_data.
//
// Returns:
// - a DocHitInfoIteratorEmbedding instance on success.
// - Any error from posting lists.
static libtextclassifier3::StatusOr<
std::unique_ptr<DocHitInfoIteratorEmbedding>>
Create(const PropertyProto::VectorProto* query,
std::unique_ptr<SectionRestrictData> section_restrict_data,
SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
double score_low, double score_high,
EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
const EmbeddingIndex* embedding_index);
libtextclassifier3::Status Advance() override;
// The iterator will internally handle the section restriction logic by itself
// to have better control, so that it is able to filter out embedding hits
// from unwanted sections to avoid retrieving unnecessary vectors and
// calculate scores for them.
bool full_section_restriction_applied() const override { return true; }
libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
return absl_ports::InvalidArgumentError(
"Query suggestions for the semanticSearch function are not supported");
}
CallStats GetCallStats() const override {
return CallStats(
/*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_,
/*num_leaf_advance_calls_main_index_in=*/0,
/*num_leaf_advance_calls_integer_index_in=*/0,
/*num_leaf_advance_calls_no_index_in=*/0,
/*num_blocks_inspected_in=*/0);
}
std::string ToString() const override { return "embedding_iterator"; }
// PopulateMatchedTermsStats is not applicable to embedding search.
void PopulateMatchedTermsStats(
std::vector<TermMatchInfo>* matched_terms_stats,
SectionIdMask filtering_section_mask) const override {}
private:
explicit DocHitInfoIteratorEmbedding(
const PropertyProto::VectorProto* query,
std::unique_ptr<SectionRestrictData> section_restrict_data,
SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low,
double score_high,
EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
const EmbeddingIndex* embedding_index,
std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor)
: query_(*query),
section_restrict_data_(std::move(section_restrict_data)),
metric_type_(metric_type),
embedding_scorer_(std::move(embedding_scorer)),
score_low_(score_low),
score_high_(score_high),
score_map_(*score_map),
embedding_index_(*embedding_index),
posting_list_accessor_(std::move(posting_list_accessor)),
cached_embedding_hits_idx_(0),
current_allowed_sections_mask_(kSectionIdMaskAll),
no_more_hit_(false),
num_advance_calls_(0) {}
// Advance to the next embedding hit of the current document. If the current
// document id is kInvalidDocumentId, the method will advance to the first
// embedding hit of the next document and update doc_hit_info_.
//
// This method also properly updates cached_embedding_hits_,
// cached_embedding_hits_idx_, current_allowed_sections_mask_, and
// no_more_hit_ to reflect the current state.
//
// Returns:
// - a const pointer to the next embedding hit on success.
// - nullptr, if there is no more hit for the current document, or no more
// hit in general if the current document id is kInvalidDocumentId.
// - Any error from posting lists.
libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit();
// Query information
const PropertyProto::VectorProto& query_; // Does not own
std::unique_ptr<SectionRestrictData> section_restrict_data_; // Nullable.
// Scoring arguments
SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
std::unique_ptr<EmbeddingScorer> embedding_scorer_;
double score_low_;
double score_high_;
// Score map
EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map_; // Does not own
// Access to embeddings index data
const EmbeddingIndex& embedding_index_;
std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_;
// Cached data from the embeddings index
std::vector<EmbeddingHit> cached_embedding_hits_;
int cached_embedding_hits_idx_;
SectionIdMask current_allowed_sections_mask_;
bool no_more_hit_;
int num_advance_calls_;
};
} // namespace lib
} // namespace icing
#endif // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_