From 67bf47058c6248dcbb3fee76662c4d64a88fdfac Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Thu, 7 Nov 2024 18:50:24 +0800 Subject: [PATCH 1/2] Support ADD_TIFLASH_ON_DEMAND option Signed-off-by: JaySon-Huang --- tests/sqlalchemy/test_sqlalchemy.py | 92 ++++++++++++++++++++++++++++- tidb_vector/sqlalchemy/__init__.py | 3 +- tidb_vector/sqlalchemy/index.py | 29 +++++++++ 3 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 tidb_vector/sqlalchemy/index.py diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index b59c87b..4cffaf4 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -1,9 +1,10 @@ import pytest import numpy as np +import sqlalchemy from sqlalchemy import URL, create_engine, Column, Integer, select from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.exc import OperationalError -from tidb_vector.sqlalchemy import VectorType, VectorAdaptor +from tidb_vector.sqlalchemy import VectorType, VectorAdaptor, VectorIndex import tidb_vector from ..config import TestConfig @@ -385,3 +386,92 @@ def test_index_and_search(self): ) assert len(items) == 2 assert items[0].distance == 0.0 + + +class TestSQLAlchemyVectorIndex: + + def setup_class(self): + Item2Model.__table__.drop(bind=engine, checkfirst=True) + Item2Model.__table__.create(bind=engine) + + def teardown_class(self): + Item2Model.__table__.drop(bind=engine, checkfirst=True) + + def test_create_vector_index_statement(self): + from sqlalchemy.sql.ddl import CreateIndex + l2_index = VectorIndex( + "idx_embedding_l2", + sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding), + ) + compiled = CreateIndex(l2_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE VECTOR INDEX idx_embedding_l2 ON sqlalchemy_item2 ((vec_l2_distance(embedding))) ADD_TIFLASH_ON_DEMAND" + + cos_index = VectorIndex( + "idx_embedding_cos", + sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding), + ) + compiled = CreateIndex(cos_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE VECTOR INDEX idx_embedding_cos ON sqlalchemy_item2 ((vec_cosine_distance(embedding))) ADD_TIFLASH_ON_DEMAND" + + def test_query_with_index(self): + # indexes + l2_index = VectorIndex( + "idx_embedding_l2", + sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding), + ) + l2_index.create(engine) + cos_index = VectorIndex( + "idx_embedding_cos", + sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding), + ) + cos_index.create(engine) + + self.check_indexes( + Item2Model.__table__, ["idx_embedding_l2", "idx_embedding_cos"] + ) + + with Session() as session: + session.add_all( + [Item2Model(embedding=[1, 2, 3]), Item2Model(embedding=[1, 2, 3.2])] + ) + session.commit() + + # l2 distance + result_l2 = session.scalars( + select(Item2Model).filter( + Item2Model.embedding.l2_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_l2) == 2 + + distance_l2 = Item2Model.embedding.l2_distance([1, 2, 3]) + items_l2 = ( + session.query(Item2Model.id, distance_l2.label("distance")) + .order_by(distance_l2) + .limit(5) + .all() + ) + assert len(items_l2) == 2 + assert items_l2[0].distance == 0.0 + + # cosine distance + result_cos = session.scalars( + select(Item2Model).filter( + Item2Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_cos) == 2 + + distance_cos = Item2Model.embedding.cosine_distance([1, 2, 3]) + items_cos = ( + session.query(Item2Model.id, distance_cos.label("distance")) + .order_by(distance_cos) + .limit(5) + .all() + ) + assert len(items_cos) == 2 + assert items_cos[0].distance == 0.0 + + # drop indexes + l2_index.drop(engine) + cos_index.drop(engine) diff --git a/tidb_vector/sqlalchemy/__init__.py b/tidb_vector/sqlalchemy/__init__.py index 17579f2..eda69ce 100644 --- a/tidb_vector/sqlalchemy/__init__.py +++ b/tidb_vector/sqlalchemy/__init__.py @@ -1,4 +1,5 @@ from .vector_type import VectorType from .adaptor import VectorAdaptor +from .index import VectorIndex -__all__ = ["VectorType", "VectorAdaptor"] +__all__ = ["VectorType", "VectorAdaptor", "VectorIndex"] diff --git a/tidb_vector/sqlalchemy/index.py b/tidb_vector/sqlalchemy/index.py new file mode 100644 index 0000000..67ec14d --- /dev/null +++ b/tidb_vector/sqlalchemy/index.py @@ -0,0 +1,29 @@ +from typing import Optional, Any + +import sqlalchemy + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.schema import Index + +class VectorIndex(Index): + def __init__( + self, + name: Optional[str], + *expressions, # _DDLColumnArgument + _table: Optional[Any] = None, + **dialect_kw: Any, + ): + super().__init__(name, *expressions, unique=False, _table=_table, **dialect_kw) + self.dialect_options["mysql"]["prefix"] = "VECTOR" + # add tiflash automatically when creating vector index + self.dialect_options["mysql"]["add_tiflash_on_demand"] = True + +# VectorIndex.argument_for("mysql", "add_tiflash_on_demand", None) + +@compiles(sqlalchemy.schema.CreateIndex) +def compile_create_vector_index(create_index_elem: sqlalchemy.sql.ddl.CreateIndex, compiler: sqlalchemy.sql.compiler.DDLCompiler, **kw): + text = compiler.visit_create_index(create_index_elem, **kw) + index_elem = create_index_elem.element + if index_elem.dialect_options.get("mysql", {}).get("add_tiflash_on_demand"): + text += " ADD_TIFLASH_ON_DEMAND" + return text From 521837f86b924c8e1209e8b4f4a905265be1bcdf Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Thu, 7 Nov 2024 18:59:04 +0800 Subject: [PATCH 2/2] normal index Signed-off-by: JaySon-Huang --- tests/sqlalchemy/test_sqlalchemy.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index 4cffaf4..393179f 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -413,6 +413,11 @@ def test_create_vector_index_statement(self): compiled = CreateIndex(cos_index).compile(dialect=engine.dialect) assert compiled.string == "CREATE VECTOR INDEX idx_embedding_cos ON sqlalchemy_item2 ((vec_cosine_distance(embedding))) ADD_TIFLASH_ON_DEMAND" + # non-vector index + normal_index = sqlalchemy.schema.Index("idx_unique", Item2Model.__table__.c.id, unique=True) + compiled = CreateIndex(normal_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE UNIQUE INDEX idx_unique ON sqlalchemy_item2 (id)" + def test_query_with_index(self): # indexes l2_index = VectorIndex(