14
14
# limitations under the License.
15
15
from __future__ import annotations
16
16
17
- import dataclasses
18
17
import datetime
19
- from typing import Any , Iterable , Optional
18
+ import textwrap
19
+ from typing import Iterable , Optional
20
20
21
21
from google .generativeai import protos
22
- from google .generativeai .types .model_types import idecode_time
23
22
from google .generativeai .types import caching_types
24
23
from google .generativeai .types import content_types
25
- from google .generativeai .utils import flatten_update_paths
26
24
from google .generativeai .client import get_default_cache_client
27
25
28
26
from google .protobuf import field_mask_pb2
29
- import google .ai .generativelanguage as glm
27
+
28
+ _USER_ROLE = "user"
29
+ _MODEL_ROLE = "model"
30
30
31
31
32
- @dataclasses .dataclass
33
32
class CachedContent :
34
33
"""Cached content resource."""
35
34
36
- name : str
37
- model : str
38
- create_time : datetime .datetime
39
- update_time : datetime .datetime
40
- expire_time : datetime .datetime
35
+ def __init__ (self , name ):
36
+ """Fetches a `CachedContent` resource.
41
37
42
- # NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+).
43
- # Adding basic support for now.
44
- def __enter__ (self ):
45
- return self
38
+ Identical to `CachedContent.get`.
46
39
47
- def __exit__ (self , exc_type , exc_value , exc_tb ):
48
- self .delete ()
49
-
50
- def _to_dict (self ) -> protos .CachedContent :
51
- proto_paths = {
52
- "name" : self .name ,
53
- "model" : self .model ,
54
- }
55
- return protos .CachedContent (** proto_paths )
56
-
57
- def _apply_update (self , path , value ):
58
- parts = path .split ("." )
59
- for part in parts [:- 1 ]:
60
- self = getattr (self , part )
61
- if parts [- 1 ] == "ttl" :
62
- value = self .expire_time + datetime .timedelta (seconds = value ["seconds" ])
63
- parts [- 1 ] = "expire_time"
64
- setattr (self , parts [- 1 ], value )
40
+ Args:
41
+ name: The resource name referring to the cached content.
42
+ """
43
+ client = get_default_cache_client ()
65
44
66
- @classmethod
67
- def _decode_cached_content (cls , cached_content : protos .CachedContent ) -> CachedContent :
68
- # not supposed to get INPUT_ONLY repeated fields, but local gapic lib build
69
- # is returning these, hence setting including_default_value_fields to False
70
- cached_content = type (cached_content ).to_dict (
71
- cached_content , including_default_value_fields = False
45
+ if "cachedContents/" not in name :
46
+ name = "cachedContents/" + name
47
+
48
+ request = protos .GetCachedContentRequest (name = name )
49
+ response = client .get_cached_content (request )
50
+ self ._proto = response
51
+
52
+ @property
53
+ def name (self ) -> str :
54
+ return self ._proto .name
55
+
56
+ @property
57
+ def model (self ) -> str :
58
+ return self ._proto .model
59
+
60
+ @property
61
+ def display_name (self ) -> str :
62
+ return self ._proto .display_name
63
+
64
+ @property
65
+ def usage_metadata (self ) -> protos .CachedContent .UsageMetadata :
66
+ return self ._proto .usage_metadata
67
+
68
+ @property
69
+ def create_time (self ) -> datetime .datetime :
70
+ return self ._proto .create_time
71
+
72
+ @property
73
+ def update_time (self ) -> datetime .datetime :
74
+ return self ._proto .update_time
75
+
76
+ @property
77
+ def expire_time (self ) -> datetime .datetime :
78
+ return self ._proto .expire_time
79
+
80
+ def __str__ (self ):
81
+ return textwrap .dedent (
82
+ f"""\
83
+ CachedContent(
84
+ name='{ self .name } ',
85
+ model='{ self .model } ',
86
+ display_name='{ self .display_name } ',
87
+ usage_metadata={ '{' }
88
+ 'total_token_count': { self .usage_metadata .total_token_count } ,
89
+ { '}' } ,
90
+ create_time={ self .create_time } ,
91
+ update_time={ self .update_time } ,
92
+ expire_time={ self .expire_time }
93
+ )"""
72
94
)
73
95
74
- idecode_time (cached_content , "create_time" )
75
- idecode_time (cached_content , "update_time" )
76
- # always decode `expire_time` as Timestamp is returned
77
- # regardless of what was sent on input
78
- idecode_time (cached_content , "expire_time" )
79
- return cls (** cached_content )
96
+ __repr__ = __str__
97
+
98
+ @classmethod
99
+ def _from_obj (cls , obj : CachedContent | protos .CachedContent | dict ) -> CachedContent :
100
+ """Creates an instance of CachedContent form an object, without calling `get`."""
101
+ self = cls .__new__ (cls )
102
+ self ._proto = protos .CachedContent ()
103
+ self ._update (obj )
104
+ return self
105
+
106
+ def _update (self , updates ):
107
+ """Updates this instance inplace, does not call the API's `update` method"""
108
+ if isinstance (updates , CachedContent ):
109
+ updates = updates ._proto
110
+
111
+ if not isinstance (updates , dict ):
112
+ updates = type (updates ).to_dict (updates , including_default_value_fields = False )
113
+
114
+ for key , value in updates .items ():
115
+ setattr (self ._proto , key , value )
80
116
81
117
@staticmethod
82
118
def _prepare_create_request (
83
119
model : str ,
84
- name : str | None = None ,
120
+ * ,
121
+ display_name : str | None = None ,
85
122
system_instruction : Optional [content_types .ContentType ] = None ,
86
123
contents : Optional [content_types .ContentsType ] = None ,
87
124
tools : Optional [content_types .FunctionLibraryType ] = None ,
88
125
tool_config : Optional [content_types .ToolConfigType ] = None ,
89
- ttl : Optional [caching_types .ExpirationTypes ] = datetime .timedelta (hours = 1 ),
126
+ ttl : Optional [caching_types .TTLTypes ] = None ,
127
+ expire_time : Optional [caching_types .ExpireTimeTypes ] = None ,
90
128
) -> protos .CreateCachedContentRequest :
91
129
"""Prepares a CreateCachedContentRequest."""
92
- if name is not None :
93
- if not caching_types .valid_cached_content_name (name ):
94
- raise ValueError (caching_types .NAME_ERROR_MESSAGE .format (name = name ))
95
-
96
- name = "cachedContents/" + name
130
+ if ttl and expire_time :
131
+ raise ValueError (
132
+ "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
133
+ )
97
134
98
135
if "/" not in model :
99
136
model = "models/" + model
100
137
138
+ if display_name and len (display_name ) > 128 :
139
+ raise ValueError ("`display_name` must be no more than 128 unicode characters." )
140
+
101
141
if system_instruction :
102
142
system_instruction = content_types .to_content (system_instruction )
103
143
@@ -110,18 +150,21 @@ def _prepare_create_request(
110
150
111
151
if contents :
112
152
contents = content_types .to_contents (contents )
153
+ if not contents [- 1 ].role :
154
+ contents [- 1 ].role = _USER_ROLE
113
155
114
- if ttl :
115
- ttl = caching_types .to_ttl ( ttl )
156
+ ttl = caching_types . to_optional_ttl ( ttl )
157
+ expire_time = caching_types .to_optional_expire_time ( expire_time )
116
158
117
159
cached_content = protos .CachedContent (
118
- name = name ,
119
160
model = model ,
161
+ display_name = display_name ,
120
162
system_instruction = system_instruction ,
121
163
contents = contents ,
122
164
tools = tools_lib ,
123
165
tool_config = tool_config ,
124
166
ttl = ttl ,
167
+ expire_time = expire_time ,
125
168
)
126
169
127
170
return protos .CreateCachedContentRequest (cached_content = cached_content )
@@ -130,48 +173,55 @@ def _prepare_create_request(
130
173
def create (
131
174
cls ,
132
175
model : str ,
133
- name : str | None = None ,
176
+ * ,
177
+ display_name : str | None = None ,
134
178
system_instruction : Optional [content_types .ContentType ] = None ,
135
179
contents : Optional [content_types .ContentsType ] = None ,
136
180
tools : Optional [content_types .FunctionLibraryType ] = None ,
137
181
tool_config : Optional [content_types .ToolConfigType ] = None ,
138
- ttl : Optional [caching_types .ExpirationTypes ] = datetime . timedelta ( hours = 1 ) ,
139
- client : glm . CacheServiceClient | None = None ,
182
+ ttl : Optional [caching_types .TTLTypes ] = None ,
183
+ expire_time : Optional [ caching_types . ExpireTimeTypes ] = None ,
140
184
) -> CachedContent :
141
185
"""Creates `CachedContent` resource.
142
186
143
187
Args:
144
188
model: The name of the `model` to use for cached content creation.
145
189
Any `CachedContent` resource can be only used with the
146
190
`model` it was created for.
147
- name: The resource name referring to the cached content.
191
+ display_name: The user-generated meaningful display name
192
+ of the cached content. `display_name` must be no
193
+ more than 128 unicode characters.
148
194
system_instruction: Developer set system instruction.
149
195
contents: Contents to cache.
150
196
tools: A list of `Tools` the model may use to generate response.
151
197
tool_config: Config to apply to all tools.
152
198
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
199
+ `ttl` and `expire_time` are exclusive arguments.
200
+ expire_time: Expiration time for cached resource.
201
+ `ttl` and `expire_time` are exclusive arguments.
153
202
154
203
Returns:
155
204
`CachedContent` resource with specified name.
156
205
"""
157
- if client is None :
158
- client = get_default_cache_client ()
206
+ client = get_default_cache_client ()
159
207
160
208
request = cls ._prepare_create_request (
161
209
model = model ,
162
- name = name ,
210
+ display_name = display_name ,
163
211
system_instruction = system_instruction ,
164
212
contents = contents ,
165
213
tools = tools ,
166
214
tool_config = tool_config ,
167
215
ttl = ttl ,
216
+ expire_time = expire_time ,
168
217
)
169
218
170
219
response = client .create_cached_content (request )
171
- return cls ._decode_cached_content (response )
220
+ result = CachedContent ._from_obj (response )
221
+ return result
172
222
173
223
@classmethod
174
- def get (cls , name : str , client : glm . CacheServiceClient | None = None ) -> CachedContent :
224
+ def get (cls , name : str ) -> CachedContent :
175
225
"""Fetches required `CachedContent` resource.
176
226
177
227
Args:
@@ -180,20 +230,18 @@ def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedC
180
230
Returns:
181
231
`CachedContent` resource with specified `name`.
182
232
"""
183
- if client is None :
184
- client = get_default_cache_client ()
233
+ client = get_default_cache_client ()
185
234
186
235
if "cachedContents/" not in name :
187
236
name = "cachedContents/" + name
188
237
189
238
request = protos .GetCachedContentRequest (name = name )
190
239
response = client .get_cached_content (request )
191
- return cls ._decode_cached_content (response )
240
+ result = CachedContent ._from_obj (response )
241
+ return result
192
242
193
243
@classmethod
194
- def list (
195
- cls , page_size : Optional [int ] = 1 , client : glm .CacheServiceClient | None = None
196
- ) -> Iterable [CachedContent ]:
244
+ def list (cls , page_size : Optional [int ] = 1 ) -> Iterable [CachedContent ]:
197
245
"""Lists `CachedContent` objects associated with the project.
198
246
199
247
Args:
@@ -203,58 +251,64 @@ def list(
203
251
Returns:
204
252
A paginated list of `CachedContent` objects.
205
253
"""
206
- if client is None :
207
- client = get_default_cache_client ()
254
+ client = get_default_cache_client ()
208
255
209
256
request = protos .ListCachedContentsRequest (page_size = page_size )
210
257
for cached_content in client .list_cached_contents (request ):
211
- yield cls ._decode_cached_content (cached_content )
258
+ cached_content = CachedContent ._from_obj (cached_content )
259
+ yield cached_content
212
260
213
- def delete (self , client : glm . CachedServiceClient | None = None ) -> None :
261
+ def delete (self ) -> None :
214
262
"""Deletes `CachedContent` resource."""
215
- if client is None :
216
- client = get_default_cache_client ()
263
+ client = get_default_cache_client ()
217
264
218
265
request = protos .DeleteCachedContentRequest (name = self .name )
219
266
client .delete_cached_content (request )
220
267
return
221
268
222
269
def update (
223
270
self ,
224
- updates : dict [str , Any ],
225
- client : glm .CacheServiceClient | None = None ,
226
- ) -> CachedContent :
271
+ * ,
272
+ ttl : Optional [caching_types .TTLTypes ] = None ,
273
+ expire_time : Optional [caching_types .ExpireTimeTypes ] = None ,
274
+ ) -> None :
227
275
"""Updates requested `CachedContent` resource.
228
276
229
277
Args:
230
- updates: The list of fields to update. Currently only
231
- `ttl/expire_time` is supported as an update path.
232
-
233
- Returns:
234
- `CachedContent` object with specified updates.
278
+ ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
279
+ `ttl` and `expire_time` are exclusive arguments.
280
+ expire_time: Expiration time for cached resource.
281
+ `ttl` and `expire_time` are exclusive arguments.
235
282
"""
236
- if client is None :
237
- client = get_default_cache_client ()
238
-
239
- updates = flatten_update_paths (updates )
240
- for update_path in updates :
241
- if update_path == "ttl" :
242
- updates = updates .copy ()
243
- update_path_val = updates .get (update_path )
244
- updates [update_path ] = caching_types .to_ttl (update_path_val )
245
- else :
246
- raise ValueError (
247
- f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{ update_path } ` instead."
248
- )
249
- field_mask = field_mask_pb2 .FieldMask ()
283
+ client = get_default_cache_client ()
250
284
251
- for path in updates . keys () :
252
- field_mask . paths . append ( path )
253
- for path , value in updates . items ():
254
- self . _apply_update ( path , value )
285
+ if ttl and expire_time :
286
+ raise ValueError (
287
+ "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
288
+ )
255
289
256
- request = protos .UpdateCachedContentRequest (
257
- cached_content = self ._to_dict (), update_mask = field_mask
290
+ ttl = caching_types .to_optional_ttl (ttl )
291
+ expire_time = caching_types .to_optional_expire_time (expire_time )
292
+
293
+ updates = protos .CachedContent (
294
+ name = self .name ,
295
+ ttl = ttl ,
296
+ expire_time = expire_time ,
258
297
)
259
- client .update_cached_content (request )
260
- return self
298
+
299
+ field_mask = field_mask_pb2 .FieldMask ()
300
+
301
+ if ttl :
302
+ field_mask .paths .append ("ttl" )
303
+ elif expire_time :
304
+ field_mask .paths .append ("expire_time" )
305
+ else :
306
+ raise ValueError (
307
+ f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`."
308
+ )
309
+
310
+ request = protos .UpdateCachedContentRequest (cached_content = updates , update_mask = field_mask )
311
+ updated_cc = client .update_cached_content (request )
312
+ self ._update (updated_cc )
313
+
314
+ return
0 commit comments