diff --git a/sdk/ai/azure-ai-voicelive/CHANGELOG.md b/sdk/ai/azure-ai-voicelive/CHANGELOG.md index 1fae3f3b5120..0f0827d6483a 100644 --- a/sdk/ai/azure-ai-voicelive/CHANGELOG.md +++ b/sdk/ai/azure-ai-voicelive/CHANGELOG.md @@ -4,6 +4,15 @@ ### Features Added +- Added `VoiceLiveRequestOptions` class for per-request customization: + - Supports custom query parameters via `addCustomQueryParameter(String key, String value)` method + - Supports custom headers via `addCustomHeader(String name, String value)` and `setCustomHeaders(HttpHeaders)` methods + - Custom parameters and headers can be passed to session creation methods +- Enhanced session creation with new overloads: + - Added `startSession(String model, VoiceLiveRequestOptions requestOptions)` for model with custom options + - Added `startSession(VoiceLiveRequestOptions requestOptions)` for custom options without explicit model parameter + - Original `startSession(String model)` and `startSession()` methods preserved for backward compatibility + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/VoiceLiveAsyncClient.java b/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/VoiceLiveAsyncClient.java index e18bcdf824f4..d0b9fb5afaec 100644 --- a/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/VoiceLiveAsyncClient.java +++ b/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/VoiceLiveAsyncClient.java @@ -3,16 +3,20 @@ package com.azure.ai.voicelive; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; + +import com.azure.ai.voicelive.models.VoiceLiveRequestOptions; import com.azure.core.annotation.ServiceClient; import com.azure.core.credential.KeyCredential; import com.azure.core.credential.TokenCredential; import com.azure.core.http.HttpHeaders; import com.azure.core.util.logging.ClientLogger; -import reactor.core.publisher.Mono; -import java.net.URI; -import java.net.URISyntaxException; -import java.util.Objects; +import reactor.core.publisher.Mono; /** * The VoiceLiveAsyncClient provides methods to create and manage real-time voice communication sessions @@ -66,6 +70,7 @@ public final class VoiceLiveAsyncClient { * * @param model The model to use for the session. * @return A Mono containing the connected VoiceLiveSessionAsyncClient. + * @throws NullPointerException if {@code model} is null. */ public Mono startSession(String model) { Objects.requireNonNull(model, "'model' cannot be null"); @@ -81,6 +86,77 @@ public Mono startSession(String model) { }); } + /** + * Starts a new VoiceLiveSessionAsyncClient for real-time voice communication without specifying a model. + * The model can be provided via custom query parameters or through the endpoint URL if required by the service. + * + * @return A Mono containing the connected VoiceLiveSessionAsyncClient. + */ + public Mono startSession() { + return Mono.fromCallable(() -> convertToWebSocketEndpoint(endpoint, null)).flatMap(wsEndpoint -> { + VoiceLiveSessionAsyncClient session; + if (keyCredential != null) { + session = new VoiceLiveSessionAsyncClient(wsEndpoint, keyCredential); + } else { + session = new VoiceLiveSessionAsyncClient(wsEndpoint, tokenCredential); + } + return session.connect(additionalHeaders).thenReturn(session); + }); + } + + /** + * Starts a new VoiceLiveSessionAsyncClient for real-time voice communication with custom request options. + * + * @param model The model to use for the session. + * @param requestOptions Custom query parameters and headers for the request. + * @return A Mono containing the connected VoiceLiveSessionAsyncClient. + * @throws NullPointerException if {@code model} or {@code requestOptions} is null. + */ + public Mono startSession(String model, VoiceLiveRequestOptions requestOptions) { + Objects.requireNonNull(model, "'model' cannot be null"); + Objects.requireNonNull(requestOptions, "'requestOptions' cannot be null"); + + return Mono + .fromCallable(() -> convertToWebSocketEndpoint(endpoint, model, requestOptions.getCustomQueryParameters())) + .flatMap(wsEndpoint -> { + VoiceLiveSessionAsyncClient session; + if (keyCredential != null) { + session = new VoiceLiveSessionAsyncClient(wsEndpoint, keyCredential); + } else { + session = new VoiceLiveSessionAsyncClient(wsEndpoint, tokenCredential); + } + // Merge additional headers with custom headers from requestOptions + HttpHeaders mergedHeaders = mergeHeaders(additionalHeaders, requestOptions.getCustomHeaders()); + return session.connect(mergedHeaders).thenReturn(session); + }); + } + + /** + * Starts a new VoiceLiveSessionAsyncClient for real-time voice communication with custom request options. + * The model can be provided via custom query parameters. + * + * @param requestOptions Custom query parameters and headers for the request. + * @return A Mono containing the connected VoiceLiveSessionAsyncClient. + * @throws NullPointerException if {@code requestOptions} is null. + */ + public Mono startSession(VoiceLiveRequestOptions requestOptions) { + Objects.requireNonNull(requestOptions, "'requestOptions' cannot be null"); + + return Mono + .fromCallable(() -> convertToWebSocketEndpoint(endpoint, null, requestOptions.getCustomQueryParameters())) + .flatMap(wsEndpoint -> { + VoiceLiveSessionAsyncClient session; + if (keyCredential != null) { + session = new VoiceLiveSessionAsyncClient(wsEndpoint, keyCredential); + } else { + session = new VoiceLiveSessionAsyncClient(wsEndpoint, tokenCredential); + } + // Merge additional headers with custom headers from requestOptions + HttpHeaders mergedHeaders = mergeHeaders(additionalHeaders, requestOptions.getCustomHeaders()); + return session.connect(mergedHeaders).thenReturn(session); + }); + } + /** * Gets the API version. * @@ -90,6 +166,24 @@ String getApiVersion() { return apiVersion; } + /** + * Merges two HttpHeaders objects, with custom headers taking precedence. + * + * @param baseHeaders The base headers. + * @param customHeaders The custom headers to merge. + * @return The merged HttpHeaders. + */ + private HttpHeaders mergeHeaders(HttpHeaders baseHeaders, HttpHeaders customHeaders) { + HttpHeaders merged = new HttpHeaders(); + if (baseHeaders != null) { + baseHeaders.forEach(header -> merged.set(header.getName(), header.getValue())); + } + if (customHeaders != null) { + customHeaders.forEach(header -> merged.set(header.getName(), header.getValue())); + } + return merged; + } + /** * Converts an HTTP endpoint to a WebSocket endpoint. * @@ -124,26 +218,118 @@ private URI convertToWebSocketEndpoint(URI httpEndpoint, String model) { path = path.replaceAll("/$", "") + "/voice-live/realtime"; } - // Build query string - StringBuilder queryBuilder = new StringBuilder(); + // Build query parameter map to avoid duplicates + Map queryParams = new LinkedHashMap<>(); + + // Start with existing query parameters from the endpoint URL if (httpEndpoint.getQuery() != null && !httpEndpoint.getQuery().isEmpty()) { - queryBuilder.append(httpEndpoint.getQuery()); + String[] pairs = httpEndpoint.getQuery().split("&"); + for (String pair : pairs) { + int idx = pair.indexOf("="); + if (idx > 0) { + String key = pair.substring(0, idx); + String value = pair.substring(idx + 1); + queryParams.put(key, value); + } + } + } + + // Ensure api-version is set (SDK's version takes precedence) + queryParams.put("api-version", apiVersion); + + // Add model if provided (function parameter takes precedence) + if (model != null && !model.isEmpty()) { + queryParams.put("model", model); } - // Add api-version if not present - if (!queryBuilder.toString().contains("api-version=")) { + // Build final query string + StringBuilder queryBuilder = new StringBuilder(); + for (Map.Entry entry : queryParams.entrySet()) { if (queryBuilder.length() > 0) { queryBuilder.append("&"); } - queryBuilder.append("api-version=").append(apiVersion); + queryBuilder.append(entry.getKey()).append("=").append(entry.getValue()); } - // Add model if not present - if (!queryBuilder.toString().contains("model=")) { + return new URI(scheme, httpEndpoint.getUserInfo(), httpEndpoint.getHost(), httpEndpoint.getPort(), path, + queryBuilder.length() > 0 ? queryBuilder.toString() : null, httpEndpoint.getFragment()); + } catch (URISyntaxException e) { + throw LOGGER + .logExceptionAsError(new IllegalArgumentException("Failed to convert endpoint to WebSocket URI", e)); + } + } + + /** + * Converts an HTTP endpoint to a WebSocket endpoint with additional custom query parameters. + * + * @param httpEndpoint The HTTP endpoint to convert. + * @param model The model name to include in the query string. + * @param additionalQueryParams Additional custom query parameters to include. + * @return The WebSocket endpoint URI. + */ + private URI convertToWebSocketEndpoint(URI httpEndpoint, String model, Map additionalQueryParams) { + try { + String scheme; + switch (httpEndpoint.getScheme().toLowerCase()) { + case "wss": + case "ws": + scheme = httpEndpoint.getScheme(); + break; + + case "https": + scheme = "wss"; + break; + + case "http": + scheme = "ws"; + break; + + default: + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Scheme " + httpEndpoint.getScheme() + " is not supported")); + } + + String path = httpEndpoint.getPath(); + if (!path.endsWith("/realtime")) { + path = path.replaceAll("/$", "") + "/voice-live/realtime"; + } + + // Build query parameter map to avoid duplicates + Map queryParams = new LinkedHashMap<>(); + + // Start with existing query parameters from the endpoint URL + if (httpEndpoint.getQuery() != null && !httpEndpoint.getQuery().isEmpty()) { + String[] pairs = httpEndpoint.getQuery().split("&"); + for (String pair : pairs) { + int idx = pair.indexOf("="); + if (idx > 0) { + String key = pair.substring(0, idx); + String value = pair.substring(idx + 1); + queryParams.put(key, value); + } + } + } + + // Add/override with custom query parameters from request options + if (additionalQueryParams != null && !additionalQueryParams.isEmpty()) { + queryParams.putAll(additionalQueryParams); + } + + // Ensure api-version is set (SDK's version takes precedence) + queryParams.put("api-version", apiVersion); + + // Add model if provided (function parameter takes precedence) + if (model != null && !model.isEmpty()) { + queryParams.put("model", model); + } + + // Build final query string + StringBuilder queryBuilder = new StringBuilder(); + for (Map.Entry entry : queryParams.entrySet()) { if (queryBuilder.length() > 0) { queryBuilder.append("&"); } - queryBuilder.append("model=").append(model); + queryBuilder.append(entry.getKey()).append("=").append(entry.getValue()); } return new URI(scheme, httpEndpoint.getUserInfo(), httpEndpoint.getHost(), httpEndpoint.getPort(), path, diff --git a/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/VoiceLiveClientBuilder.java b/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/VoiceLiveClientBuilder.java index c89987016753..e655bde45d91 100644 --- a/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/VoiceLiveClientBuilder.java +++ b/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/VoiceLiveClientBuilder.java @@ -3,6 +3,10 @@ package com.azure.ai.voicelive; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Objects; + import com.azure.core.annotation.ServiceClientBuilder; import com.azure.core.client.traits.EndpointTrait; import com.azure.core.client.traits.KeyCredentialTrait; @@ -14,10 +18,6 @@ import com.azure.core.util.CoreUtils; import com.azure.core.util.logging.ClientLogger; -import java.net.URI; -import java.net.URISyntaxException; -import java.util.Objects; - /** * Builder for creating instances of {@link VoiceLiveAsyncClient}. */ diff --git a/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/models/VoiceLiveRequestOptions.java b/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/models/VoiceLiveRequestOptions.java new file mode 100644 index 000000000000..aa43c315425d --- /dev/null +++ b/sdk/ai/azure-ai-voicelive/src/main/java/com/azure/ai/voicelive/models/VoiceLiveRequestOptions.java @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.ai.voicelive.models; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import com.azure.core.annotation.Fluent; +import com.azure.core.http.HttpHeaders; + +/** + * Options for customizing VoiceLive requests with additional query parameters and headers. + */ +@Fluent +public final class VoiceLiveRequestOptions { + + private Map customQueryParameters; + private HttpHeaders customHeaders; + + /** + * Creates a new instance of VoiceLiveRequestOptions. + */ + public VoiceLiveRequestOptions() { + this.customQueryParameters = new HashMap<>(); + this.customHeaders = new HttpHeaders(); + } + + /** + * Gets the custom query parameters. + * + * @return The custom query parameters. + */ + public Map getCustomQueryParameters() { + return customQueryParameters; + } + + /** + * Adds a custom query parameter. + * + * @param key The query parameter key. + * @param value The query parameter value. + * @return The updated VoiceLiveRequestOptions object. + * @throws NullPointerException if {@code key} is null. + */ + public VoiceLiveRequestOptions addCustomQueryParameter(String key, String value) { + Objects.requireNonNull(key, "'key' cannot be null"); + if (this.customQueryParameters == null) { + this.customQueryParameters = new HashMap<>(); + } + this.customQueryParameters.put(key, value); + return this; + } + + /** + * Gets the custom headers. + * + * @return The custom headers. + */ + public HttpHeaders getCustomHeaders() { + return customHeaders; + } + + /** + * Sets the custom headers. + * + * @param customHeaders The custom headers to set. + * @return The updated VoiceLiveRequestOptions object. + */ + public VoiceLiveRequestOptions setCustomHeaders(HttpHeaders customHeaders) { + this.customHeaders = customHeaders != null ? customHeaders : new HttpHeaders(); + return this; + } + + /** + * Adds a custom header. + * + * @param name The header name. + * @param value The header value. + * @return The updated VoiceLiveRequestOptions object. + * @throws NullPointerException if {@code name} is null. + */ + public VoiceLiveRequestOptions addCustomHeader(String name, String value) { + Objects.requireNonNull(name, "'name' cannot be null"); + if (this.customHeaders == null) { + this.customHeaders = new HttpHeaders(); + } + this.customHeaders.set(name, value); + return this; + } +} diff --git a/sdk/ai/azure-ai-voicelive/src/test/java/com/azure/ai/voicelive/VoiceLiveAsyncClientTest.java b/sdk/ai/azure-ai-voicelive/src/test/java/com/azure/ai/voicelive/VoiceLiveAsyncClientTest.java index f71306338a6a..7562512c2592 100644 --- a/sdk/ai/azure-ai-voicelive/src/test/java/com/azure/ai/voicelive/VoiceLiveAsyncClientTest.java +++ b/sdk/ai/azure-ai-voicelive/src/test/java/com/azure/ai/voicelive/VoiceLiveAsyncClientTest.java @@ -167,4 +167,14 @@ void testReturnTypeOptimization() { // The returned Mono should contain a VoiceLiveSessionAsyncClient when subscribed }); } + + @Test + void testStartSessionWithoutModel() { + // Test that startSession() without parameters works + assertDoesNotThrow(() -> { + Mono sessionMono = client.startSession(); + assertNotNull(sessionMono); + }); + } + } diff --git a/sdk/ai/azure-ai-voicelive/src/test/java/com/azure/ai/voicelive/VoiceLiveRequestOptionsTest.java b/sdk/ai/azure-ai-voicelive/src/test/java/com/azure/ai/voicelive/VoiceLiveRequestOptionsTest.java new file mode 100644 index 000000000000..9b87d3b6e2e6 --- /dev/null +++ b/sdk/ai/azure-ai-voicelive/src/test/java/com/azure/ai/voicelive/VoiceLiveRequestOptionsTest.java @@ -0,0 +1,182 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.ai.voicelive; + +import com.azure.ai.voicelive.models.VoiceLiveRequestOptions; +import com.azure.core.credential.KeyCredential; +import com.azure.core.http.HttpHeaders; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.lang.reflect.Method; +import java.net.URI; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests for VoiceLiveRequestOptions feature in VoiceLive client. + */ +@ExtendWith(MockitoExtension.class) +class VoiceLiveRequestOptionsTest { + + @Mock + private KeyCredential mockKeyCredential; + + @Mock + private HttpHeaders mockHeaders; + + private URI testEndpoint; + private String apiVersion; + + @BeforeEach + void setUp() throws Exception { + testEndpoint = new URI("https://test.cognitiveservices.azure.com"); + apiVersion = "2024-10-01-preview"; + } + + @Test + void testRequestOptionsWithCustomQueryParameters() { + // Arrange & Act + VoiceLiveRequestOptions options + = new VoiceLiveRequestOptions().addCustomQueryParameter("deployment-id", "test-deployment") + .addCustomQueryParameter("region", "eastus"); + + // Assert + assertNotNull(options.getCustomQueryParameters()); + assertEquals(2, options.getCustomQueryParameters().size()); + assertEquals("test-deployment", options.getCustomQueryParameters().get("deployment-id")); + assertEquals("eastus", options.getCustomQueryParameters().get("region")); + } + + @Test + void testRequestOptionsWithCustomHeaders() { + // Arrange & Act + VoiceLiveRequestOptions options + = new VoiceLiveRequestOptions().addCustomHeader("X-Custom-Header", "custom-value") + .addCustomHeader("X-Another-Header", "another-value"); + + // Assert + assertNotNull(options.getCustomHeaders()); + assertEquals("custom-value", options.getCustomHeaders().getValue("X-Custom-Header")); + assertEquals("another-value", options.getCustomHeaders().getValue("X-Another-Header")); + } + + @Test + void testRequestOptionsFluentApi() { + // Arrange & Act + VoiceLiveRequestOptions options = new VoiceLiveRequestOptions(); + + VoiceLiveRequestOptions result1 = options.addCustomQueryParameter("key1", "value1"); + VoiceLiveRequestOptions result2 = options.addCustomHeader("Header1", "value1"); + + // Assert + assertSame(options, result1); + assertSame(options, result2); + } + + @Test + void testRequestOptionsSetCustomHeaders() { + // Arrange + HttpHeaders headers = new HttpHeaders(); + headers.set("X-Custom-Header", "value1"); + headers.set("X-Another-Header", "value2"); + + // Act + VoiceLiveRequestOptions options = new VoiceLiveRequestOptions().setCustomHeaders(headers); + + // Assert + assertNotNull(options.getCustomHeaders()); + assertEquals("value1", options.getCustomHeaders().getValue("X-Custom-Header")); + assertEquals("value2", options.getCustomHeaders().getValue("X-Another-Header")); + } + + @Test + void testConvertToWebSocketEndpointWithRequestOptions() throws Exception { + // Arrange + VoiceLiveRequestOptions requestOptions + = new VoiceLiveRequestOptions().addCustomQueryParameter("deployment-id", "test-deployment") + .addCustomQueryParameter("custom-param", "custom-value"); + + VoiceLiveAsyncClient client + = new VoiceLiveAsyncClient(testEndpoint, mockKeyCredential, apiVersion, mockHeaders); + + // Use reflection to access the private method + Method method = VoiceLiveAsyncClient.class.getDeclaredMethod("convertToWebSocketEndpoint", URI.class, + String.class, Map.class); + method.setAccessible(true); + + // Act + URI result = (URI) method.invoke(client, testEndpoint, "gpt-4o-realtime-preview", + requestOptions.getCustomQueryParameters()); + + // Assert + assertNotNull(result); + assertEquals("wss", result.getScheme()); + assertNotNull(result.getQuery()); + assertTrue(result.getQuery().contains("api-version=" + apiVersion)); + assertTrue(result.getQuery().contains("model=gpt-4o-realtime-preview")); + assertTrue(result.getQuery().contains("deployment-id=test-deployment")); + assertTrue(result.getQuery().contains("custom-param=custom-value")); + } + + @Test + void testConvertToWebSocketEndpointWithRequestOptionsWithoutModel() throws Exception { + // Arrange + VoiceLiveRequestOptions requestOptions + = new VoiceLiveRequestOptions().addCustomQueryParameter("deployment-id", "test-deployment") + .addCustomQueryParameter("model", "custom-model-from-params"); + + VoiceLiveAsyncClient client + = new VoiceLiveAsyncClient(testEndpoint, mockKeyCredential, apiVersion, mockHeaders); + + // Use reflection to access the private method + Method method = VoiceLiveAsyncClient.class.getDeclaredMethod("convertToWebSocketEndpoint", URI.class, + String.class, Map.class); + method.setAccessible(true); + + // Act - passing null as model parameter + URI result = (URI) method.invoke(client, testEndpoint, null, requestOptions.getCustomQueryParameters()); + + // Assert + assertNotNull(result); + assertEquals("wss", result.getScheme()); + assertNotNull(result.getQuery()); + assertTrue(result.getQuery().contains("api-version=" + apiVersion)); + assertTrue(result.getQuery().contains("model=custom-model-from-params")); + assertTrue(result.getQuery().contains("deployment-id=test-deployment")); + } + + @Test + void testRequestOptionsQueryParameterPrecedence() throws Exception { + // Arrange - endpoint has api-version, requestOptions overrides it (but SDK should win) + URI endpointWithQuery = new URI("https://test.cognitiveservices.azure.com?api-version=old-version"); + + VoiceLiveRequestOptions requestOptions + = new VoiceLiveRequestOptions().addCustomQueryParameter("deployment-id", "test-deployment"); + + VoiceLiveAsyncClient client + = new VoiceLiveAsyncClient(endpointWithQuery, mockKeyCredential, apiVersion, mockHeaders); + + // Use reflection to access the private method + Method method = VoiceLiveAsyncClient.class.getDeclaredMethod("convertToWebSocketEndpoint", URI.class, + String.class, Map.class); + method.setAccessible(true); + + // Act + URI result = (URI) method.invoke(client, endpointWithQuery, "gpt-4o-realtime-preview", + requestOptions.getCustomQueryParameters()); + + // Assert - SDK's api-version should take precedence + assertNotNull(result); + assertTrue(result.getQuery().contains("api-version=" + apiVersion)); + assertTrue(result.getQuery().contains("deployment-id=test-deployment")); + } +}