Skip to content

Commit a42a17f

Browse files
authored
feat: wildcard pattern added for fields to find all non-null values (#4941)
fixed: #4937
1 parent 954e940 commit a42a17f

File tree

8 files changed

+502
-4
lines changed

8 files changed

+502
-4
lines changed

src/core/search/ast_expr.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ struct AstNode;
2222
// Matches all documents
2323
struct AstStarNode {};
2424

25+
// Matches all documents where this field has a non-null value
26+
struct AstStarFieldNode {};
27+
2528
// Matches terms in text fields
2629
struct AstTermNode {
2730
explicit AstTermNode(std::string term);
@@ -108,9 +111,9 @@ struct AstKnnNode {
108111
std::optional<float> ef_runtime;
109112
};
110113

111-
using NodeVariants =
112-
std::variant<std::monostate, AstStarNode, AstTermNode, AstPrefixNode, AstRangeNode,
113-
AstNegateNode, AstLogicalNode, AstFieldNode, AstTagsNode, AstKnnNode>;
114+
using NodeVariants = std::variant<std::monostate, AstStarNode, AstStarFieldNode, AstTermNode,
115+
AstPrefixNode, AstRangeNode, AstNegateNode, AstLogicalNode,
116+
AstFieldNode, AstTagsNode, AstKnnNode>;
114117

115118
struct AstNode : public NodeVariants {
116119
using variant::variant;

src/core/search/base.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ struct BaseIndex {
9090
// Returns true if the document was added / indexed
9191
virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
9292
virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
93+
94+
// Returns documents that have non-null values for this field (used for @field:* queries)
95+
virtual std::optional<std::vector<DocId>> GetAllResults() const {
96+
return std::nullopt;
97+
}
9398
};
9499

95100
// Base class for type-specific sorting indices.

src/core/search/indices.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ struct NumericIndex : public BaseIndex {
3535

3636
std::vector<DocId> Range(double l, double r) const;
3737

38+
std::optional<std::vector<DocId>> GetAllResults() const override {
39+
return Range(-std::numeric_limits<double>::infinity(), std::numeric_limits<double>::infinity());
40+
}
41+
3842
private:
3943
using Entry = std::pair<double, DocId>;
4044
absl::btree_set<Entry, std::less<Entry>, PMR_NS::polymorphic_allocator<Entry>> entries_;
@@ -58,6 +62,20 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
5862
// Returns all the terms that appear as keys in the reverse index.
5963
std::vector<std::string> GetTerms() const;
6064

65+
std::optional<std::vector<DocId>> GetAllResults() const override {
66+
absl::flat_hash_set<DocId> unique_docs;
67+
68+
for (const auto& [term, container] : entries_) {
69+
for (const DocId& id : container) {
70+
unique_docs.insert(id);
71+
}
72+
}
73+
74+
auto result = std::vector<DocId>(unique_docs.begin(), unique_docs.end());
75+
std::sort(result.begin(), result.end());
76+
return result;
77+
}
78+
6179
protected:
6280
using StringList = DocumentAccessor::StringList;
6381

@@ -133,6 +151,33 @@ struct FlatVectorIndex : public BaseVectorIndex {
133151

134152
const float* Get(DocId doc) const;
135153

154+
// Return all documents that have vectors in this index
155+
std::optional<std::vector<DocId>> GetAllResults() const override {
156+
std::vector<DocId> result;
157+
size_t num_vectors = entries_.size() / dim_;
158+
result.reserve(num_vectors);
159+
160+
for (DocId id = 0; id < num_vectors; ++id) {
161+
// Check if the vector is not zero (all elements are 0)
162+
// TODO: Valid vector can contain 0s, we should use a better approach
163+
const float* vec = Get(id);
164+
bool is_zero_vector = true;
165+
166+
for (size_t i = 0; i < dim_; ++i) {
167+
if (vec[i] != 0.0f) {
168+
is_zero_vector = false;
169+
break;
170+
}
171+
}
172+
173+
if (!is_zero_vector) {
174+
result.push_back(id);
175+
}
176+
}
177+
178+
return result;
179+
}
180+
136181
protected:
137182
void AddVector(DocId id, const VectorPtr& vector) override;
138183

@@ -142,6 +187,10 @@ struct FlatVectorIndex : public BaseVectorIndex {
142187

143188
struct HnswlibAdapter;
144189

190+
// This index does't have GetAllResults method
191+
// because it's not possible to get all vectors from the index
192+
// It depends on the Hnswlib implementation
193+
// TODO: Consider adding GetAllResults method in the future
145194
struct HnswVectorIndex : public BaseVectorIndex {
146195
HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
147196
~HnswVectorIndex();

src/core/search/parser.y

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
%define api.value.type variant
1111
%define api.parser.class {Parser}
1212
%define parse.assert
13+
%define api.value.automove true
1314

1415
// Added to header file before parser declaration.
1516
%code requires {
@@ -136,6 +137,7 @@ search_unary_expr:
136137
field_cond:
137138
TERM { $$ = AstTermNode(std::move($1)); }
138139
| UINT32 { $$ = AstTermNode(std::move($1)); }
140+
| STAR { $$ = AstStarFieldNode(); }
139141
| NOT_OP field_cond { $$ = AstNegateNode(std::move($2)); }
140142
| LPAREN field_cond_expr RPAREN { $$ = std::move($2); }
141143
| LBRACKET numeric_filter_expr RBRACKET { $$ = std::move($2); }
@@ -168,7 +170,7 @@ field_or_expr:
168170

169171
field_unary_expr:
170172
LPAREN field_cond_expr RPAREN { $$ = std::move($2); }
171-
| NOT_OP field_unary_expr { $$ = AstNegateNode(std::move($2)); };
173+
| NOT_OP field_unary_expr { $$ = AstNegateNode(std::move($2)); }
172174
| TERM { $$ = AstTermNode(std::move($1)); }
173175
| UINT32 { $$ = AstTermNode(std::move($1)); }
174176

src/core/search/search.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ struct ProfileBuilder {
143143
[](const AstKnnNode& n) { return absl::StrCat("KNN{l=", n.limit, "}"); },
144144
[](const AstNegateNode& n) { return absl::StrCat("Negate{}"); },
145145
[](const AstStarNode& n) { return absl::StrCat("Star{}"); },
146+
[](const AstStarFieldNode& n) { return absl::StrCat("StarField{}"); },
146147
};
147148
return visit(node_info, node.Variant());
148149
}
@@ -302,6 +303,32 @@ struct BasicSearch {
302303
return UnifyResults(GetSubResults(selected_indices, mapping), LogicOp::OR);
303304
}
304305

306+
IndexResult Search(const AstStarFieldNode& node, string_view active_field) {
307+
// Try to get a sort index first, as `@field:*` might imply wanting sortable behavior
308+
BaseSortIndex* sort_index = indices_->GetSortIndex(active_field);
309+
if (sort_index) {
310+
if (auto result = sort_index->GetAllResults()) {
311+
return std::move(*result);
312+
}
313+
}
314+
315+
// If sort index doesn't exist or doesn't support GetAllResults, try regular index
316+
BaseIndex* base_index = indices_->GetIndex(active_field);
317+
if (base_index) {
318+
if (auto result = base_index->GetAllResults()) {
319+
return std::move(*result);
320+
}
321+
}
322+
323+
// If we get here, neither index could handle the request
324+
if (!base_index && !sort_index) {
325+
error_ = absl::StrCat("Invalid field: ", active_field);
326+
} else {
327+
error_ = absl::StrCat("Wrong access type for field: ", active_field);
328+
}
329+
return IndexResult{};
330+
}
331+
305332
IndexResult Search(const AstPrefixNode& node, string_view active_field) {
306333
vector<TextIndex*> indices;
307334
if (!active_field.empty()) {

src/core/search/search_test.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,45 @@ static void BM_VectorSearch(benchmark::State& state) {
805805

806806
BENCHMARK(BM_VectorSearch)->Args({120, 10'000});
807807

808+
TEST_F(SearchTest, MatchNonNullField) {
809+
PrepareSchema({{"text_field", SchemaField::TEXT},
810+
{"tag_field", SchemaField::TAG},
811+
{"num_field", SchemaField::NUMERIC}});
812+
813+
{
814+
PrepareQuery("@text_field:*");
815+
816+
ExpectAll(Map{{"text_field", "any value"}}, Map{{"text_field", "another value"}},
817+
Map{{"text_field", "third"}, {"tag_field", "tag1"}});
818+
819+
ExpectNone(Map{{"tag_field", "wrong field"}}, Map{{"num_field", "123"}}, Map{});
820+
821+
EXPECT_TRUE(Check()) << GetError();
822+
}
823+
824+
{
825+
PrepareQuery("@tag_field:*");
826+
827+
ExpectAll(Map{{"tag_field", "tag1"}}, Map{{"tag_field", "tag2"}},
828+
Map{{"text_field", "value"}, {"tag_field", "tag3"}});
829+
830+
ExpectNone(Map{{"text_field", "wrong field"}}, Map{{"num_field", "456"}}, Map{});
831+
832+
EXPECT_TRUE(Check()) << GetError();
833+
}
834+
835+
{
836+
PrepareQuery("@num_field:*");
837+
838+
ExpectAll(Map{{"num_field", "123"}}, Map{{"num_field", "456"}},
839+
Map{{"text_field", "value"}, {"num_field", "789"}});
840+
841+
ExpectNone(Map{{"text_field", "wrong field"}}, Map{{"tag_field", "tag1"}}, Map{});
842+
843+
EXPECT_TRUE(Check()) << GetError();
844+
}
845+
}
846+
808847
} // namespace search
809848

810849
} // namespace dfly

src/core/search/sort_indices.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,23 @@ template <typename T> struct SimpleValueSortIndex : public BaseSortIndex {
4141
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
4242
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
4343

44+
// Override GetAllResults to return all documents with non-null values
45+
std::optional<std::vector<DocId>> GetAllResults() const override {
46+
std::vector<DocId> result;
47+
48+
for (DocId id = 0; id < values_.size(); ++id) {
49+
// Check if id is not present in null_values_
50+
// Also need to handle deleted elements - in them T should be empty
51+
// Different types of T have different "empty" values, but we can check
52+
// if this value is the default for the given type
53+
if (!null_values_.contains(id) && !(values_[id] == T{})) {
54+
result.push_back(id);
55+
}
56+
}
57+
58+
return result;
59+
}
60+
4461
protected:
4562
virtual ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field_value) = 0;
4663

0 commit comments

Comments
 (0)