Skip to content

Explicit Caching #355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f13228d
*Inital prototype for explicit caching
mayureshagashe2105 Apr 26, 2024
a4ac7a5
Merge branch 'main' into caching
mayureshagashe2105 May 21, 2024
6fafe6b
rename get_cached_content to get
mayureshagashe2105 May 22, 2024
afd066d
Merge branch 'main' into caching
mayureshagashe2105 May 22, 2024
cfc936e
Stroke out functional approach for CachedContent CURD ops
mayureshagashe2105 May 23, 2024
e65d16e
blacken
mayureshagashe2105 May 23, 2024
d35cc71
Improve tests
mayureshagashe2105 May 23, 2024
d862dae
fix tests
mayureshagashe2105 May 23, 2024
2cde1a2
fix tests
mayureshagashe2105 May 23, 2024
e1d8c7a
Validate name checks for CachedContent creation
mayureshagashe2105 May 24, 2024
59663c8
Add tests
mayureshagashe2105 May 24, 2024
f37df8c
mark name as OPTIONAL for CachedContent creation
mayureshagashe2105 May 26, 2024
d1fd749
Add type-annotations to __new__ to fix pytype checks
mayureshagashe2105 May 27, 2024
17372e3
Add 'cached_content' to GenerativeModel's repr
mayureshagashe2105 May 27, 2024
645ceab
blacken
mayureshagashe2105 May 27, 2024
f48cedc
Fix types
mayureshagashe2105 May 27, 2024
a1c8c72
Fix docstrings
mayureshagashe2105 May 27, 2024
67472d3
Fix types
mayureshagashe2105 May 27, 2024
bf6551a
Fix types
mayureshagashe2105 May 27, 2024
82d3c5a
Merge branch 'main' of https://github.com/mayureshagashe2105/generati…
mayureshagashe2105 May 30, 2024
8e86ef1
Refactor for genai.protos module
mayureshagashe2105 May 30, 2024
4627fe1
use preview build
MarkDaoust Jun 4, 2024
fb9995c
Merge branch 'main' into caching
MarkDaoust Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions google/generativeai/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Optional, Iterable

import google.ai.generativelanguage as glm

from google.generativeai.types import caching_types
from google.generativeai.types import content_types
from google.generativeai.client import get_default_cache_client


# alias for `caching_types.CachedContent`.
CachedContent = caching_types.CachedContent


def get_cached_content(name: str, client: glm.CacheServiceClient | None = None) -> CachedContent:
"""Fetches required `CachedContent` resource.

Args:
name: name: The resource name referring to the cached content.

Returns:
`CachedContent` resource with specified name.
"""
return CachedContent.get_cached_content(name=name, client=client)


def delete_cached_content(name: str, client: glm.CacheServiceClient | None = None) -> None:
"""Deletes `CachedContent` resource.

Args:
name: The resource name referring to the cached content.
Format: cachedContents/{id}.
"""
if client is None:
client = get_default_cache_client()

if "cachedContents/" not in name:
name = "cachedContents/" + name

request = glm.DeleteCachedContentRequest(name=name)
client.delete_cached_content(request)
return


def list_cached_contents(
page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None
) -> Iterable[CachedContent]:
"""Lists `CachedContent` objects associated with the project.

Args:
page_size: The maximum number of permissions to return (per page). The service may return fewer `CachedContent` objects.

Returns:
A paginated list of `CachedContent` objects.
"""
if client is None:
client = get_default_cache_client()

request = glm.ListCachedContentsRequest(page_size=page_size)
for cached_content in client.list_cached_contents(request):
yield caching_types.decode_cached_content(cached_content)
4 changes: 4 additions & 0 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ def configure(
_client_manager.configure()


def get_default_cache_client() -> glm.CacheServiceClient:
return _client_manager.get_default_client("cache")


def get_default_discuss_client() -> glm.DiscussServiceClient:
return _client_manager.get_default_client("discuss")

Expand Down
72 changes: 70 additions & 2 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from collections.abc import Iterable
import dataclasses
import textwrap
from typing import Any
from typing import Union
from typing import Any, Union, overload
import reprlib

# pylint: disable=bad-continuation, line-too-long
Expand All @@ -15,10 +14,13 @@
import google.api_core.exceptions
from google.ai import generativelanguage as glm
from google.generativeai import client

from google.generativeai import caching
from google.generativeai.types import content_types
from google.generativeai.types import generation_types
from google.generativeai.types import helper_types
from google.generativeai.types import safety_types
from google.generativeai.types import caching_types


class GenerativeModel:
Expand Down Expand Up @@ -96,6 +98,15 @@ def __init__(
self._client = None
self._async_client = None

def __new__(cls, *args, **kwargs):
self = super().__new__(cls)

if cached_instance := kwargs.pop("cached_content", None):
setattr(self, "_cached_content", cached_instance.name)
setattr(cls, "cached_content", property(fget=lambda self: self._cached_content))

return self

@property
def model_name(self):
return self._model_name
Expand Down Expand Up @@ -129,6 +140,13 @@ def _prepare_request(
tool_config: content_types.ToolConfigType | None,
) -> glm.GenerateContentRequest:
"""Creates a `glm.GenerateContentRequest` from raw inputs."""
if hasattr(self, "cached_content") and any([self._system_instruction, tools, tool_config]):
raise ValueError(
"`tools`, `tool_config`, `system_instruction` cannot be set on a model instantinated with `cached_content` as its context."
)

cached_content = getattr(self, "cached_content", None)

tools_lib = self._get_tools_lib(tools)
if tools_lib is not None:
tools_lib = tools_lib.to_proto()
Expand Down Expand Up @@ -157,6 +175,7 @@ def _prepare_request(
tools=tools_lib,
tool_config=tool_config,
system_instruction=self._system_instruction,
cached_content=cached_content,
)

def _get_tools_lib(
Expand All @@ -167,6 +186,55 @@ def _get_tools_lib(
else:
return content_types.to_function_library(tools)

@overload
@classmethod
def from_cached_content(
cls,
cached_content: str,
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
) -> GenerativeModel: ...

@overload
@classmethod
def from_cached_content(
cls,
cached_content: caching_types.CachedContent,
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
) -> GenerativeModel: ...

@classmethod
def from_cached_content(
cls,
cached_content: str | caching_types.CachedContent,
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
) -> GenerativeModel:
"""Creates a model with `cached_content` as model's context.

Args:
cached_content: context for the model.

Returns:
`GenerativeModel` object with `cached_content` as its context.
"""
if isinstance(cached_content, str):
cached_content = caching.get_cached_content(name=cached_content)

# call __new__ with the cached_content to set the model's context. This is done to avoid
# the exposing `cached_content` as a public attribute.
self = cls.__new__(cls, cached_content=cached_content)

# call __init__ to set the model's `generation_config`, `safety_settings`.
# `model_name` will be the name of the model for which the `cached_content` was created.
self.__init__(
model_name=cached_content.model,
generation_config=generation_config,
safety_settings=safety_settings,
)
return self

def generate_content(
self,
contents: content_types.ContentsType,
Expand Down
Loading
Loading