diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/vertex_model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/vertex_model.dart index 5bf5812a57e5..5b2b21517bef 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/vertex_model.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/vertex_model.dart @@ -59,8 +59,7 @@ final class GenerativeModel { List? tools, Content? systemInstruction, ToolConfig? toolConfig, - }) : _firebaseApp = app, - _googleAIModel = createModelWithBaseUri( + }) : _googleAIModel = createModelWithBaseUri( model: _normalizeModelName(model), apiKey: app.options.apiKey, baseUri: _vertexUri(app, location), @@ -75,7 +74,6 @@ final class GenerativeModel { : [], toolConfig: toolConfig?.toGoogleAI(), ); - final FirebaseApp _firebaseApp; final google_ai.GenerativeModel _googleAIModel; static const _modelsPrefix = 'models/'; @@ -92,15 +90,6 @@ final class GenerativeModel { ); } - static google_ai.GenerationConfig _convertGenerationConfig( - GenerationConfig? config, FirebaseApp app) { - if (config == null) { - return google_ai.GenerationConfig(); - } else { - return config.toGoogleAI(); - } - } - static FutureOr> Function() _firebaseTokens( FirebaseAppCheck? appCheck, FirebaseAuth? auth) { return () async { @@ -135,7 +124,9 @@ final class GenerativeModel { /// ``` Future generateContent(Iterable prompt, {List? safetySettings, - GenerationConfig? generationConfig}) async { + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig}) async { Iterable googlePrompt = prompt.map((content) => content.toGoogleAI()); List googleSafetySettings = safetySettings != null @@ -143,8 +134,11 @@ final class GenerativeModel { : []; final response = await _googleAIModel.generateContent(googlePrompt, safetySettings: googleSafetySettings, - generationConfig: - _convertGenerationConfig(generationConfig, _firebaseApp)); + generationConfig: generationConfig?.toGoogleAI(), + tools: tools != null + ? tools.map((tool) => tool.toGoogleAI()).toList() + : [], + toolConfig: toolConfig?.toGoogleAI()); return response.toVertex(); } @@ -163,13 +157,19 @@ final class GenerativeModel { Stream generateContentStream( Iterable prompt, {List? safetySettings, - GenerationConfig? generationConfig}) { + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig}) { return _googleAIModel .generateContentStream(prompt.map((content) => content.toGoogleAI()), safetySettings: safetySettings != null ? safetySettings.map((setting) => setting.toGoogleAI()).toList() : [], - generationConfig: generationConfig?.toGoogleAI()) + generationConfig: generationConfig?.toGoogleAI(), + tools: tools != null + ? tools.map((tool) => tool.toGoogleAI()).toList() + : [], + toolConfig: toolConfig?.toGoogleAI()) .map((r) => r.toVertex()); } @@ -190,9 +190,23 @@ final class GenerativeModel { /// print(response.text); /// } /// ``` - Future countTokens(Iterable contents) async { + Future countTokens( + Iterable contents, { + List? safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + }) async { return _googleAIModel - .countTokens(contents.map((e) => e.toGoogleAI())) + .countTokens(contents.map((e) => e.toGoogleAI()), + safetySettings: safetySettings != null + ? safetySettings.map((setting) => setting.toGoogleAI()).toList() + : [], + generationConfig: generationConfig?.toGoogleAI(), + tools: tools != null + ? tools.map((tool) => tool.toGoogleAI()).toList() + : [], + toolConfig: toolConfig?.toGoogleAI()) .then((r) => r.toVertex()); } @@ -207,10 +221,12 @@ final class GenerativeModel { /// (await model.embedContent([Content.text(prompt)])).embedding.values; /// ``` Future embedContent(Content content, - {TaskType? taskType, String? title}) async { + {TaskType? taskType, String? title, int? outputDimensionality}) async { return _googleAIModel .embedContent(content.toGoogleAI(), - taskType: taskType?.toGoogleAI(), title: title) + taskType: taskType?.toGoogleAI(), + title: title, + outputDimensionality: outputDimensionality) .then((r) => r.toVertex()); } diff --git a/packages/firebase_vertexai/firebase_vertexai/pubspec.yaml b/packages/firebase_vertexai/firebase_vertexai/pubspec.yaml index 34c427d40576..0859b7b5f07e 100644 --- a/packages/firebase_vertexai/firebase_vertexai/pubspec.yaml +++ b/packages/firebase_vertexai/firebase_vertexai/pubspec.yaml @@ -14,7 +14,7 @@ dependencies: firebase_core_platform_interface: ^5.0.0 flutter: sdk: flutter - google_generative_ai: ^0.4.0 + google_generative_ai: ^0.4.1 dev_dependencies: flutter_lints: ^3.0.0