Skip to content

WIP: Configure tsvector experiment (search only) #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions engine/clients/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
)
from engine.clients.qdrant import QdrantConfigurator, QdrantSearcher, QdrantUploader
from engine.clients.redis import RedisConfigurator, RedisSearcher, RedisUploader
from engine.clients.tsvector import (
TsVectorConfigurator,
TsVectorSearcher,
TsVectorUploader,
)
from engine.clients.weaviate import (
WeaviateConfigurator,
WeaviateSearcher,
Expand All @@ -39,6 +44,7 @@
"opensearch": OpenSearchConfigurator,
"redis": RedisConfigurator,
"pgvector": PgVectorConfigurator,
"tsvector": TsVectorConfigurator,
}

ENGINE_UPLOADERS = {
Expand All @@ -49,6 +55,7 @@
"opensearch": OpenSearchUploader,
"redis": RedisUploader,
"pgvector": PgVectorUploader,
"tsvector": TsVectorUploader,
}

ENGINE_SEARCHERS = {
Expand All @@ -59,6 +66,7 @@
"opensearch": OpenSearchSearcher,
"redis": RedisSearcher,
"pgvector": PgVectorSearcher,
"tsvector": TsVectorSearcher,
}


Expand Down
6 changes: 4 additions & 2 deletions engine/clients/qdrant/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from benchmark.dataset import Dataset
from engine.base_client.configure import BaseConfigurator
from engine.base_client.distances import Distance
from engine.clients.qdrant.config import QDRANT_COLLECTION_NAME, QDRANT_API_KEY
from engine.clients.qdrant.config import QDRANT_API_KEY, QDRANT_COLLECTION_NAME


class QdrantConfigurator(BaseConfigurator):
Expand Down Expand Up @@ -32,7 +32,9 @@ class QdrantConfigurator(BaseConfigurator):
def __init__(self, host, collection_params: dict, connection_params: dict):
super().__init__(host, collection_params, connection_params)

self.client = QdrantClient(url=host, api_key=QDRANT_API_KEY, **connection_params)
self.client = QdrantClient(
url=host, api_key=QDRANT_API_KEY, **connection_params
)

def clean(self):
self.client.delete_collection(collection_name=QDRANT_COLLECTION_NAME)
Expand Down
10 changes: 5 additions & 5 deletions engine/clients/qdrant/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from typing import List, Tuple

import httpx
from qdrant_client import QdrantClient
from qdrant_client import QdrantClient, models
from qdrant_client._pydantic_compat import construct
from qdrant_client import models

from dataset_reader.base_reader import Query
from engine.base_client.search import BaseSearcher
from engine.clients.qdrant.config import QDRANT_COLLECTION_NAME, QDRANT_API_KEY
from engine.clients.qdrant.config import QDRANT_API_KEY, QDRANT_COLLECTION_NAME
from engine.clients.qdrant.parser import QdrantConditionParser


Expand Down Expand Up @@ -48,7 +47,6 @@ def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
values=query.sparse_vector.values,
)


prefetch = cls.search_params.get("prefetch")

if prefetch:
Expand All @@ -65,7 +63,9 @@ def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
query=query_vector,
query_filter=cls.parser.parse(query.meta_conditions),
limit=top,
search_params=models.SearchParams(**cls.search_params.get("config", {})),
search_params=models.SearchParams(
**cls.search_params.get("config", {})
),
with_payload=cls.search_params.get("with_payload", False),
)
except Exception as ex:
Expand Down
6 changes: 4 additions & 2 deletions engine/clients/qdrant/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from dataset_reader.base_reader import Record
from engine.base_client.upload import BaseUploader
from engine.clients.qdrant.config import QDRANT_COLLECTION_NAME, QDRANT_API_KEY
from engine.clients.qdrant.config import QDRANT_API_KEY, QDRANT_COLLECTION_NAME


class QdrantUploader(BaseUploader):
Expand All @@ -24,7 +24,9 @@ class QdrantUploader(BaseUploader):
def init_client(cls, host, distance, connection_params, upload_params):
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "true"
os.environ["GRPC_POLL_STRATEGY"] = "epoll,poll"
cls.client = QdrantClient(url=host, prefer_grpc=True, api_key=QDRANT_API_KEY, **connection_params)
cls.client = QdrantClient(
url=host, prefer_grpc=True, api_key=QDRANT_API_KEY, **connection_params
)
cls.upload_params = upload_params

