Skip to content

Commit

Permalink
Merge pull request #160 from milderhc/jdbc
Browse files Browse the repository at this point in the history
Add JDBC Vector Store
  • Loading branch information
milderhc authored Aug 6, 2024
2 parents 57db64f + c2ac675 commit b5b873b
Show file tree
Hide file tree
Showing 28 changed files with 2,295 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
/**
* Provides OpenAi implementation of audio to text service.
*/
public class OpenAiAudioToTextService extends OpenAiService<OpenAIAsyncClient> implements AudioToTextService {
public class OpenAiAudioToTextService extends OpenAiService<OpenAIAsyncClient>
implements AudioToTextService {

private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiAudioToTextService.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
/**
* Provides OpenAi implementation of text to audio service.
*/
public class OpenAiTextToAudioService extends OpenAiService<OpenAIAsyncClient> implements TextToAudioService {
public class OpenAiTextToAudioService extends OpenAiService<OpenAIAsyncClient>
implements TextToAudioService {

private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiTextToAudioService.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
/**
* OpenAI chat completion service.
*/
public class OpenAIChatCompletion extends OpenAiService<OpenAIAsyncClient> implements ChatCompletionService {
public class OpenAIChatCompletion extends OpenAiService<OpenAIAsyncClient>
implements ChatCompletionService {

private static final Logger LOGGER = LoggerFactory.getLogger(OpenAIChatCompletion.class);

Expand Down Expand Up @@ -1055,7 +1056,8 @@ static ChatRequestMessage getChatRequestMessage(
/**
* Builder for creating a new instance of {@link OpenAIChatCompletion}.
*/
public static class Builder extends OpenAiServiceBuilder<OpenAIAsyncClient, OpenAIChatCompletion, Builder> {
public static class Builder
extends OpenAiServiceBuilder<OpenAIAsyncClient, OpenAIChatCompletion, Builder> {

@Override
public OpenAIChatCompletion build() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
/**
* An OpenAI implementation of a {@link TextGenerationService}.
*/
public class OpenAITextGenerationService extends OpenAiService<OpenAIAsyncClient> implements TextGenerationService {
public class OpenAITextGenerationService extends OpenAiService<OpenAIAsyncClient>
implements TextGenerationService {

private static final Logger LOGGER = LoggerFactory.getLogger(OpenAITextGenerationService.class);

Expand Down
6 changes: 3 additions & 3 deletions api-test/integration-tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@
<version>3.44.1.0</version>
</dependency>
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
<version>8.2.0</version>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.33</version>
<scope>test</scope>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package com.microsoft.semantickernel.tests.connectors.memory.jdbc;

import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollection;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.mysql.cj.jdbc.MysqlDataSource;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.MySQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

import javax.annotation.Nonnull;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;

@Testcontainers
public class JDBCVectorStoreRecordCollectionTest {
@Container
private static final MySQLContainer<?> CONTAINER = new MySQLContainer<>("mysql:5.7.34");
private static final String MYSQL_USER = "test";
private static final String MYSQL_PASSWORD = "test";
private static MysqlDataSource dataSource;
@BeforeAll
static void setup() {
dataSource = new MysqlDataSource();
dataSource.setUrl(CONTAINER.getJdbcUrl());
dataSource.setUser(MYSQL_USER);
dataSource.setPassword(MYSQL_PASSWORD);
}

private JDBCVectorStoreRecordCollection<Hotel> buildRecordCollection(@Nonnull String collectionName) {
JDBCVectorStoreRecordCollection<Hotel> recordCollection = new JDBCVectorStoreRecordCollection<>(
dataSource,
collectionName,
JDBCVectorStoreRecordCollectionOptions.<Hotel>builder()
.withRecordClass(Hotel.class)
.withQueryProvider(MySQLVectorStoreQueryProvider.builder()
.withDataSource(dataSource)
.build())
.build());

recordCollection.prepareAsync().block();
recordCollection.createCollectionIfNotExistsAsync().block();
return recordCollection;
}

@Test
public void buildRecordCollection() {
assertNotNull(buildRecordCollection("buildTest"));
}

private List<Hotel> getHotels() {
return List.of(
new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0),
new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f), 3.0),
new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0),
new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0),
new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0)
);
}

@Test
public void upsertAndGetRecordAsync() {
String collectionName = "upsertAndGetRecordAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
for (Hotel hotel : hotels) {
recordStore.upsertAsync(hotel, null).block();
}

for (Hotel hotel : hotels) {
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
assertNotNull(retrievedHotel);
assertEquals(hotel.getId(), retrievedHotel.getId());
}
}

@Test
public void getBatchAsync() {
String collectionName = "getBatchAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
for (Hotel hotel : hotels) {
recordStore.upsertAsync(hotel, null).block();
}

List<String> keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}

List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, null).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
}

@Test
public void upsertBatchAndGetBatchAsync() {
String collectionName = "upsertBatchAndGetBatchAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();

List<String> keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}

List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, null).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
}

@Test
public void insertAndReplaceAsync() {
String collectionName = "insertAndReplaceAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();
recordStore.upsertBatchAsync(hotels, null).block();
recordStore.upsertBatchAsync(hotels, null).block();

List<String> keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}

List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, null).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
}

