Skip to content

Commit 23b81d7

Browse files
Explicit Caching patch (#377)
* Squashed commit of the following: commit acb3806 Author: Mayuresh Agashe <magashe@google.com> Date: Wed Jun 5 00:51:30 2024 +0000 fix update method Change-Id: I433c25b2d80cdf6e483b59f61ff29bb8d2dc6595 commit fb9995c Merge: 4627fe1 7b9758f Author: Mark Daoust <markdaoust@google.com> Date: Tue Jun 4 09:55:38 2024 -0700 Merge branch 'main' into caching Change-Id: I2bade6b0099f12dd37a24fe26cfda1981c58fbc0 commit 4627fe1 Author: Mark Daoust <markdaoust@google.com> Date: Tue Jun 4 09:54:31 2024 -0700 use preview build Change-Id: Ic1cd4fc28f591794dc5fbff0647a00a77ea7f601 commit 8e86ef1 Author: Mayuresh Agashe <magashe@google.com> Date: Thu May 30 16:18:22 2024 +0000 Refactor for genai.protos module Change-Id: I2f02d2421d7303f0309ec86f05d33c07332c03c1 commit 82d3c5a Merge: bf6551a f08c789 Author: Mayuresh Agashe <magashe@google.com> Date: Thu May 30 15:57:27 2024 +0000 Merge branch 'main' of https://github.com/mayureshagashe2105/generative-ai-python into caching Change-Id: Id2b259fe4b2c91653bf5e4d5e883f556366d8676 commit bf6551a Author: Mayuresh Agashe <magashe@google.com> Date: Mon May 27 11:26:03 2024 +0000 Fix types Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075 commit 67472d3 Author: Mayuresh Agashe <magashe@google.com> Date: Mon May 27 11:26:03 2024 +0000 Fix types Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075 commit a1c8c72 Author: Mayuresh Agashe <magashe@google.com> Date: Mon May 27 11:15:15 2024 +0000 Fix docstrings Change-Id: I6020df4e862a4f1d58462a4cd70876a8448293cf commit f48cedc Author: Mayuresh Agashe <magashe@google.com> Date: Mon May 27 11:13:44 2024 +0000 Fix types Change-Id: Ia4bf6b936fab4c1992798c65cff91c15e51a92c0 commit 645ceab Author: Mayuresh Agashe <magashe@google.com> Date: Mon May 27 05:54:26 2024 +0000 blacken Change-Id: I4e073d821d29eea30801bdb7e2a8dc01bb7d6b9a commit 17372e3 Author: Mayuresh Agashe <magashe@google.com> Date: Mon May 27 05:54:06 2024 +0000 Add 'cached_content' to GenerativeModel's repr Change-Id: I06676fad23895e3e1a6393baa938fc1f2df57d80 commit d1fd749 Author: Mayuresh Agashe <magashe@google.com> Date: Mon May 27 05:04:43 2024 +0000 Add type-annotations to __new__ to fix pytype checks Change-Id: I6c69c036e54d56d18ea60368fa0a1dcda2d315fd commit f37df8c Author: Mayuresh Agashe <magashe@google.com> Date: Sun May 26 06:51:54 2024 +0000 mark name as OPTIONAL for CachedContent creation If not provided, the name will be randomly generated Change-Id: Ib95fbafd3dfe098b43164d7ee4d6c2a84b0aae2e commit 59663c8 Author: Mayuresh Agashe <magashe@google.com> Date: Fri May 24 10:22:08 2024 +0000 Add tests Change-Id: I249188fa585bd9b7193efa48b1cfca20b8a79821 commit e1d8c7a Author: Mayuresh Agashe <magashe@google.com> Date: Fri May 24 10:21:42 2024 +0000 Validate name checks for CachedContent creation Change-Id: Ie41602621d99ddff6404c6708c7278e0da790652 commit 2cde1a2 Author: Mayuresh Agashe <magashe@google.com> Date: Thu May 23 18:09:14 2024 +0000 fix tests Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7 commit d862dae Author: Mayuresh Agashe <magashe@google.com> Date: Thu May 23 18:09:14 2024 +0000 fix tests Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7 commit d35cc71 Author: Mayuresh Agashe <magashe@google.com> Date: Thu May 23 23:12:38 2024 +0530 Improve tests commit e65d16e Author: Mayuresh Agashe <magashe@google.com> Date: Thu May 23 23:12:05 2024 +0530 blacken commit cfc936e Author: Mayuresh Agashe <magashe@google.com> Date: Thu May 23 23:10:16 2024 +0530 Stroke out functional approach for CachedContent CURD ops commit afd066d Merge: 6fafe6b 0dca4ce Author: Mayuresh Agashe <magashe@google.com> Date: Wed May 22 23:10:20 2024 +0530 Merge branch 'main' into caching commit 6fafe6b Author: Mayuresh Agashe <magashe@google.com> Date: Wed May 22 10:49:35 2024 +0530 rename get_cached_content to get commit a4ac7a5 Merge: f13228d f987fde Author: Mayuresh Agashe <magashe@google.com> Date: Tue May 21 23:32:41 2024 +0530 Merge branch 'main' into caching commit f13228d Author: Mayuresh Agashe <magashe@google.com> Date: Fri Apr 26 16:54:09 2024 +0000 *Inital prototype for explicit caching *Add basic CURD support for caching *Remove INPUT_ONLY marked fields from CachedContent dataclass *Rename files 'cached_content*' -> 'caching*' *Update 'Create' method for explicit instantination of 'CachedContent' *Add a factory method to instatinate model with `CachedContent` as its context *blacken *Add tests Change-Id: I694545243efda467d6fd599beded0dc6679b727d Change-Id: I7b14d94f729953294780815f4c496888bb2ad46f * Remove auto cache deletion Change-Id: I4658e1c57f967faeb3945dffef0181a456d65370 * Rename _to_dict --> _get_update_fields Change-Id: I3c92c65e8e5b215e98c1ac0eea6db033166dec78 * Fix tests Change-Id: Id36d7606e13d15caf6870f29a108944c7f36eaeb * Set 'CachedContent' as a public property Remove __new__ construct Change-Id: Ie4f5527270be90730341b6c3b67de71b9b6e9c5c * blacken Change-Id: I12498213a7fc2b257827ab0df87c6913e04cad25 * set 'role=user' when content is passed as a str (#4) 'to_content' method assigns a default 'role=user' to all the contents passed as a string Change-Id: I748514a7839b7f1d36150b879c3d1464ca9e11ba * Handle ttl and expire_time separately Change-Id: If9c6f04fe8d419828e3efd2249f0698bca4d5bdc * Remove name param Change-Id: I40fe7c8fafdb014fb9c7e74956452aca9a666641 * Update caching_types.py * Update caching.py * Update docstrs and error messages Change-Id: I111a1218a7d9783d494b84f0a11cb3b76c7ad9da * Update model name to gemini-1.5-pro for caching tests Change-Id: Ibb1f75c409afaac124ef70232be71e3a882f6015 * Remove dafault ttl assignment Let the API set the dafault Change-Id: Id8d125a085ed27229ddb78d5812ed5b5ad39227b * blacken Change-Id: I1d7fe0ec422589e237502b0eda687cf81ef21a21 * Remove client arg Change-Id: I17f05a90a1514f404dd3527c0db1ce6147d2c47a * Add 'usage_metadata' param to CachedContent class Change-Id: Ic527c157bc2cd114948b73a8f1832c21dd61b52e * Add 'display_name' to CachedContent class Change-Id: Id0a9be9d1bfdb94dc9d5c4fc7af9dee89e5365a4 * update generativelanguage version, fix tests Change-Id: I0acc57853ab7dde863bbbe4b30ae3957e6ec3d11 * format Change-Id: Ib2e9a16aaa989021d3498f3e59f9983560919159 * fewer automatic 'role' insertions Change-Id: I0752741532a451f8720fa5e110e68f0b4e66cc4b * cleanup Change-Id: I151a809f6d079b8e4b0ed30d1153a638c98cacfd * Wrap the proto Change-Id: I14b4c54652fb51b867fb43d4b3e9091e6eaccd4e * Apply suggestions from code review Co-authored-by: Mayuresh Agashe <magashe@google.com> * fix Change-Id: I381029fc8fc13c39e432b39084fc8feba305514e * format Change-Id: I8e0b44aebc102d3b2afb27a422c4d70d6c99d5d2 * cleanup Change-Id: I024733b53cede5bfdf957ce7e56d6ad01fd4b2bf * update version Change-Id: Ic95dffb3e945e31adc0d98787942d27289512b8a * fix Change-Id: I6ffdabbddf0e803606b3638521ebfeb6796d2e4b * typing Change-Id: I629d4d111f0e640f4f4bf602ea33f70fdc9ca3e4 * Simplify update method Accept kwargs instead of dict of updates and construct protos using kwargs Change-Id: I7858d585b1aa6b965134e2fb90adff737172af92 * Add repr to CachedContent Change-Id: Id4ec78ebf9d6e96f22f6bf37fc4509268fa552f4 * cleanup Change-Id: I684b46f881735bceb3f9e09d8573721ddb29f98a * blacken Change-Id: I773e7a5b8a222c8b4435470cdc2b53be425d95e4 * Apply suggestions from code review Change-Id: I2a12b9689001bbc41c460db5a9f0e87c77d4caf6 --------- Co-authored-by: Mark Daoust <markdaoust@google.com>
1 parent dbd5498 commit 23b81d7

File tree

6 files changed

+335
-214
lines changed

6 files changed

+335
-214
lines changed

google/generativeai/caching.py

Lines changed: 160 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -14,90 +14,130 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
import dataclasses
1817
import datetime
19-
from typing import Any, Iterable, Optional
18+
import textwrap
19+
from typing import Iterable, Optional
2020

2121
from google.generativeai import protos
22-
from google.generativeai.types.model_types import idecode_time
2322
from google.generativeai.types import caching_types
2423
from google.generativeai.types import content_types
25-
from google.generativeai.utils import flatten_update_paths
2624
from google.generativeai.client import get_default_cache_client
2725

2826
from google.protobuf import field_mask_pb2
29-
import google.ai.generativelanguage as glm
27+
28+
_USER_ROLE = "user"
29+
_MODEL_ROLE = "model"
3030

3131

32-
@dataclasses.dataclass
3332
class CachedContent:
3433
"""Cached content resource."""
3534

36-
name: str
37-
model: str
38-
create_time: datetime.datetime
39-
update_time: datetime.datetime
40-
expire_time: datetime.datetime
35+
def __init__(self, name):
36+
"""Fetches a `CachedContent` resource.
4137
42-
# NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+).
43-
# Adding basic support for now.
44-
def __enter__(self):
45-
return self
38+
Identical to `CachedContent.get`.
4639
47-
def __exit__(self, exc_type, exc_value, exc_tb):
48-
self.delete()
49-
50-
def _to_dict(self) -> protos.CachedContent:
51-
proto_paths = {
52-
"name": self.name,
53-
"model": self.model,
54-
}
55-
return protos.CachedContent(**proto_paths)
56-
57-
def _apply_update(self, path, value):
58-
parts = path.split(".")
59-
for part in parts[:-1]:
60-
self = getattr(self, part)
61-
if parts[-1] == "ttl":
62-
value = self.expire_time + datetime.timedelta(seconds=value["seconds"])
63-
parts[-1] = "expire_time"
64-
setattr(self, parts[-1], value)
40+
Args:
41+
name: The resource name referring to the cached content.
42+
"""
43+
client = get_default_cache_client()
6544

66-
@classmethod
67-
def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent:
68-
# not supposed to get INPUT_ONLY repeated fields, but local gapic lib build
69-
# is returning these, hence setting including_default_value_fields to False
70-
cached_content = type(cached_content).to_dict(
71-
cached_content, including_default_value_fields=False
45+
if "cachedContents/" not in name:
46+
name = "cachedContents/" + name
47+
48+
request = protos.GetCachedContentRequest(name=name)
49+
response = client.get_cached_content(request)
50+
self._proto = response
51+
52+
@property
53+
def name(self) -> str:
54+
return self._proto.name
55+
56+
@property
57+
def model(self) -> str:
58+
return self._proto.model
59+
60+
@property
61+
def display_name(self) -> str:
62+
return self._proto.display_name
63+
64+
@property
65+
def usage_metadata(self) -> protos.CachedContent.UsageMetadata:
66+
return self._proto.usage_metadata
67+
68+
@property
69+
def create_time(self) -> datetime.datetime:
70+
return self._proto.create_time
71+
72+
@property
73+
def update_time(self) -> datetime.datetime:
74+
return self._proto.update_time
75+
76+
@property
77+
def expire_time(self) -> datetime.datetime:
78+
return self._proto.expire_time
79+
80+
def __str__(self):
81+
return textwrap.dedent(
82+
f"""\
83+
CachedContent(
84+
name='{self.name}',
85+
model='{self.model}',
86+
display_name='{self.display_name}',
87+
usage_metadata={'{'}
88+
'total_token_count': {self.usage_metadata.total_token_count},
89+
{'}'},
90+
create_time={self.create_time},
91+
update_time={self.update_time},
92+
expire_time={self.expire_time}
93+
)"""
7294
)
7395

74-
idecode_time(cached_content, "create_time")
75-
idecode_time(cached_content, "update_time")
76-
# always decode `expire_time` as Timestamp is returned
77-
# regardless of what was sent on input
78-
idecode_time(cached_content, "expire_time")
79-
return cls(**cached_content)
96+
__repr__ = __str__
97+
98+
@classmethod
99+
def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent:
100+
"""Creates an instance of CachedContent form an object, without calling `get`."""
101+
self = cls.__new__(cls)
102+
self._proto = protos.CachedContent()
103+
self._update(obj)
104+
return self
105+
106+
def _update(self, updates):
107+
"""Updates this instance inplace, does not call the API's `update` method"""
108+
if isinstance(updates, CachedContent):
109+
updates = updates._proto
110+
111+
if not isinstance(updates, dict):
112+
updates = type(updates).to_dict(updates, including_default_value_fields=False)
113+
114+
for key, value in updates.items():
115+
setattr(self._proto, key, value)
80116

81117
@staticmethod
82118
def _prepare_create_request(
83119
model: str,
84-
name: str | None = None,
120+
*,
121+
display_name: str | None = None,
85122
system_instruction: Optional[content_types.ContentType] = None,
86123
contents: Optional[content_types.ContentsType] = None,
87124
tools: Optional[content_types.FunctionLibraryType] = None,
88125
tool_config: Optional[content_types.ToolConfigType] = None,
89-
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
126+
ttl: Optional[caching_types.TTLTypes] = None,
127+
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
90128
) -> protos.CreateCachedContentRequest:
91129
"""Prepares a CreateCachedContentRequest."""
92-
if name is not None:
93-
if not caching_types.valid_cached_content_name(name):
94-
raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name))
95-
96-
name = "cachedContents/" + name
130+
if ttl and expire_time:
131+
raise ValueError(
132+
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
133+
)
97134

