Skip to content

Commit b1f454e

Browse files
committed
Merge branch 'main' into OllamaRAGExample
2 parents b7eb334 + f8acb3a commit b1f454e

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

azureChat.m

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@
103103
function this = azureChat(systemPrompt, nvp)
104104
arguments
105105
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
106-
nvp.Endpoint {mustBeNonzeroLengthTextScalar}
107-
nvp.Deployment {mustBeNonzeroLengthTextScalar}
106+
nvp.Endpoint (1,1) string {mustBeNonzeroLengthTextScalar}
107+
nvp.Deployment (1,1) string {mustBeNonzeroLengthTextScalar}
108108
nvp.APIKey {mustBeNonzeroLengthTextScalar}
109109
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
110-
nvp.APIVersion (1,1) {mustBeAPIVersion} = "2024-02-01"
110+
nvp.APIVersion (1,1) string {mustBeAPIVersion} = "2024-02-01"
111111
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
112112
nvp.TopP {llms.utils.mustBeValidTopP} = 1
113113
nvp.StopSequences {llms.utils.mustBeValidStop} = {}

tests/tazureChat.m

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
InvalidGenerateInput = iGetInvalidGenerateInput;
99
InvalidValuesSetters = iGetInvalidValuesSetters;
1010
StringInputs = struct('string',{"hi"},'char',{'hi'},'cellstr',{{'hi'}});
11+
APIVersions = iGetAPIVersions();
1112
end
1213

1314
methods(Test)
@@ -40,6 +41,14 @@ function doGenerate(testCase,StringInputs)
4041
testCase.verifyGreaterThan(strlength(response),0);
4142
end
4243

44+
function doGenerateUsingSystemPrompt(testCase)
45+
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
46+
chat = azureChat("You are a helpful assistant");
47+
response = testCase.verifyWarningFree(@() generate(chat,"Hi"));
48+
testCase.verifyClass(response,'string');
49+
testCase.verifyGreaterThan(strlength(response),0);
50+
end
51+
4352
function generateMultipleResponses(testCase)
4453
chat = azureChat;
4554
[~,~,response] = generate(chat,"What is a cat?",NumCompletions=3);
@@ -150,6 +159,38 @@ function keyNotFound(testCase)
150159
unsetenv("AZURE_OPENAI_API_KEY");
151160
testCase.verifyError(@()azureChat, "llms:keyMustBeSpecified");
152161
end
162+
163+
function canUseAPIVersions(testCase, APIVersions)
164+
% Test that we can use different APIVersion value to call
165+
% azureChat.generate
166+
167+
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
168+
chat = azureChat("APIVersion", APIVersions);
169+
170+
response = testCase.verifyWarningFree(@() generate(chat,"How similar is the DNA of a cat and a tiger?"));
171+
testCase.verifyClass(response,'string');
172+
testCase.verifyGreaterThan(strlength(response),0);
173+
end
174+
175+
function endpointNotFound(testCase)
176+
% to verify the error, we need to unset the environment variable
177+
% AZURE_OPENAI_ENDPOINT, if given. Use a fixture to restore the
178+
% value on leaving the test point
179+
import matlab.unittest.fixtures.EnvironmentVariableFixture
180+
testCase.applyFixture(EnvironmentVariableFixture("AZURE_OPENAI_ENDPOINT","dummy"));
181+
unsetenv("AZURE_OPENAI_ENDPOINT");
182+
testCase.verifyError(@()azureChat, "llms:endpointMustBeSpecified");
183+
end
184+
185+
function deploymentNotFound(testCase)
186+
% to verify the error, we need to unset the environment variable
187+
% AZURE_OPENAI_DEPLOYMENT, if given. Use a fixture to restore the
188+
% value on leaving the test point
189+
import matlab.unittest.fixtures.EnvironmentVariableFixture
190+
testCase.applyFixture(EnvironmentVariableFixture("AZURE_OPENAI_DEPLOYMENT","dummy"));
191+
unsetenv("AZURE_OPENAI_DEPLOYMENT");
192+
testCase.verifyError(@()azureChat, "llms:deploymentMustBeSpecified");
193+
end
153194
end
154195
end
155196

@@ -446,3 +487,7 @@ function keyNotFound(testCase)
446487
"Input",{{ validMessages "ToolChoice" ["validfunction", "validfunction"] }},...
447488
"Error","MATLAB:validators:mustBeTextScalar"));
448489
end
490+
491+
function apiVersions = iGetAPIVersions()
492+
apiVersions = cellstr(llms.azure.apiVersions);
493+
end

tests/tollamaChat.m

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ function doGenerate(testCase,StringInputs)
3939
testCase.verifyGreaterThan(strlength(response),0);
4040
end
4141

42+
function doGenerateUsingSystemPrompt(testCase)
43+
chat = ollamaChat("mistral","You are a helpful assistant");
44+
response = testCase.verifyWarningFree(@() generate(chat,"Hi"));
45+
testCase.verifyClass(response,'string');
46+
testCase.verifyGreaterThan(strlength(response),0);
47+
end
48+
4249
function extremeTopK(testCase)
4350
% setting top-k to k=1 leaves no random choice,
4451
% so we expect to get a fixed response.

0 commit comments

Comments
 (0)