Skip to content

Commit

Permalink
Merge pull request #1 from majamil16/feat/add-schema-on-collection
Browse files Browse the repository at this point in the history
update - move support for schema onto collection instead of client
  • Loading branch information
majamil16 authored Jan 15, 2024
2 parents bc28230 + 0745093 commit fed1fce
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 33 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ repos:
hooks:
- id: autoflake
args: ['--in-place', '--remove-all-unused-imports']
language_version: python3.9

- repo: https://github.com/ambv/black
rev: 22.10.0
Expand Down
4 changes: 3 additions & 1 deletion src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
PYTEST_DB = "postgresql://postgres:password@localhost:5611/vecs_db"
PYTEST_SCHEMA = "test_schema"


@pytest.fixture(scope="session")
def maybe_start_pg() -> Generator[None, None, None]:
"""Creates a postgres 15 docker container that can be connected
Expand Down Expand Up @@ -94,12 +95,13 @@ def maybe_start_pg() -> Generator[None, None, None]:
def clean_db(maybe_start_pg: None) -> Generator[str, None, None]:
eng = create_engine(PYTEST_DB)
with eng.begin() as connection:
connection.execute(text("drop schema if exists vecs cascade;"))
connection.execute(text(f"drop schema if exists {PYTEST_SCHEMA} cascade;"))
yield PYTEST_DB
eng.dispose()


@pytest.fixture(scope="function")
def client(clean_db: str) -> Generator[vecs.Client, None, None]:
client_ = vecs.create_client(clean_db, PYTEST_SCHEMA)
client_ = vecs.create_client(clean_db)
yield client_
13 changes: 7 additions & 6 deletions src/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import pytest
import vecs

def test_create_client(clean_db) -> None:
client = vecs.create_client(clean_db)
assert client.schema == "vecs"
import vecs

client = vecs.create_client(clean_db, "my_schema")
assert client.schema == "my_schema"

def test_extracts_vector_version(client: vecs.Client) -> None:
# pgvector version is sucessfully extracted
Expand Down Expand Up @@ -34,11 +29,17 @@ def test_get_collection(client: vecs.Client) -> None:


def test_list_collections(client: vecs.Client) -> None:
"""
Test list_collections returns appropriate results for default schema (vecs) and custom schema
"""
assert len(client.list_collections()) == 0
client.get_or_create_collection(name="docs", dimension=384)
client.get_or_create_collection(name="books", dimension=1586)
client.get_or_create_collection(name="movies", schema="test_schema", dimension=384)
collections = client.list_collections()
collections_test_schema = client.list_collections(schema="test_schema")
assert len(collections) == 2
assert len(collections_test_schema) == 1


def test_delete_collection(client: vecs.Client) -> None:
Expand Down
66 changes: 66 additions & 0 deletions src/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,3 +815,69 @@ def test_hnsw_unavailable_error(client: vecs.Client) -> None:
bar = client.get_or_create_collection(name="bar", dimension=dim)
with pytest.raises(ArgError):
bar.create_index(method=IndexMethod.hnsw)


def test_get_or_create_with_schema(client: vecs.Client):
"""
Test that get_or_create_collection works when specifying custom schema
"""

dim = 384

collection_1 = client.get_or_create_collection(
name="collection_1", schema="test_schema", dimension=dim
)
collection_2 = client.get_or_create_collection(
name="collection_1", schema="test_schema", dimension=dim
)

assert collection_1.schema == "test_schema"
assert collection_1.schema == collection_2.schema
assert collection_1.name == collection_2.name


def test_upsert_with_schema(client: vecs.Client) -> None:
n_records = 100
dim = 384

movies1 = client.get_or_create_collection(
name="ping", schema="test_schema", dimension=dim
)
movies2 = client.get_or_create_collection(name="ping", schema="vecs", dimension=dim)

# collection initially empty
assert len(movies1) == 0
assert len(movies2) == 0

records = [
(
f"vec{ix}",
vec,
{
"genre": random.choice(["action", "rom-com", "drama"]),
"year": int(50 * random.random()) + 1970,
},
)
for ix, vec in enumerate(np.random.random((n_records, dim)))
]

# insert works
movies1.upsert(records)
assert len(movies1) == n_records

movies2.upsert(records)
assert len(movies2) == n_records

# upserting overwrites
new_record = ("vec0", np.zeros(384), {})
movies1.upsert([new_record])
db_record = movies1["vec0"]
db_record[0] == new_record[0]
db_record[1] == new_record[1]
db_record[2] == new_record[2]

movies2.upsert([new_record])
db_record = movies2["vec0"]
db_record[0] == new_record[0]
db_record[1] == new_record[1]
db_record[2] == new_record[2]
9 changes: 6 additions & 3 deletions src/vecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
]


def create_client(connection_string: str, schema: str="vecs") -> Client:
"""Creates a client from a Postgres connection string"""
return Client(connection_string=connection_string, schema=schema)
def create_client(connection_string: str) -> Client:
"""
Creates a client from a Postgres connection string and optional schema.
Defaults to `vecs` schema.
"""
return Client(connection_string=connection_string)
25 changes: 12 additions & 13 deletions src/vecs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, List, Optional

from deprecated import deprecated
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker

from vecs.adapter import Adapter
Expand Down Expand Up @@ -47,24 +47,21 @@ class Client:
vx.disconnect()
"""

