Skip to content

Commit

Permalink
feat: GA cache python SDK
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=#4861 from googleapis:release-please--branches--main 039f2cb
PiperOrigin-RevId: 718130866
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jan 22, 2025
1 parent c2e7ce4 commit 6b6acdc
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 48 deletions.
1 change: 1 addition & 0 deletions google/cloud/aiplatform/compat/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
annotation_spec as annotation_spec_v1,
artifact as artifact_v1,
batch_prediction_job as batch_prediction_job_v1,
cached_content as cached_content_v1,
completion_stats as completion_stats_v1,
context as context_v1,
custom_job as custom_job_v1,
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/vertexai/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import json
import mock
import pytest
from vertexai.preview import caching
from vertexai.caching import _caching
from google.cloud.aiplatform import initializer
import vertexai
from google.cloud.aiplatform_v1beta1.types.cached_content import (
Expand Down Expand Up @@ -141,7 +141,7 @@ def list_cached_contents(self, request):

@pytest.mark.usefixtures("google_auth_mock")
class TestCaching:
"""Unit tests for caching.CachedContent."""
"""Unit tests for _caching.CachedContent."""

def setup_method(self):
vertexai.init(
Expand All @@ -156,7 +156,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
full_resource_name = (
"projects/123/locations/europe-west1/cachedContents/contents-id"
)
cache = caching.CachedContent(
cache = _caching.CachedContent(
cached_content_name=full_resource_name,
)

Expand All @@ -166,7 +166,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
def test_constructor_with_only_content_id(self, mock_get_cached_content):
partial_resource_name = "contents-id"

cache = caching.CachedContent(
cache = _caching.CachedContent(
cached_content_name=partial_resource_name,
)

Expand All @@ -179,7 +179,7 @@ def test_constructor_with_only_content_id(self, mock_get_cached_content):
def test_get_with_content_id(self, mock_get_cached_content):
partial_resource_name = "contents-id"

cache = caching.CachedContent.get(
cache = _caching.CachedContent.get(
cached_content_name=partial_resource_name,
)

Expand All @@ -192,7 +192,7 @@ def test_get_with_content_id(self, mock_get_cached_content):
def test_create_with_real_payload(
self, mock_create_cached_content, mock_get_cached_content
):
cache = caching.CachedContent.create(
cache = _caching.CachedContent.create(
model_name="model-name",
system_instruction=GapicContent(
role="system", parts=[GapicPart(text="system instruction")]
Expand All @@ -219,7 +219,7 @@ def test_create_with_real_payload(
def test_create_with_real_payload_and_wrapped_type(
self, mock_create_cached_content, mock_get_cached_content
):
cache = caching.CachedContent.create(
cache = _caching.CachedContent.create(
model_name="model-name",
system_instruction="Please answer my questions with cool",
tools=[],
Expand All @@ -239,15 +239,15 @@ def test_create_with_real_payload_and_wrapped_type(
assert cache.display_name == _TEST_DISPLAY_NAME

def test_list(self, mock_list_cached_contents):
cached_contents = caching.CachedContent.list()
cached_contents = _caching.CachedContent.list()
for i, cached_content in enumerate(cached_contents):
assert cached_content.name == f"cached_content{i + 1}_from_list_request"
assert cached_content.model_name == f"model-name{i + 1}"

def test_print_a_cached_content(
self, mock_create_cached_content, mock_get_cached_content
):
cached_content = caching.CachedContent.create(
cached_content = _caching.CachedContent.create(
model_name="model-name",
system_instruction="Please answer my questions with cool",
tools=[],
Expand Down
25 changes: 25 additions & 0 deletions vertexai/caching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Classes for working with the Gemini models."""

# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.caching._caching import (
CachedContent,
)

__all__ = [
"CachedContent",
]
78 changes: 39 additions & 39 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,45 @@ def start_chat(
response_validation=response_validation,
)

@classmethod
def from_cached_content(
cls,
cached_content: Union[str, "caching.CachedContent"],
*,
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
) -> "_GenerativeModel":
"""Creates a model from cached content.
Creates a model instance with an existing cached content. The cached
content becomes the prefix of the requesting contents.
Args:
cached_content: The cached content resource name or object.
generation_config: The generation config to use for this model.
safety_settings: The safety settings to use for this model.
Returns:
A model instance with the cached content wtih cached content as
prefix of all its requests.
"""
if isinstance(cached_content, str):
from vertexai.caching import _caching

cached_content = _caching.CachedContent.get(cached_content)
model_name = cached_content.model_name
model = cls(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
tools=None,
tool_config=None,
system_instruction=None,
)
model._cached_content = cached_content

return model


_SUCCESSFUL_FINISH_REASONS = [
gapic_content_types.Candidate.FinishReason.STOP,
Expand Down Expand Up @@ -3515,42 +3554,3 @@ def start_chat(
response_validation=response_validation,
responder=responder,
)

@classmethod
def from_cached_content(
cls,
cached_content: Union[str, "caching.CachedContent"],
*,
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
) -> "_GenerativeModel":
"""Creates a model from cached content.
Creates a model instance with an existing cached content. The cached
content becomes the prefix of the requesting contents.
Args:
cached_content: The cached content resource name or object.
generation_config: The generation config to use for this model.
safety_settings: The safety settings to use for this model.
Returns:
A model instance with the cached content wtih cached content as
prefix of all its requests.
"""
if isinstance(cached_content, str):
from vertexai.preview import caching

cached_content = caching.CachedContent.get(cached_content)
model_name = cached_content.model_name
model = cls(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
tools=None,
tool_config=None,
system_instruction=None,
)
model._cached_content = cached_content

return model

0 comments on commit 6b6acdc

Please sign in to comment.