1313"""
1414CREATE EXTENSION IF NOT EXISTS vector;
1515CREATE TABLE IF NOT EXISTS document_pieces (
16- id VARCHAR PRIMARY KEY,
16+ id TEXT PRIMARY KEY,
17+ tenant_id TEXT NOT NULL,
1718 namespace TEXT,
1819 name TEXT,
1920 text TEXT,
2021 embedding vector NOT NULL
2122);
2223CREATE TABLE IF NOT EXISTS documents (
23- id VARCHAR PRIMARY KEY,
24+ id TEXT PRIMARY KEY,
25+ tenant_id TEXT NOT NULL,
2426 namespace TEXT,
2527 name TEXT,
2628 text TEXT
2729);
30+
31+ ALTER TABLE document_pieces ENABLE ROW LEVEL SECURITY;
32+ ALTER TABLE documents ENABLE ROW LEVEL SECURITY;
33+
34+ CREATE POLICY document_pieces_policy ON document_pieces USING (tenant_id = current_user);
35+ CREATE POLICY documents_policy ON documents USING (tenant_id = current_user);
2836"""
2937
3038
@@ -33,12 +41,14 @@ class Database(Storage):
3341 user : Optional [str ]
3442 password : Optional [str ]
3543 connection : Optional [psycopg .AsyncConnection ]
44+ tenant_id : Optional [str ]
3645
3746 def __init__ (self , connection_url : str , user : Optional [str ] = None , password : Optional [str ] = None ):
3847 self .connection_url = connection_url
3948 self .user = user
4049 self .password = password
4150 self .connection = None
51+ self .tenant_id = None
4252
4353 async def initialize (self ) -> None :
4454 """
@@ -55,6 +65,13 @@ async def initialize(self) -> None:
5565 else :
5666 self .connection = await psycopg .AsyncConnection .connect (self .connection_url )
5767
68+ async with self .connection .cursor () as cursor :
69+ await cursor .execute ("SELECT \" current_user\" ()" )
70+ result = await cursor .fetchone ()
71+ if result is None :
72+ raise Exception ("Failed to connect to the database. No current user found." )
73+ self .tenant_id = result [0 ]
74+
5875 async def delete_all (self ):
5976 """For testing purposes, delete all entries"""
6077 await self .initialize ()
@@ -69,8 +86,8 @@ async def store_document(self, namespace: str, id: str, name: str, text: str) ->
6986
7087 async with self .connection .cursor () as cursor :
7188 await cursor .execute (
72- "INSERT INTO documents (namespace, id, name, text) VALUES (%s, %s, %s, %s)" ,
73- (namespace , id , name , text ),
89+ "INSERT INTO documents (tenant_id, namespace, id, name, text) VALUES (%s, %s, %s, %s, %s)" ,
90+ (self . tenant_id , namespace , id , name , text ),
7491 )
7592 await self .connection .commit ()
7693
@@ -80,16 +97,16 @@ async def documents(self, namespace: str, id: Optional[str] = None) -> List[Sear
8097 async with self .connection .cursor () as cursor :
8198 if id :
8299 await cursor .execute (
83- "SELECT id, name, text FROM documents WHERE namespace = %s AND document_id = %s" ,
84- (namespace , id ),
100+ "SELECT id, name, text FROM documents WHERE tenant_id = %s AND namespace = %s AND document_id = %s" ,
101+ (self . tenant_id , namespace , id ),
85102 )
86103 result : Optional [Row ] = await cursor .fetchone ()
87104 if result :
88105 return [SearchHit (result [0 ], result [1 ])]
89106 else :
90107 raise Exception (f"Document { id } not found." )
91108 else :
92- await cursor .execute ("SELECT id, name, text FROM documents WHERE namespace = %s" , (namespace ,))
109+ await cursor .execute ("SELECT id, name, text FROM documents WHERE tenant_id = %s AND namespace = %s" , (self . tenant_id , namespace ,))
93110 results : List [Row ] = await cursor .fetchall ()
94111 return [SearchHit (result [0 ], result [1 ], result [2 ]) for result in results ]
95112
@@ -99,8 +116,8 @@ async def store_text_piece(self, namespace: str, id: str, name: str, pieces: Lis
99116 async with self .connection .cursor () as cursor :
100117 for piece in pieces :
101118 await cursor .execute (
102- "INSERT INTO document_pieces (namespace, id, name, text, embedding) VALUES (%s, %s, %s, %s, %s)" ,
103- (namespace , id , name , piece .text , piece .embedding ),
119+ "INSERT INTO document_pieces (tenant_id, namespace, id, name, text, embedding) VALUES (%s, %s, %s, %s, %s, %s)" ,
120+ (self . tenant_id , namespace , id , name , piece .text , piece .embedding ),
104121 )
105122 await self .connection .commit ()
106123
@@ -111,10 +128,10 @@ async def search_text(self, namespace: str, embedding: List[float], max_results:
111128 await cursor .execute (
112129 """
113130 SELECT id, name, text, (embedding <=> %s::sparsevec) as distance FROM document_pieces
114- WHERE namespace = %s AND (embedding <=> %s::sparsevec) <= %s
131+ WHERE tenant_id = %s AND namespace = %s AND (embedding <=> %s::sparsevec) <= %s
115132 ORDER BY distance LIMIT %s
116133 """ ,
117- (embedding , namespace , embedding , max_distance * 2.0 , max_results ),
134+ (embedding , self . tenant_id , namespace , embedding , max_distance * 2.0 , max_results ),
118135 )
119136 results = await cursor .fetchall ()
120137 # returned cosine distance is in the range [0, 2]
0 commit comments