Skip to content

Apache Cassandra® engine addition #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ NOTES.md

results/*
tools/custom/data.json

.venv
*.DS_Store
55 changes: 55 additions & 0 deletions engine/clients/cassandra/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Apache Cassandra®

## Setup
### Pre-requisites
Adjust the configuration file [`cassandra-single-node.json`](../../../experiments/configurations/cassandra-single-node.json) to your needs. The default configuration is set to use the default Apache Cassandra® settings.
### Start up the server
Run the following command to start the server (alternatively, run `docker compose up -d` to run in detached mode):
```bash
$ docker compose up

[+] Running 1/1
✔ cassandra Pulled 1.4s
[+] Running 1/1
✔ Container cassandra-benchmark Recreated 0.1s
Attaching to cassandra-benchmark
cassandra-benchmark | CompileCommand: dontinline org/apache/cassandra/db/Columns$Serializer.deserializeLargeSubset(Lorg/apache/cassandra/io/util/DataInputPlus;Lorg/apache/cassandra/db/Columns;I)Lorg/apache/cassandra/db/Columns; bool dontinline = true
...
cassandra-benchmark | INFO [main] 2025-04-04 22:27:38,592 StorageService.java:957 - Cassandra version: 5.0.3
cassandra-benchmark | INFO [main] 2025-04-04 22:27:38,592 StorageService.java:958 - Git SHA: b0226c8ea122c3e5ea8680efb0744d33924fd732
cassandra-benchmark | INFO [main] 2025-04-04 22:27:38,592 StorageService.java:959 - CQL version: 3.4.7
...
cassandra-benchmark | INFO [main] 2025-04-04 22:28:25,091 StorageService.java:3262 - Node /172.18.0.2:7000 state jump to NORMAL
```
> [!TIP]
> Other helpful commands:
> - Run `docker exec -it cassandra-benchmark cqlsh` to access the Cassandra shell.
> - Run `docker compose logs --follow --tail 10` to view & follow the logs of the container running Cassandra.

### Start up the client benchmark
Run the following command to start the client benchmark using `glove-25-angular` dataset as an example:
```bash
% python3 -m run --engines cassandra-single-node --datasets glove-25-angular
```
and you'll see the following output:
```bash
Running experiment: cassandra-single-node - glove-25-angular
Downloading http://ann-benchmarks.com/glove-25-angular.hdf5...
...
Experiment stage: Configure
Experiment stage: Upload
1183514it [mm:ss, <...>it/s]
Upload time: <...>
Index is not ready yet, sleeping for 2 minutes...
...
Total import time: <...>
Experiment stage: Search
10000it [mm:ss, <...>it/s]
10000it [mm:ss, <...>it/s]
Experiment stage: Done
Results saved to: /path/to/repository/results
...
```

> [!TIP]
> If you want to see the detailed results, set environment variable `DETAILED_RESULTS=1` before running the benchmark.
5 changes: 5 additions & 0 deletions engine/clients/cassandra/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from engine.clients.cassandra.configure import CassandraConfigurator
from engine.clients.cassandra.search import CassandraSearcher
from engine.clients.cassandra.upload import CassandraUploader

__all__ = ["CassandraConfigurator", "CassandraSearcher", "CassandraUploader"]
7 changes: 7 additions & 0 deletions engine/clients/cassandra/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os

CASSANDRA_KEYSPACE = os.getenv("CASSANDRA_KEYSPACE", "benchmark")
CASSANDRA_TABLE = os.getenv("CASSANDRA_TABLE", "vectors")
ASTRA_API_ENDPOINT = os.getenv("ASTRA_API_ENDPOINT", None)
ASTRA_API_KEY = os.getenv("ASTRA_API_KEY", None)
ASTRA_SCB_PATH = os.getenv("ASTRA_SCB_PATH", None)
102 changes: 102 additions & 0 deletions engine/clients/cassandra/configure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from cassandra import ConsistencyLevel, ProtocolVersion
from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile
from cassandra.policies import (
DCAwareRoundRobinPolicy,
ExponentialReconnectionPolicy,
TokenAwarePolicy,
)

from benchmark.dataset import Dataset
from engine.base_client.configure import BaseConfigurator
from engine.base_client.distances import Distance
from engine.clients.cassandra.config import CASSANDRA_KEYSPACE, CASSANDRA_TABLE


class CassandraConfigurator(BaseConfigurator):
SPARSE_VECTOR_SUPPORT = False
DISTANCE_MAPPING = {
Distance.L2: "euclidean",
Distance.COSINE: "cosine",
Distance.DOT: "dot_product",
}

def __init__(self, host, collection_params: dict, connection_params: dict):
super().__init__(host, collection_params, connection_params)

# Set up execution profiles for consistency and performance
profile = ExecutionProfile(
load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
consistency_level=ConsistencyLevel.LOCAL_QUORUM,
request_timeout=60,
)

# Initialize Cassandra cluster connection
self.cluster = Cluster(
contact_points=[host],
execution_profiles={EXEC_PROFILE_DEFAULT: profile},
protocol_version=ProtocolVersion.V4,
reconnection_policy=ExponentialReconnectionPolicy(
base_delay=1, max_delay=60
),
**connection_params,
)
self.session = self.cluster.connect()

def clean(self):
"""Drop the keyspace if it exists"""
self.session.execute(f"DROP KEYSPACE IF EXISTS {CASSANDRA_KEYSPACE}")

def recreate(self, dataset: Dataset, collection_params):
"""Create keyspace and table for vector search"""
# Create keyspace if not exists
self.session.execute(
f"""CREATE KEYSPACE IF NOT EXISTS {CASSANDRA_KEYSPACE}
WITH REPLICATION = {{ 'class': 'SimpleStrategy', 'replication_factor': 1 }}"""
)

# Use the keyspace
self.session.execute(f"USE {CASSANDRA_KEYSPACE}")

# Get the distance metric
distance_metric = self.DISTANCE_MAPPING.get(dataset.config.distance)
vector_size = dataset.config.vector_size

# Create vector table
# Using a simple schema that supports vector similarity search
self.session.execute(
f"""CREATE TABLE IF NOT EXISTS {CASSANDRA_TABLE} (
id int PRIMARY KEY,
embedding vector<float, {vector_size}>,
metadata map<text, text>
)"""
)

# Create vector index using the appropriate distance metric
self.session.execute(
f"""CREATE CUSTOM INDEX IF NOT EXISTS vector_index ON {CASSANDRA_TABLE}(embedding)
USING 'StorageAttachedIndex'
WITH OPTIONS = {{ 'similarity_function': '{distance_metric}' }}"""
)

# Add additional schema fields based on collection_params if needed
for field_name, field_type in dataset.config.schema.items():
if field_type in ["keyword", "text"]:
# For text fields, we would typically add them to metadata
pass
elif field_type in ["int", "float"]:
# For numeric fields that need separate indexing
# In a real implementation, we might alter the table to add these columns
pass

return collection_params

def execution_params(self, distance, vector_size) -> dict:
"""Return any execution parameters needed for the dataset"""
return {"normalize": distance == Distance.COSINE}

def delete_client(self):
"""Close the Cassandra connection"""
if hasattr(self, "session") and self.session:
self.session.shutdown()
if hasattr(self, "cluster") and self.cluster:
self.cluster.shutdown()
92 changes: 92 additions & 0 deletions engine/clients/cassandra/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Any, List, Optional

from engine.base_client.parser import BaseConditionParser, FieldValue


class CassandraConditionParser(BaseConditionParser):
def build_condition(
self, and_subfilters: Optional[List[Any]], or_subfilters: Optional[List[Any]]
) -> Optional[Any]:
"""
Build a CQL condition expression that combines AND and OR subfilters
"""
conditions = []

# Add AND conditions
if and_subfilters and len(and_subfilters) > 0:
and_conds = " AND ".join([f"({cond})" for cond in and_subfilters if cond])
if and_conds:
conditions.append(f"({and_conds})")

# Add OR conditions
if or_subfilters and len(or_subfilters) > 0:
or_conds = " OR ".join([f"({cond})" for cond in or_subfilters if cond])
if or_conds:
conditions.append(f"({or_conds})")

# Combine all conditions
if not conditions:
return None

return " AND ".join(conditions)

def build_exact_match_filter(self, field_name: str, value: FieldValue) -> Any:
"""
Build a CQL exact match condition for metadata fields
For Cassandra, we format metadata as a map with string values
"""
if isinstance(value, str):
return f"metadata['{field_name}'] = '{value}'"
else:
return f"metadata['{field_name}'] = '{str(value)}'"

def build_range_filter(
self,
field_name: str,
lt: Optional[FieldValue],
gt: Optional[FieldValue],
lte: Optional[FieldValue],
gte: Optional[FieldValue],
) -> Any:
"""
Build a CQL range filter condition
"""
conditions = []

if lt is not None:
if isinstance(lt, str):
conditions.append(f"metadata['{field_name}'] < '{lt}'")
else:
conditions.append(f"metadata['{field_name}'] < '{str(lt)}'")

if gt is not None:
if isinstance(gt, str):
conditions.append(f"metadata['{field_name}'] > '{gt}'")
else:
conditions.append(f"metadata['{field_name}'] > '{str(gt)}'")

if lte is not None:
if isinstance(lte, str):
conditions.append(f"metadata['{field_name}'] <= '{lte}'")
else:
conditions.append(f"metadata['{field_name}'] <= '{str(lte)}'")

if gte is not None:
if isinstance(gte, str):
conditions.append(f"metadata['{field_name}'] >= '{gte}'")
else:
conditions.append(f"metadata['{field_name}'] >= '{str(gte)}'")

return " AND ".join(conditions)

def build_geo_filter(
self, field_name: str, lat: float, lon: float, radius: float
) -> Any:
"""
Build a CQL geo filter condition
Note: Basic Cassandra doesn't have built-in geo filtering.
This is a simplified approach that won't actually work without extensions.
"""
# In a real implementation with a geo extension, we'd implement proper geo filtering
# For this benchmark, we'll return a placeholder condition that doesn't filter
return "1=1" # Always true condition as a placeholder
123 changes: 123 additions & 0 deletions engine/clients/cassandra/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import multiprocessing as mp
from typing import List, Tuple

from cassandra import ConsistencyLevel, ProtocolVersion
from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile
from cassandra.policies import (
DCAwareRoundRobinPolicy,
ExponentialReconnectionPolicy,
TokenAwarePolicy,
)

from dataset_reader.base_reader import Query
from engine.base_client.distances import Distance
from engine.base_client.search import BaseSearcher
from engine.clients.cassandra.config import CASSANDRA_KEYSPACE, CASSANDRA_TABLE
from engine.clients.cassandra.parser import CassandraConditionParser


class CassandraSearcher(BaseSearcher):
search_params = {}
session = None
cluster = None
parser = CassandraConditionParser()

@classmethod
def init_client(cls, host, distance, connection_params: dict, search_params: dict):
# Set up execution profiles for consistency and performance
profile = ExecutionProfile(
load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
consistency_level=ConsistencyLevel.LOCAL_ONE,
request_timeout=60,
)

# Initialize Cassandra cluster connection
cls.cluster = Cluster(
contact_points=[host],
execution_profiles={EXEC_PROFILE_DEFAULT: profile},
reconnection_policy=ExponentialReconnectionPolicy(
base_delay=1, max_delay=60
),
protocol_version=ProtocolVersion.V4,
**connection_params,
)
cls.session = cls.cluster.connect(CASSANDRA_KEYSPACE)
cls.search_params = search_params

# Update prepared statements with current search parameters
cls.update_prepared_statements(distance)

@classmethod
def get_mp_start_method(cls):
return "fork" if "fork" in mp.get_all_start_methods() else "spawn"

@classmethod
def update_prepared_statements(cls, distance):
"""Create prepared statements for vector searches"""
# Prepare a vector similarity search query
limit = cls.search_params.get("top", 10)

if distance == Distance.COSINE:
SIMILARITY_FUNC = "similarity_cosine"
elif distance == Distance.L2:
SIMILARITY_FUNC = "similarity_euclidean"
elif distance == Distance.DOT:
SIMILARITY_FUNC = "similarity_dot_product"
else:
raise ValueError(f"Unsupported distance metric: {distance}")

cls.ann_search_stmt = cls.session.prepare(
f"""SELECT id, {SIMILARITY_FUNC}(embedding, ?) as distance
FROM {CASSANDRA_TABLE}
ORDER BY embedding ANN OF ?
LIMIT {limit}"""
)

# Prepare a statement for filtered vector search
cls.filtered_search_query_template = f"""SELECT id, {SIMILARITY_FUNC}(embedding, ?) as distance
FROM {CASSANDRA_TABLE}
WHERE {{conditions}}
ORDER BY embedding ANN OF ?
LIMIT {limit}"""

@classmethod
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
"""Execute a vector similarity search with optional filters"""
# Convert query vector to a format Cassandra can use
query_vector = (
query.vector.tolist() if hasattr(query.vector, "tolist") else query.vector
)

# Generate filter conditions if metadata conditions exist
filter_conditions = cls.parser.parse(query.meta_conditions)

try:
if filter_conditions:
# Use the filtered search query
query_with_conditions = cls.filtered_search_query_template.format(
conditions=filter_conditions
)
results = cls.session.execute(
cls.session.prepare(query_with_conditions),
(query_vector, query_vector),
)
else:
# Use the basic ANN search query
results = cls.session.execute(
cls.ann_search_stmt, (query_vector, query_vector)
)

# Extract and return results
return [(row.id, row.distance) for row in results]

except Exception as ex:
print(f"Error during Cassandra vector search: {ex}")
raise ex

@classmethod
def delete_client(cls):
"""Close the Cassandra connection"""
if cls.session:
cls.session.shutdown()
if cls.cluster:
cls.cluster.shutdown()
Loading