Skip to content

Commit 44c35bb

Browse files
committed
feat: Make timescale_vector pgai compatible (mostly)
1 parent f4a0aa3 commit 44c35bb

File tree

13 files changed

+2833
-55
lines changed

13 files changed

+2833
-55
lines changed

tests/async_client_test.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
@pytest.mark.asyncio
2020
@pytest.mark.parametrize("schema", ["temp", None])
2121
async def test_vector(service_url: str, schema: str) -> None:
22-
vec = Async(service_url, "data_table", 2, schema_name=schema)
22+
vec = Async(
23+
service_url, "data_table", 2, schema_name=schema, embedding_table_name="data_table", id_column_name="id"
24+
)
2325
await vec.drop_table()
2426
await vec.create_tables()
2527
empty = await vec.table_is_empty()
@@ -118,7 +120,7 @@ async def test_vector(service_url: str, schema: str) -> None:
118120

119121
assert isinstance(rec[0][SEARCH_RESULT_METADATA_IDX], dict)
120122
assert isinstance(rec[0]["metadata"], dict)
121-
assert rec[0]["contents"] == "the brown fox"
123+
assert rec[0]["chunk"] == "the brown fox"
122124

123125
rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(("key", "val2")))
124126
assert len(rec) == 1
@@ -256,7 +258,7 @@ async def test_vector(service_url: str, schema: str) -> None:
256258
await vec.drop_table()
257259
await vec.close()
258260

259-
vec = Async(service_url, "data_table", 2, id_type="TEXT")
261+
vec = Async(service_url, "data_table", 2, id_type="TEXT", embedding_table_name="data_table", id_column_name="id")
260262
await vec.create_tables()
261263
empty = await vec.table_is_empty()
262264
assert empty
@@ -269,7 +271,14 @@ async def test_vector(service_url: str, schema: str) -> None:
269271
await vec.drop_table()
270272
await vec.close()
271273

272-
vec = Async(service_url, "data_table", 2, time_partition_interval=timedelta(seconds=60))
274+
vec = Async(
275+
service_url,
276+
"data_table",
277+
2,
278+
time_partition_interval=timedelta(seconds=60),
279+
embedding_table_name="data_table",
280+
id_column_name="id",
281+
)
273282
await vec.create_tables()
274283
empty = await vec.table_is_empty()
275284
assert empty