@classmethod
Expand Down
9 changes: 9 additions & 0 deletions engine/clients/tsvector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from engine.clients.tsvector.configure import TsVectorConfigurator
from engine.clients.tsvector.search import TsVectorSearcher
from engine.clients.tsvector.upload import TsVectorUploader

__all__ = [
"TsVectorConfigurator",
"TsVectorSearcher",
"TsVectorUploader",
]
18 changes: 18 additions & 0 deletions engine/clients/tsvector/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

TSVECTOR_PORT = int(os.getenv("TSVECTOR_PORT", 5432))
TSVECTOR_DB = os.getenv("TSVECTOR_DB", "ann")
TSVECTOR_USER = os.getenv("TSVECTOR_USER", "ann")
TSVECTOR_PASSWORD = os.getenv("TSVECTOR_PASSWORD", "ann")


def get_db_config(host, connection_params):
return {
"host": host or "localhost",
"port": TSVECTOR_PORT,
"dbname": TSVECTOR_DB,
"user": TSVECTOR_USER,
"password": TSVECTOR_PASSWORD,
"autocommit": True,
**connection_params,
}
27 changes: 27 additions & 0 deletions engine/clients/tsvector/configure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pgvector.psycopg
import psycopg

from benchmark.dataset import Dataset
from engine.base_client.configure import BaseConfigurator
from engine.clients.tsvector.config import get_db_config


class TsVectorConfigurator(BaseConfigurator):
def __init__(self, host, collection_params: dict, connection_params: dict):
super().__init__(host, collection_params, connection_params)
self.conn = psycopg.connect(**get_db_config(host, connection_params))
print("configure connection created")
self.conn.execute("CREATE EXTENSION IF NOT EXISTS pg_prewarm;")
self.conn.execute("create extension if not exists timescaledb;")
self.conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
self.conn.execute("create extension if not exists vectorscale cascade;")
pgvector.psycopg.register_vector(self.conn)

def clean(self):
print("Make sure we don't remove the table")

def recreate(self, dataset: Dataset, collection_params):
print("Make sure we don't remove the table")

def delete_client(self):
self.conn.close()
46 changes: 46 additions & 0 deletions engine/clients/tsvector/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
from typing import Any, List, Optional

from engine.base_client import IncompatibilityError
from engine.base_client.parser import BaseConditionParser, FieldValue


class TsVectorConditionParser(BaseConditionParser):
def build_condition(
self, and_subfilters: Optional[List[Any]], or_subfilters: Optional[List[Any]]
) -> Optional[Any]:
clauses = []
if or_subfilters is not None and len(or_subfilters) > 0:
clauses.append(f"( {' OR '.join(or_subfilters)} )")
if and_subfilters is not None and len(and_subfilters) > 0:
clauses.append(f"( {' AND '.join(and_subfilters)} )")

return " AND ".join(clauses)

def build_exact_match_filter(self, field_name: str, value: FieldValue) -> Any:
return f"{field_name} == {json.dumps(value)}"

def build_range_filter(
self,
field_name: str,
lt: Optional[FieldValue],
gt: Optional[FieldValue],
lte: Optional[FieldValue],
gte: Optional[FieldValue],
) -> Any:
clauses = []
if lt is not None:
clauses.append(f"{field_name} < {lt}")
if gt is not None:
clauses.append(f"{field_name} > {gt}")
if lte is not None:
clauses.append(f"{field_name} <= {lte}")
if gte is not None:
clauses.append(f"{field_name} >= {gte}")
return f"( {' AND '.join(clauses)} )"

def build_geo_filter(
self, field_name: str, lat: float, lon: float, radius: float
) -> Any:
# TODO: Implement this
raise IncompatibilityError
96 changes: 96 additions & 0 deletions engine/clients/tsvector/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import List, Tuple

import numpy as np
import psycopg
from pgvector.psycopg import register_vector

from dataset_reader.base_reader import Query
from engine.base_client.distances import Distance
from engine.base_client.search import BaseSearcher
from engine.clients.tsvector.config import get_db_config
from engine.clients.tsvector.parser import TsVectorConditionParser

CONNECTION_SETTINGS = [
"set work_mem = '2GB';",
"set maintenance_work_mem = '8GB';" "set max_parallel_workers_per_gather = 0;",
"set enable_seqscan=0;",
"set jit = 'off';",
]


