Skip to content

Commit b5b873b

Browse files
authored
Merge pull request #160 from milderhc/jdbc
Add JDBC Vector Store
2 parents 57db64f + c2ac675 commit b5b873b

File tree

28 files changed

+2295
-94
lines changed

28 files changed

+2295
-94
lines changed

aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
/**
1919
* Provides OpenAi implementation of audio to text service.
2020
*/
21-
public class OpenAiAudioToTextService extends OpenAiService<OpenAIAsyncClient> implements AudioToTextService {
21+
public class OpenAiAudioToTextService extends OpenAiService<OpenAIAsyncClient>
22+
implements AudioToTextService {
2223

2324
private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiAudioToTextService.class);
2425

aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
/**
1818
* Provides OpenAi implementation of text to audio service.
1919
*/
20-
public class OpenAiTextToAudioService extends OpenAiService<OpenAIAsyncClient> implements TextToAudioService {
20+
public class OpenAiTextToAudioService extends OpenAiService<OpenAIAsyncClient>
21+
implements TextToAudioService {
2122

2223
private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiTextToAudioService.class);
2324

aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@
7979
/**
8080
* OpenAI chat completion service.
8181
*/
82-
public class OpenAIChatCompletion extends OpenAiService<OpenAIAsyncClient> implements ChatCompletionService {
82+
public class OpenAIChatCompletion extends OpenAiService<OpenAIAsyncClient>
83+
implements ChatCompletionService {
8384

8485
private static final Logger LOGGER = LoggerFactory.getLogger(OpenAIChatCompletion.class);
8586

@@ -1055,7 +1056,8 @@ static ChatRequestMessage getChatRequestMessage(
10551056
/**
10561057
* Builder for creating a new instance of {@link OpenAIChatCompletion}.
10571058
*/
1058-
public static class Builder extends OpenAiServiceBuilder<OpenAIAsyncClient, OpenAIChatCompletion, Builder> {
1059+
public static class Builder
1060+
extends OpenAiServiceBuilder<OpenAIAsyncClient, OpenAIChatCompletion, Builder> {
10591061

10601062
@Override
10611063
public OpenAIChatCompletion build() {

aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
/**
3131
* An OpenAI implementation of a {@link TextGenerationService}.
3232
*/
33-
public class OpenAITextGenerationService extends OpenAiService<OpenAIAsyncClient> implements TextGenerationService {
33+
public class OpenAITextGenerationService extends OpenAiService<OpenAIAsyncClient>
34+
implements TextGenerationService {
3435

3536
private static final Logger LOGGER = LoggerFactory.getLogger(OpenAITextGenerationService.class);
3637

api-test/integration-tests/pom.xml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@
6868
<version>3.44.1.0</version>
6969
</dependency>
7070
<dependency>
71-
<groupId>com.mysql</groupId>
72-
<artifactId>mysql-connector-j</artifactId>
73-
<version>8.2.0</version>
71+
<groupId>mysql</groupId>
72+
<artifactId>mysql-connector-java</artifactId>
73+
<version>8.0.33</version>
7474
<scope>test</scope>
7575
</dependency>
7676

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
package com.microsoft.semantickernel.tests.connectors.memory.jdbc;
2+
3+
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollection;
4+
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
5+
import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
6+
import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions;
7+
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
8+
import com.mysql.cj.jdbc.MysqlDataSource;
9+
import org.junit.jupiter.api.BeforeAll;
10+
import org.junit.jupiter.api.Test;
11+
import org.testcontainers.containers.MySQLContainer;
12+
import org.testcontainers.junit.jupiter.Container;
13+
import org.testcontainers.junit.jupiter.Testcontainers;
14+
15+
import javax.annotation.Nonnull;
16+
import javax.sql.DataSource;
17+
import java.sql.Connection;
18+
import java.sql.DriverManager;
19+
import java.sql.SQLException;
20+
import java.util.ArrayList;
21+
import java.util.Arrays;
22+
import java.util.List;
23+
24+
import static org.junit.jupiter.api.Assertions.assertEquals;
25+
import static org.junit.jupiter.api.Assertions.assertNotNull;
26+
import static org.junit.jupiter.api.Assertions.assertNull;
27+
28+
@Testcontainers
29+
public class JDBCVectorStoreRecordCollectionTest {
30+
@Container
31+
private static final MySQLContainer<?> CONTAINER = new MySQLContainer<>("mysql:5.7.34");
32+
private static final String MYSQL_USER = "test";
33+
private static final String MYSQL_PASSWORD = "test";
34+
private static MysqlDataSource dataSource;
35+
@BeforeAll
36+
static void setup() {
37+
dataSource = new MysqlDataSource();
38+
dataSource.setUrl(CONTAINER.getJdbcUrl());
39+
dataSource.setUser(MYSQL_USER);
40+
dataSource.setPassword(MYSQL_PASSWORD);
41+
}
42+
43+
private JDBCVectorStoreRecordCollection<Hotel> buildRecordCollection(@Nonnull String collectionName) {
44+
JDBCVectorStoreRecordCollection<Hotel> recordCollection = new JDBCVectorStoreRecordCollection<>(
45+
dataSource,
46+
collectionName,
47+
JDBCVectorStoreRecordCollectionOptions.<Hotel>builder()
48+
.withRecordClass(Hotel.class)
49+
.withQueryProvider(MySQLVectorStoreQueryProvider.builder()
50+
.withDataSource(dataSource)
51+
.build())
52+
.build());
53+
54+
recordCollection.prepareAsync().block();
55+
recordCollection.createCollectionIfNotExistsAsync().block();
56+
return recordCollection;
57+
}
58+
59+
@Test
60+
public void buildRecordCollection() {
61+
assertNotNull(buildRecordCollection("buildTest"));
62+
}
63+
64+
private List<Hotel> getHotels() {
65+
return List.of(
66+
new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0),
67+
new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f), 3.0),
68+
new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0),
69+
new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0),
70+
new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0)
71+
);
72+
}
73+
74+
@Test
75+
public void upsertAndGetRecordAsync() {
76+
String collectionName = "upsertAndGetRecordAsync";
77+
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);
78+
79+
List<Hotel> hotels = getHotels();
80+
for (Hotel hotel : hotels) {
81+
recordStore.upsertAsync(hotel, null).block();
82+
}
83+
84+
for (Hotel hotel : hotels) {
85+
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
86+
assertNotNull(retrievedHotel);
87+
assertEquals(hotel.getId(), retrievedHotel.getId());
88+
}
89+
}
90+
91+
@Test
92+
public void getBatchAsync() {
93+
String collectionName = "getBatchAsync";
94+
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);
95+
96+
List<Hotel> hotels = getHotels();
97+
for (Hotel hotel : hotels) {
98+
recordStore.upsertAsync(hotel, null).block();
99+
}
100+
101+
List<String> keys = new ArrayList<>();
102+
for (Hotel hotel : hotels) {
103+
keys.add(hotel.getId());
104+
}
105+
106+
List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, null).block();
107+
assertNotNull(retrievedHotels);
108+
assertEquals(hotels.size(), retrievedHotels.size());
109+
}
110+
111+
@Test
112+
public void upsertBatchAndGetBatchAsync() {
113+
String collectionName = "upsertBatchAndGetBatchAsync";
114+
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);
115+
116+
List<Hotel> hotels = getHotels();
117+
recordStore.upsertBatchAsync(hotels, null).block();
118+
119+
List<String> keys = new ArrayList<>();
120+
for (Hotel hotel : hotels) {
121+
keys.add(hotel.getId());
122+
}
123+
124+
List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, null).block();
125+
assertNotNull(retrievedHotels);
126+
assertEquals(hotels.size(), retrievedHotels.size());
127+
}
128+
129+
@Test
130+
public void insertAndReplaceAsync() {
131+
String collectionName = "insertAndReplaceAsync";
132+
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);
133+
134+
List<Hotel> hotels = getHotels();
135+
recordStore.upsertBatchAsync(hotels, null).block();
136+
recordStore.upsertBatchAsync(hotels, null).block();
137+
recordStore.upsertBatchAsync(hotels, null).block();
138+
139+
List<String> keys = new ArrayList<>();
140+
for (Hotel hotel : hotels) {
141+
keys.add(hotel.getId());
142+
}
143+
144+
List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, null).block();
145+
assertNotNull(retrievedHotels);
146+
assertEquals(hotels.size(), retrievedHotels.size());
147+
}
148+
149+
@Test
150+
public void deleteRecordAsync() {
151+
String collectionName = "deleteRecordAsync";
152+
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);
153+
154+
List<Hotel> hotels = getHotels();
155+
recordStore.upsertBatchAsync(hotels, null).block();
156+
157+
for (Hotel hotel : hotels) {
158+
recordStore.deleteAsync(hotel.getId(), null).block();
159+
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
160+
assertNull(retrievedHotel);
161+
}
162+
}
163+
164+
@Test
165+
public void deleteBatchAsync() {
166+
String collectionName = "deleteBatchAsync";
167+
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);
168+
169+
List<Hotel> hotels = getHotels();
170+
recordStore.upsertBatchAsync(hotels, null).block();
171+
172+
List<String> keys = new ArrayList<>();
173+
for (Hotel hotel : hotels) {
174+
keys.add(hotel.getId());
175+
}
176+
177+
recordStore.deleteBatchAsync(keys, null).block();
178+
179+
for (String key : keys) {
180+
Hotel retrievedHotel = recordStore.getAsync(key, null).block();
181+
assertNull(retrievedHotel);
182+
}
183+
}
184+
185+
@Test
186+
public void getWithNoVectors() {
187+
String collectionName = "getWithNoVectors";
188+
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);
189+
190+
List<Hotel> hotels = getHotels();
191+
recordStore.upsertBatchAsync(hotels, null).block();
192+
193+
GetRecordOptions options = GetRecordOptions.builder()
194+
.includeVectors(false)
195+
.build();
196+
197+
for (Hotel hotel : hotels) {
198+
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block();
199+
assertNotNull(retrievedHotel);
200+
assertEquals(hotel.getId(), retrievedHotel.getId());
201+
assertNull(retrievedHotel.getDescriptionEmbedding());
202+
}
203+
204+
options = GetRecordOptions.builder()
205+
.includeVectors(true)
206+
.build();
207+
208+
for (Hotel hotel : hotels) {
209+
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block();
210+
assertNotNull(retrievedHotel);
211+
assertEquals(hotel.getId(), retrievedHotel.getId());
212+
assertNotNull(retrievedHotel.getDescriptionEmbedding());
213+
}
214+
}
215+
216+
@Test
217+
public void getBatchWithNoVectors() {
218+
String collectionName = "getBatchWithNoVectors";
219+
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);
220+
221+
List<Hotel> hotels = getHotels();
222+
recordStore.upsertBatchAsync(hotels, null).block();
223+
224+
GetRecordOptions options = GetRecordOptions.builder()
225+
.includeVectors(false)
226+
.build();
227+
228+
List<String> keys = new ArrayList<>();
229+
for (Hotel hotel : hotels) {
230+
keys.add(hotel.getId());
231+
}
232+
233+
List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, options).block();
234+
assertNotNull(retrievedHotels);
235+
assertEquals(hotels.size(), retrievedHotels.size());
236+
237+
for (Hotel hotel : retrievedHotels) {
238+
assertNull(hotel.getDescriptionEmbedding());
239+
}
240+
241+
options = GetRecordOptions.builder()
242+
.includeVectors(true)
243+
.build();
244+
245+
retrievedHotels = recordStore.getBatchAsync(keys, options).block();
246+
assertNotNull(retrievedHotels);
247+
assertEquals(hotels.size(), retrievedHotels.size());
248+
249+
for (Hotel hotel : retrievedHotels) {
250+
assertNotNull(hotel.getDescriptionEmbedding());
251+
}
252+
}
253+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package com.microsoft.semantickernel.tests.connectors.memory.jdbc;
2+
3+
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
4+
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
5+
import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
6+
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
7+
import com.mysql.cj.jdbc.MysqlDataSource;
8+
import org.junit.jupiter.api.BeforeAll;
9+
import org.junit.jupiter.api.Test;
10+
import org.testcontainers.containers.MySQLContainer;
11+
import org.testcontainers.junit.jupiter.Container;
12+
import org.testcontainers.junit.jupiter.Testcontainers;
13+
14+
import java.sql.Connection;
15+
import java.sql.DriverManager;
16+
import java.sql.SQLException;
17+
import java.util.Arrays;
18+
import java.util.List;
19+
20+
import static org.junit.jupiter.api.Assertions.assertEquals;
21+
import static org.junit.jupiter.api.Assertions.assertNotNull;
22+
import static org.junit.jupiter.api.Assertions.assertTrue;
23+
24+
@Testcontainers
25+
public class JDBCVectorStoreTest {
26+
@Container
27+
private static final MySQLContainer<?> CONTAINER = new MySQLContainer<>("mysql:5.7.34");
28+
private static final String MYSQL_USER = "test";
29+
private static final String MYSQL_PASSWORD = "test";
30+
private static MysqlDataSource dataSource;
31+
32+
@BeforeAll
33+
static void setup() {
34+
dataSource = new MysqlDataSource();
35+
dataSource.setUrl(CONTAINER.getJdbcUrl());
36+
dataSource.setUser(MYSQL_USER);
37+
dataSource.setPassword(MYSQL_PASSWORD);
38+
}
39+
40+
@Test
41+
public void getCollectionNamesAsync() {
42+
MySQLVectorStoreQueryProvider queryProvider = MySQLVectorStoreQueryProvider.builder()
43+
.withDataSource(dataSource)
44+
.build();
45+
46+
JDBCVectorStore vectorStore = JDBCVectorStore.builder()
47+
.withDataSource(dataSource)
48+
.withOptions(
49+
JDBCVectorStoreOptions.builder()
50+
.withQueryProvider(queryProvider)
51+
.build()
52+
)
53+
.build();
54+
55+
vectorStore.getCollectionNamesAsync().block();
56+
57+
List<String> collectionNames = Arrays.asList("collection1", "collection2", "collection3");
58+
59+
for (String collectionName : collectionNames) {
60+
vectorStore.getCollection(collectionName, Hotel.class, null).createCollectionAsync().block();
61+
}
62+
63+
List<String> retrievedCollectionNames = vectorStore.getCollectionNamesAsync().block();
64+
assertNotNull(retrievedCollectionNames);
65+
assertEquals(collectionNames.size(), retrievedCollectionNames.size());
66+
for (String collectionName : collectionNames) {
67+
assertTrue(retrievedCollectionNames.contains(collectionName));
68+
}
69+
}
70+
}

0 commit comments

Comments
 (0)