tests/compatability_test.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import uuid
2+
from collections.abc import Generator
3+
4+
import numpy
5+
import psycopg2
6+
import pytest
7+
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
8+
9+
from tests.mocks import embeddings
10+
from tests.utils import test_file_path
11+
from timescale_vector import client
12+
13+
# To Generate a new dump in blog.sql:
14+
# Go through the quickstart in https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md
15+
# and run the following command:
16+
# docker compose exec db pg_dump \
17+
# -t public.blog \
18+
# -t public.blog_contents_embeddings_store \
19+
# -t public.blog_contents_embeddings \
20+
# --inserts \
21+
# --section=data \
22+
# --section=pre-data \
23+
# --no-table-access-method \
24+
# postgres > blog.sql
25+
26+
27+
@pytest.fixture(scope="module")
28+
def quickstart(service_url: str) -> Generator[None, None, None]:
29+
conn = psycopg2.connect(service_url)
30+
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
31+
32+
with conn.cursor() as cursor:
33+
cursor.execute("CREATE EXTENSION IF NOT EXISTS ai CASCADE;")
34+
cursor.execute("DROP VIEW IF EXISTS blog_contents_embeddings;")
35+
cursor.execute("DROP TABLE IF EXISTS blog_contents_embeddings_store;")
36+
cursor.execute("DROP TABLE IF EXISTS blog;")
37+
38+
with open(test_file_path + "/sample_tables/blog.sql") as f:
39+
sql = f.read()
40+
cursor.execute(sql)
41+
42+
yield # Run the tests
43+
44+
conn = psycopg2.connect(service_url)
45+
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
46+
47+
with conn.cursor() as cursor:
48+
cursor.execute("DROP VIEW IF EXISTS blog_contents_embeddings;")
49+
cursor.execute("DROP TABLE IF EXISTS blog_contents_embeddings_store;")
50+
cursor.execute("DROP TABLE IF EXISTS blog;")
51+
52+
conn.close()
53+
54+
55+
def format_array_for_pg(array: list[float]) -> str:
56+
formatted_values = [f"{x:g}" for x in array]
57+
58+
return f"ARRAY[{','.join(formatted_values)}]::vector"
59+
60+
61+
def test_semantic_search(quickstart: None, service_url: str): # noqa: ARG001
62+
conn = psycopg2.connect(service_url)
63+
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
64+
65+
with conn.cursor() as cursor:
66+
cursor.execute(f"""
67+
SELECT
68+
title,
69+
chunk,
70+
embedding <=> {format_array_for_pg(embeddings["artificial intelligence"])} as distance
71+
FROM blog_contents_embeddings
72+
ORDER BY distance
73+
LIMIT 3;
74+
""")
75+
76+
results = cursor.fetchall()
77+
78+
assert len(results) == 3
79+
assert "Artificial Intelligence" in results[0][0] # First result should be the AI article
80+
81+
cursor.execute(f"""
82+
SELECT
83+
title,
84+
chunk,
85+
embedding <=> {format_array_for_pg(embeddings["database technology"])} as distance
86+
FROM blog_contents_embeddings
87+
ORDER BY distance
88+
LIMIT 3;
89+
""")
90+
91+
results = cursor.fetchall()
92+
93+
# Verify that the PostgreSQL article comes first
94+
assert len(results) == 3
95+
assert "PostgreSQL" in results[0][0]
96+
97+
conn.close()
98+
99+
100+
def test_metadata_filtered_search(quickstart: None, service_url: str): # noqa: ARG001
101+
conn = psycopg2.connect(service_url)
102+
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
103+
104+
with conn.cursor() as cursor:
105+
cursor.execute(f"""
106+
SELECT
107+
title,
108+
chunk,
109+
metadata->>'read_time' as read_time,
110+
embedding <=> {format_array_for_pg(embeddings["technology"])} as distance
111+
FROM blog_contents_embeddings
112+
WHERE metadata->'tags' ? 'technology'
113+
ORDER BY distance
114+
LIMIT 2;
115+
""")
116+
117+
results = cursor.fetchall()
118+
119+
assert len(results) > 0
120+
titles = [row[0] for row in results]
121+
assert any("Artificial Intelligence" in title for title in titles)
122+
assert any("Cloud Computing" in title for title in titles)
123+
124+
conn.close()
125+
126+
127+
@pytest.fixture(scope="function")
128+
def sync_client(service_url: str) -> client.Sync:
129+
return client.Sync(service_url, "blog_contents_embeddings", 768)
130+
131+
132+
def test_basic_similarity_search(sync_client: client.Sync, quickstart: None): # noqa: ARG001
133+
results = sync_client.search(embeddings["artificial intelligence"], limit=3)
134+
135+
assert len(results) == 3
136+
# Verify the most relevant result is AI-related
137+
assert "AI" in results[0]["metadata"]["tags"]
138+
# Verify basic result structure
139+
assert all(isinstance(r["embedding_uuid"], uuid.UUID) for r in results)
140+
assert all(isinstance(r["chunk"], str) for r in results)
141+
assert all(isinstance(r["metadata"], dict) for r in results)
142+
assert all(isinstance(r["embedding"], numpy.ndarray) for r in results)
143+
assert all(isinstance(r["distance"], float) for r in results)
144+
145+
146+
def test_metadata_filter_search(sync_client: client.Sync, quickstart: None): # noqa: ARG001
147+
results = sync_client.search(
148+
embeddings["technology"],
149+
limit=2,
150+
filter={"read_time": 12}, # matches read_time exactly
151+
)
152+
153+
assert len(results) > 0
154+
assert all(result["metadata"]["read_time"] == 12 for result in results)
155+
156+
results = sync_client.search(
157+
embeddings["technology"],
158+
limit=3,
159+
filter=[{"read_time": 5}, {"read_time": 8}], # matches either read_time
160+
)
161+
162+
assert len(results) == 2
163+
assert all(result["metadata"]["read_time"] in [5, 8] for result in results)
164+
165+
results = sync_client.search(embeddings["technology"], limit=2, filter={"published_date": "2024-04-01"})
166+
167+
assert len(results) > 0
168+
assert all(result["metadata"]["published_date"] == "2024-04-01" for result in results)
169+
170+
171+
def test_predicate_search(sync_client: client.Sync, quickstart: None): # noqa: ARG001
172+
results = sync_client.search(embeddings["technology"], limit=2, predicates=client.Predicates("read_time", ">", 5))
173+
174+
assert len(results) > 0
175+
assert all(float(result["metadata"]["read_time"]) > 5 for result in results)
176+
177+
combined_results = sync_client.search(
178+
embeddings["technology"],
179+
limit=2,
180+
predicates=(client.Predicates("read_time", ">", 5) & client.Predicates("read_time", "<", 15)),
181+
)
182+
183+
assert len(combined_results) > 0
184+
assert all(5 < float(r["metadata"]["read_time"]) < 15 for r in combined_results)
185+
186+
187+
@pytest.mark.skip(
188+
"hard to make work because pgai has a foreign key to the original data which we dont pass in upsert atm"
189+
)
190+
def test_upsert_and_retrieve(sync_client: client.Sync, quickstart: None): # noqa: ARG001
191+
test_id = uuid.uuid1()
192+
test_content = "This is a test article about Python programming."
193+
test_embedding = [0.1] * 768
194+
195+
# Test upsert Todo: ? This breaks right now but users shouldn't have to manually manage embeddings anyways
196+
sync_client.upsert([(test_id, test_content, test_embedding)])
197+
results = sync_client.search(test_embedding, limit=1, filter={"tags": "test"})
198+
199+
assert len(results) == 1
200+
assert results[0]["id"] == test_id
201+
assert results[0]["chunk"] == test_content
202+
203+
sync_client.delete_by_ids([test_id])
204+
205+
206+
def test_delete_operations(sync_client: client.Sync, quickstart: None): # noqa: ARG001
207+
initial_results = sync_client.search(embeddings["database technology"], limit=1, filter={"read_time": 5})
208+
assert len(initial_results) > 0
209+
record_to_delete = initial_results[0]
210+
211+
sync_client.delete_by_ids([record_to_delete["embedding_uuid"]])
212+
results_after_delete = sync_client.search(embeddings["database technology"], limit=1, filter={"read_time": 5})
213+
assert len(results_after_delete) == 0
214+
215+
initial_health_results = sync_client.search(
216+
embeddings["artificial intelligence"], limit=1, filter={"read_time": 12}
217+
)
218+
assert len(initial_health_results) > 0
219+
220+
sync_client.delete_by_metadata({"read_time": 12})
221+
results_after_metadata_delete = sync_client.search(
222+
embeddings["artificial intelligence"], limit=1, filter={"read_time": 12}
223+
)
224+
assert len(results_after_metadata_delete) == 0
225+
226+
227+
@pytest.mark.skip("Makes no sense for the managed vector store?")
228+
def test_index_operations(sync_client: client.Sync, quickstart: None): # noqa: ARG001
229+
sync_client.create_embedding_index(client.DiskAnnIndex())
230+
231+
results = sync_client.search(
232+
embeddings["database technology"], limit=3, query_params=client.DiskAnnIndexParams(rescore=50)
233+
)
234+
235+
assert len(results) == 3
236+
tags = [result["metadata"]["tags"] for result in results]
237+
assert any("database" in t for t in tags)
238+
239+
results_with_params = sync_client.search(
240+
embeddings["database technology"],
241+
limit=3,
242+
query_params=client.DiskAnnIndexParams(rescore=100, search_list_size=20),
243+
)
244+
assert len(results_with_params) == 3
245+
246+
sync_client.drop_embedding_index()

0 commit comments

Comments
 (0)