98135
if "/" not in model:
99136
model = "models/" + model
100137

138+
if display_name and len(display_name) > 128:
139+
raise ValueError("`display_name` must be no more than 128 unicode characters.")
140+
101141
if system_instruction:
102142
system_instruction = content_types.to_content(system_instruction)
103143

@@ -110,18 +150,21 @@ def _prepare_create_request(
110150

111151
if contents:
112152
contents = content_types.to_contents(contents)
153+
if not contents[-1].role:
154+
contents[-1].role = _USER_ROLE
113155

114-
if ttl:
115-
ttl = caching_types.to_ttl(ttl)
156+
ttl = caching_types.to_optional_ttl(ttl)
157+
expire_time = caching_types.to_optional_expire_time(expire_time)
116158

117159
cached_content = protos.CachedContent(
118-
name=name,
119160
model=model,
161+
display_name=display_name,
120162
system_instruction=system_instruction,
121163
contents=contents,
122164
tools=tools_lib,
123165
tool_config=tool_config,
124166
ttl=ttl,
167+
expire_time=expire_time,
125168
)
126169

127170
return protos.CreateCachedContentRequest(cached_content=cached_content)
@@ -130,48 +173,55 @@ def _prepare_create_request(
130173
def create(
131174
cls,
132175
model: str,
133-
name: str | None = None,
176+
*,
177+
display_name: str | None = None,
134178
system_instruction: Optional[content_types.ContentType] = None,
135179
contents: Optional[content_types.ContentsType] = None,
136180
tools: Optional[content_types.FunctionLibraryType] = None,
137181
tool_config: Optional[content_types.ToolConfigType] = None,
138-
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
139-
client: glm.CacheServiceClient | None = None,
182+
ttl: Optional[caching_types.TTLTypes] = None,
183+
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
140184
) -> CachedContent:
141185
"""Creates `CachedContent` resource.
142186
143187
Args:
144188
model: The name of the `model` to use for cached content creation.
145189
Any `CachedContent` resource can be only used with the
146190
`model` it was created for.
147-
name: The resource name referring to the cached content.
191+
display_name: The user-generated meaningful display name
192+
of the cached content. `display_name` must be no
193+
more than 128 unicode characters.
148194
system_instruction: Developer set system instruction.
149195
contents: Contents to cache.
150196
tools: A list of `Tools` the model may use to generate response.
151197
tool_config: Config to apply to all tools.
152198
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
199+
`ttl` and `expire_time` are exclusive arguments.
200+
expire_time: Expiration time for cached resource.
201+
`ttl` and `expire_time` are exclusive arguments.
153202
154203
Returns:
155204
`CachedContent` resource with specified name.
156205
"""
157-
if client is None:
158-
client = get_default_cache_client()
206+
client = get_default_cache_client()
159207

160208
request = cls._prepare_create_request(
161209
model=model,
162-
name=name,
210+
display_name=display_name,
163211
system_instruction=system_instruction,
164212
contents=contents,
165213
tools=tools,
166214
tool_config=tool_config,
167215
ttl=ttl,
216+
expire_time=expire_time,
168217
)
169218

170219
response = client.create_cached_content(request)
171-
return cls._decode_cached_content(response)
220+
result = CachedContent._from_obj(response)
221+
return result
172222

173223
@classmethod
174-
def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent:
224+
def get(cls, name: str) -> CachedContent:
175225
"""Fetches required `CachedContent` resource.
176226
177227
Args:
@@ -180,20 +230,18 @@ def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedC
180230
Returns:
181231
`CachedContent` resource with specified `name`.
182232
"""
183-
if client is None:
184-
client = get_default_cache_client()
233+
client = get_default_cache_client()
185234

186235
if "cachedContents/" not in name:
187236
name = "cachedContents/" + name
188237

189238
request = protos.GetCachedContentRequest(name=name)
190239
response = client.get_cached_content(request)
191-
return cls._decode_cached_content(response)
240+
result = CachedContent._from_obj(response)
241+
return result
192242

193243
@classmethod
194-
def list(
195-
cls, page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None
196-
) -> Iterable[CachedContent]:
244+
def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]:
197245
"""Lists `CachedContent` objects associated with the project.
198246
199247
Args:
@@ -203,58 +251,64 @@ def list(
203251
Returns:
204252
A paginated list of `CachedContent` objects.
205253
"""
206-
if client is None:
207-
client = get_default_cache_client()
254+
client = get_default_cache_client()
208255

209256
request = protos.ListCachedContentsRequest(page_size=page_size)
210257
for cached_content in client.list_cached_contents(request):
211-
yield cls._decode_cached_content(cached_content)
258+
cached_content = CachedContent._from_obj(cached_content)
259+
yield cached_content
212260

213-
def delete(self, client: glm.CachedServiceClient | None = None) -> None:
261+
def delete(self) -> None:
214262
"""Deletes `CachedContent` resource."""
215-
if client is None:
216-
client = get_default_cache_client()
263+
client = get_default_cache_client()
217264

218265
request = protos.DeleteCachedContentRequest(name=self.name)
219266
client.delete_cached_content(request)
220267
return
221268

222269
def update(
223270
self,
224-
updates: dict[str, Any],
225-
client: glm.CacheServiceClient | None = None,
226-
) -> CachedContent:
271+
*,
272+
ttl: Optional[caching_types.TTLTypes] = None,
273+
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
274+
) -> None:
227275
"""Updates requested `CachedContent` resource.
228276
229277
Args:
230-
updates: The list of fields to update. Currently only
231-
`ttl/expire_time` is supported as an update path.
232-
233-
Returns:
234-
`CachedContent` object with specified updates.
278+
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
279+
`ttl` and `expire_time` are exclusive arguments.
280+
expire_time: Expiration time for cached resource.
281+
`ttl` and `expire_time` are exclusive arguments.
235282
"""
236-
if client is None:
237-
client = get_default_cache_client()
238-
239-
updates = flatten_update_paths(updates)
240-
for update_path in updates:
241-
if update_path == "ttl":
242-
updates = updates.copy()
243-
update_path_val = updates.get(update_path)
244-
updates[update_path] = caching_types.to_ttl(update_path_val)
245-
else:
246-
raise ValueError(
247-
f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead."
248-
)
249-
field_mask = field_mask_pb2.FieldMask()
283+
client = get_default_cache_client()
250284

251-
for path in updates.keys():
252-
field_mask.paths.append(path)
253-
for path, value in updates.items():
254-
self._apply_update(path, value)
285+
if ttl and expire_time:
286+
raise ValueError(
287+
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
288+
)
255289

256-
request = protos.UpdateCachedContentRequest(
257-
cached_content=self._to_dict(), update_mask=field_mask
290+
ttl = caching_types.to_optional_ttl(ttl)
291+
expire_time = caching_types.to_optional_expire_time(expire_time)
292+
293+
updates = protos.CachedContent(
294+
name=self.name,
295+
ttl=ttl,
296+
expire_time=expire_time,
258297
)
259-
client.update_cached_content(request)
260-
return self
298+
299+
field_mask = field_mask_pb2.FieldMask()
300+
301+
if ttl:
302+
field_mask.paths.append("ttl")
303+
elif expire_time:
304+
field_mask.paths.append("expire_time")
305+
else:
306+
raise ValueError(
307+
f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`."
308+
)
309+
310+
request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask)
311+
updated_cc = client.update_cached_content(request)
312+
self._update(updated_cc)
313+
314+
return

0 commit comments

Comments
 (0)