def __init__(self, connection_string: str, schema: str):
def __init__(self, connection_string: str):
"""
Initialize a Client instance.
Args:
connection_string (str): A string representing the database connection information.
schema (str): A string representing the database schema to connect to.
Returns:
None
"""
self.schema = schema
self.engine = create_engine(connection_string)
self.meta = MetaData(schema=self.schema)
self.Session = sessionmaker(self.engine)

with self.Session() as sess:
with sess.begin():
sess.execute(text(f"create schema if not exists {self.schema};"))
sess.execute(text("create schema if not exists vecs;"))
sess.execute(text("create extension if not exists vector;"))
self.vector_version: str = sess.execute(
text(
Expand All @@ -84,6 +81,7 @@ def _supports_hnsw(self):
def get_or_create_collection(
self,
name: str,
schema: str = "vecs",
*,
dimension: Optional[int] = None,
adapter: Optional[Adapter] = None,
Expand All @@ -106,14 +104,15 @@ def get_or_create_collection(
CollectionAlreadyExists: If a collection with the same name already exists
"""
from vecs.collection import Collection

adapter_dimension = adapter.exported_dimension if adapter else None

collection = Collection(
name=name,
dimension=dimension or adapter_dimension, # type: ignore
client=self,
adapter=adapter,
schema=schema,
)

return collection._create_if_not_exists()
Expand Down Expand Up @@ -163,7 +162,7 @@ def get_collection(self, name: str) -> Collection:
join pg_attribute pa
on pc.oid = pa.attrelid
where
pc.relnamespace = '{self.schema}'::regnamespace
pc.relnamespace = 'vecs'::regnamespace
and pc.relkind = 'r'
and pa.attname = 'vec'
and not pc.relname ^@ '_'
Expand All @@ -183,18 +182,18 @@ def get_collection(self, name: str) -> Collection:
self,
)

def list_collections(self) -> List["Collection"]:
def list_collections(self, schema: str = "vecs") -> List["Collection"]:
"""
List all vector collections.
List all vector collections by database schema.
Returns:
list[Collection]: A list of all collections.
"""
from vecs.collection import Collection

return Collection._list_collections(self)
return Collection._list_collections(self, schema)

def delete_collection(self, name: str) -> None:
def delete_collection(self, name: str, schema: str = "vecs") -> None:
"""
Delete a vector collection.
Expand All @@ -208,7 +207,7 @@ def delete_collection(self, name: str) -> None:
"""
from vecs.collection import Collection

Collection(name, -1, self)._drop()
Collection(name, -1, self, schema=schema)._drop()
return

def disconnect(self) -> None:
Expand Down
31 changes: 21 additions & 10 deletions src/vecs/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
dimension: int,
client: Client,
adapter: Optional[Adapter] = None,
schema: Optional[str] = "vecs",
):
"""
Initializes a new instance of the `Collection` class.
Expand All @@ -174,7 +175,9 @@ def __init__(
self.client = client
self.name = name
self.dimension = dimension
self.table = build_table(name, client.meta, dimension)
self.schema = schema
self.meta = MetaData(schema=self.schema)
self.table = build_table(name, self.meta, dimension)
self._index: Optional[str] = None
self.adapter = adapter or Adapter(steps=[NoOp(dimension=dimension)])

Expand All @@ -195,6 +198,10 @@ def __init__(
"Dimensions reported by adapter, dimension, and collection do not match"
)

with self.client.Session() as sess:
with sess.begin():
sess.execute(text(f"create schema if not exists {self.schema};"))

def __repr__(self):
"""
Returns a string representation of the `Collection` instance.
Expand Down Expand Up @@ -235,7 +242,7 @@ def _create_if_not_exists(self):
join pg_attribute pa
on pc.oid = pa.attrelid
where
pc.relnamespace = '{self.client.schema}'::regnamespace
pc.relnamespace = '{self.schema}'::regnamespace
and pc.relkind = 'r'
and pa.attname = 'vec'
and not pc.relname ^@ '_'
Expand Down Expand Up @@ -285,11 +292,12 @@ def _create(self):

unique_string = str(uuid.uuid4()).replace("-", "_")[0:7]
with self.client.Session() as sess:
sess.execute(text(f"create schema if not exists {self.schema};"))
sess.execute(
text(
f"""
create index ix_meta_{unique_string}
on {self.client.schema}."{self.table.name}"
on {self.schema}."{self.table.name}"
using gin ( metadata jsonb_path_ops )
"""
)
Expand Down Expand Up @@ -562,17 +570,18 @@ def query(
return sess.execute(stmt).fetchall() or []

@classmethod
def _list_collections(cls, client: "Client") -> List["Collection"]:
def _list_collections(cls, client: "Client", schema: str) -> List["Collection"]:
"""
PRIVATE
Retrieves all collections from the database.
Args:
client (Client): The database client.
schema (str): The database schema to query.
Returns:
List[Collection]: A list of all existing collections.
List[Collection]: A list of all existing collections within the specified schema.
"""

query = text(
Expand All @@ -585,7 +594,7 @@ def _list_collections(cls, client: "Client") -> List["Collection"]:
join pg_attribute pa
on pc.oid = pa.attrelid
where
pc.relnamespace = '{client.schema}'::regnamespace
pc.relnamespace = '{schema}'::regnamespace
and pc.relkind = 'r'
and pa.attname = 'vec'
and not pc.relname ^@ '_'
Expand Down Expand Up @@ -642,7 +651,7 @@ def index(self) -> Optional[str]:
from
pg_class pc
where
pc.relnamespace = '{self.client.schema}'::regnamespace
pc.relnamespace = '{self.schema}'::regnamespace
and relname ilike 'ix_vector%'
and pc.relkind = 'i'
"""
Expand Down Expand Up @@ -760,7 +769,9 @@ def create_index(
with sess.begin():
if self.index is not None:
if replace:
sess.execute(text(f'drop index "{self.client.schema}"."{self.index}";'))
sess.execute(
text(f'drop index "{self.schema}"."{self.index}";')
)
self._index = None
else:
raise ArgError("replace is set to False but an index exists")
Expand All @@ -787,7 +798,7 @@ def create_index(
text(
f"""
create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string}
on {self.client.schema}."{self.table.name}"
on {self.schema}."{self.table.name}"
using ivfflat (vec {ops}) with (lists={n_lists})
"""
)
Expand All @@ -806,7 +817,7 @@ def create_index(
text(
f"""
create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string}
on {self.client.schema}."{self.table.name}"
on {self.schema}."{self.table.name}"
using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction});
"""
)
Expand Down

0 comments on commit fed1fce

Please sign in to comment.