@Test
public void deleteRecordAsync() {
String collectionName = "deleteRecordAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();

for (Hotel hotel : hotels) {
recordStore.deleteAsync(hotel.getId(), null).block();
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
assertNull(retrievedHotel);
}
}

@Test
public void deleteBatchAsync() {
String collectionName = "deleteBatchAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();

List<String> keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}

recordStore.deleteBatchAsync(keys, null).block();

for (String key : keys) {
Hotel retrievedHotel = recordStore.getAsync(key, null).block();
assertNull(retrievedHotel);
}
}

@Test
public void getWithNoVectors() {
String collectionName = "getWithNoVectors";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();

GetRecordOptions options = GetRecordOptions.builder()
.includeVectors(false)
.build();

for (Hotel hotel : hotels) {
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block();
assertNotNull(retrievedHotel);
assertEquals(hotel.getId(), retrievedHotel.getId());
assertNull(retrievedHotel.getDescriptionEmbedding());
}

options = GetRecordOptions.builder()
.includeVectors(true)
.build();

for (Hotel hotel : hotels) {
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block();
assertNotNull(retrievedHotel);
assertEquals(hotel.getId(), retrievedHotel.getId());
assertNotNull(retrievedHotel.getDescriptionEmbedding());
}
}

@Test
public void getBatchWithNoVectors() {
String collectionName = "getBatchWithNoVectors";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();

GetRecordOptions options = GetRecordOptions.builder()
.includeVectors(false)
.build();

List<String> keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}

List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, options).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());

for (Hotel hotel : retrievedHotels) {
assertNull(hotel.getDescriptionEmbedding());
}

options = GetRecordOptions.builder()
.includeVectors(true)
.build();

retrievedHotels = recordStore.getBatchAsync(keys, options).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());

for (Hotel hotel : retrievedHotels) {
assertNotNull(hotel.getDescriptionEmbedding());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package com.microsoft.semantickernel.tests.connectors.memory.jdbc;

import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.mysql.cj.jdbc.MysqlDataSource;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.MySQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

@Testcontainers
public class JDBCVectorStoreTest {
@Container
private static final MySQLContainer<?> CONTAINER = new MySQLContainer<>("mysql:5.7.34");
private static final String MYSQL_USER = "test";
private static final String MYSQL_PASSWORD = "test";
private static MysqlDataSource dataSource;

@BeforeAll
static void setup() {
dataSource = new MysqlDataSource();
dataSource.setUrl(CONTAINER.getJdbcUrl());
dataSource.setUser(MYSQL_USER);
dataSource.setPassword(MYSQL_PASSWORD);
}

@Test
public void getCollectionNamesAsync() {
MySQLVectorStoreQueryProvider queryProvider = MySQLVectorStoreQueryProvider.builder()
.withDataSource(dataSource)
.build();

JDBCVectorStore vectorStore = JDBCVectorStore.builder()
.withDataSource(dataSource)
.withOptions(
JDBCVectorStoreOptions.builder()
.withQueryProvider(queryProvider)
.build()
)
.build();

vectorStore.getCollectionNamesAsync().block();

List<String> collectionNames = Arrays.asList("collection1", "collection2", "collection3");

for (String collectionName : collectionNames) {
vectorStore.getCollection(collectionName, Hotel.class, null).createCollectionAsync().block();
}

List<String> retrievedCollectionNames = vectorStore.getCollectionNamesAsync().block();
assertNotNull(retrievedCollectionNames);
assertEquals(collectionNames.size(), retrievedCollectionNames.size());
for (String collectionName : collectionNames) {
assertTrue(retrievedCollectionNames.contains(collectionName));
}
}
}
Loading

0 comments on commit b5b873b

Please sign in to comment.