class TsVectorSearcher(BaseSearcher):
conn = None
cur = None
distance = None
search_params = {}
parser = TsVectorConditionParser()

@classmethod
def init_client(cls, host, distance, connection_params: dict, search_params: dict):
cls.distance = distance
cls.conn = psycopg.connect(**get_db_config(host, connection_params))
register_vector(cls.conn)
cls.cur = cls.conn.cursor()

if distance == Distance.COSINE:
cls.query = "SELECT id, embedding <=> %s AS _score FROM items ORDER BY _score LIMIT %s"
elif distance == Distance.L2:
cls.query = "SELECT id, embedding <-> %s AS _score FROM items ORDER BY _score LIMIT %s"
else:
raise NotImplementedError(f"Unsupported distance metric {cls.distance}")

cls.cur.execute(
"set diskann.query_search_list_size = %d"
% search_params["query_search_list_size"]
)
print(
"set diskann.query_search_list_size = %d"
% search_params["query_search_list_size"]
)
cls.cur.execute(
"set diskann.query_rescore = %d" % search_params["query_rescore"]
)
print("set diskann.query_rescore = %d" % search_params["query_rescore"])

for setting in CONNECTION_SETTINGS:
cls.cur.execute(setting)

print("Prewarming...")
cls.cur.execute(
"select format($$%I.%I$$, chunk_schema, chunk_name) from timescaledb_information.chunks k where hypertable_name = 'items'"
)
chunks = [row[0] for row in cls.cur]
for chunk in chunks:
print(f"prewarming chunk heap {chunk}")
cls.cur.execute(f"select pg_prewarm('{chunk}'::regclass, mode=>'buffer')")
cls.cur.fetchall()

cls.cur.execute(
"""
select format($$%I.%I$$, x.schemaname, x.indexname)
from timescaledb_information.chunks k
inner join pg_catalog.pg_indexes x on (k.chunk_schema = x.schemaname and k.chunk_name = x.tablename)
where x.indexname ilike '%_embedding_%'
and k.hypertable_name = 'items'"""
)
chunks = [row[0] for row in cls.cur]
for chunk_index in chunks:
print(f"prewarming chunk index {chunk_index}")
cls.cur.execute(
f"select pg_prewarm('{chunk_index}'::regclass, mode=>'buffer')"
)
cls.cur.fetchall()

@classmethod
def search_one(cls, query: Query, top) -> List[Tuple[int, float]]:
# TODO: Use query.metaconditions for datasets with filtering
cls.cur.execute(
cls.query, (np.array(query.vector), top), binary=True, prepare=True
)
return cls.cur.fetchall()

@classmethod
def delete_client(cls):
if cls.cur:
cls.cur.close()
cls.conn.close()
41 changes: 41 additions & 0 deletions engine/clients/tsvector/upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import List

import psycopg
from pgvector.psycopg import register_vector

from dataset_reader.base_reader import Record
from engine.base_client.distances import Distance
from engine.base_client.upload import BaseUploader
from engine.clients.tsvector.config import get_db_config


class TsVectorUploader(BaseUploader):
DISTANCE_MAPPING = {
Distance.L2: "vector_l2_ops",
Distance.COSINE: "vector_cosine_ops",
}
conn = None
cur = None
upload_params = {}

@classmethod
def init_client(cls, host, distance, connection_params, upload_params):
cls.conn = psycopg.connect(**get_db_config(host, connection_params))
register_vector(cls.conn)
cls.cur = cls.conn.cursor()
cls.upload_params = upload_params

@classmethod
def upload_batch(cls, batch: List[Record]):
print("Data upload already done")

@classmethod
def post_upload(cls, distance):
print("Already done")
return {}

@classmethod
def delete_client(cls):
if cls.cur:
cls.cur.close()
cls.conn.close()
12 changes: 12 additions & 0 deletions experiments/configurations/tsvector-single-node.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[
{
"name": "tsvector-q-search-75-rescore-400",
"engine": "tsvector",
"connection_params": {},
"collection_params": {},
"search_params": [
{ "parallel": 8, "config": { "hnsw_ef": 128 }, "query_search_list_size": 75, "query_rescore": 400 }
],
"upload_params": { "parallel": 16, "batch_size": 1024, "hnsw_config": { "m": 16, "ef_construct": 128 } }
}
]