diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index b59c87b..393179f 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -1,9 +1,10 @@ import pytest import numpy as np +import sqlalchemy from sqlalchemy import URL, create_engine, Column, Integer, select from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.exc import OperationalError -from tidb_vector.sqlalchemy import VectorType, VectorAdaptor +from tidb_vector.sqlalchemy import VectorType, VectorAdaptor, VectorIndex import tidb_vector from ..config import TestConfig @@ -385,3 +386,97 @@ def test_index_and_search(self): ) assert len(items) == 2 assert items[0].distance == 0.0 + + +class TestSQLAlchemyVectorIndex: + + def setup_class(self): + Item2Model.__table__.drop(bind=engine, checkfirst=True) + Item2Model.__table__.create(bind=engine) + + def teardown_class(self): + Item2Model.__table__.drop(bind=engine, checkfirst=True) + + def test_create_vector_index_statement(self): + from sqlalchemy.sql.ddl import CreateIndex + l2_index = VectorIndex( + "idx_embedding_l2", + sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding), + ) + compiled = CreateIndex(l2_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE VECTOR INDEX idx_embedding_l2 ON sqlalchemy_item2 ((vec_l2_distance(embedding))) ADD_TIFLASH_ON_DEMAND" + + cos_index = VectorIndex( + "idx_embedding_cos", + sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding), + ) + compiled = CreateIndex(cos_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE VECTOR INDEX idx_embedding_cos ON sqlalchemy_item2 ((vec_cosine_distance(embedding))) ADD_TIFLASH_ON_DEMAND" + + # non-vector index + normal_index = sqlalchemy.schema.Index("idx_unique", Item2Model.__table__.c.id, unique=True) + compiled = CreateIndex(normal_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE UNIQUE INDEX idx_unique ON sqlalchemy_item2 (id)" + + def test_query_with_index(self): + # indexes + l2_index = VectorIndex( + "idx_embedding_l2", + sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding), + ) + l2_index.create(engine) + cos_index = VectorIndex( + "idx_embedding_cos", + sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding), + ) + cos_index.create(engine) + + self.check_indexes( + Item2Model.__table__, ["idx_embedding_l2", "idx_embedding_cos"] + ) + + with Session() as session: + session.add_all( + [Item2Model(embedding=[1, 2, 3]), Item2Model(embedding=[1, 2, 3.2])] + ) + session.commit() + + # l2 distance + result_l2 = session.scalars( + select(Item2Model).filter( + Item2Model.embedding.l2_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_l2) == 2 + + distance_l2 = Item2Model.embedding.l2_distance([1, 2, 3]) + items_l2 = ( + session.query(Item2Model.id, distance_l2.label("distance")) + .order_by(distance_l2) + .limit(5) + .all() + ) + assert len(items_l2) == 2 + assert items_l2[0].distance == 0.0 + + # cosine distance + result_cos = session.scalars( + select(Item2Model).filter( + Item2Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_cos) == 2 + + distance_cos = Item2Model.embedding.cosine_distance([1, 2, 3]) + items_cos = ( + session.query(Item2Model.id, distance_cos.label("distance")) + .order_by(distance_cos) + .limit(5) + .all() + ) + assert len(items_cos) == 2 + assert items_cos[0].distance == 0.0 + + # drop indexes + l2_index.drop(engine) + cos_index.drop(engine) diff --git a/tidb_vector/sqlalchemy/__init__.py b/tidb_vector/sqlalchemy/__init__.py index 17579f2..eda69ce 100644 --- a/tidb_vector/sqlalchemy/__init__.py +++ b/tidb_vector/sqlalchemy/__init__.py @@ -1,4 +1,5 @@ from .vector_type import VectorType from .adaptor import VectorAdaptor +from .index import VectorIndex -__all__ = ["VectorType", "VectorAdaptor"] +__all__ = ["VectorType", "VectorAdaptor", "VectorIndex"] diff --git a/tidb_vector/sqlalchemy/index.py b/tidb_vector/sqlalchemy/index.py new file mode 100644 index 0000000..67ec14d --- /dev/null +++ b/tidb_vector/sqlalchemy/index.py @@ -0,0 +1,29 @@ +from typing import Optional, Any + +import sqlalchemy + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.schema import Index + +class VectorIndex(Index): + def __init__( + self, + name: Optional[str], + *expressions, # _DDLColumnArgument + _table: Optional[Any] = None, + **dialect_kw: Any, + ): + super().__init__(name, *expressions, unique=False, _table=_table, **dialect_kw) + self.dialect_options["mysql"]["prefix"] = "VECTOR" + # add tiflash automatically when creating vector index + self.dialect_options["mysql"]["add_tiflash_on_demand"] = True + +# VectorIndex.argument_for("mysql", "add_tiflash_on_demand", None) + +@compiles(sqlalchemy.schema.CreateIndex) +def compile_create_vector_index(create_index_elem: sqlalchemy.sql.ddl.CreateIndex, compiler: sqlalchemy.sql.compiler.DDLCompiler, **kw): + text = compiler.visit_create_index(create_index_elem, **kw) + index_elem = create_index_elem.element + if index_elem.dialect_options.get("mysql", {}).get("add_tiflash_on_demand"): + text += " ADD_TIFLASH_ON_DEMAND" + return text