|
1 | 1 | import base64
|
2 | 2 | import json
|
3 |
| -from typing import List, Optional |
| 3 | +from typing import List, Optional, cast |
4 | 4 |
|
5 | 5 | import httpx
|
6 | 6 | import pytest
|
@@ -209,10 +209,21 @@ def test_usage_metadata(self, model: BaseChatModel) -> None:
|
209 | 209 | def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
210 | 210 | if not self.returns_usage_metadata:
|
211 | 211 | pytest.skip("Not implemented.")
|
212 |
| - full: Optional[BaseMessageChunk] = None |
213 |
| - for chunk in model.stream("Hello"): |
| 212 | + full: Optional[AIMessageChunk] = None |
| 213 | + for chunk in model.stream("Write me 2 haikus. Only include the haikus."): |
214 | 214 | assert isinstance(chunk, AIMessageChunk)
|
215 |
| - full = chunk if full is None else full + chunk |
| 215 | + # only one chunk is allowed to set usage_metadata.input_tokens |
| 216 | + # if multiple do, it's likely a bug that will result in overcounting |
| 217 | + # input tokens |
| 218 | + if full and full.usage_metadata and full.usage_metadata["input_tokens"]: |
| 219 | + assert ( |
| 220 | + not chunk.usage_metadata or not chunk.usage_metadata["input_tokens"] |
| 221 | + ), ( |
| 222 | + "Only one chunk should set input_tokens," |
| 223 | + " the rest should be 0 or None" |
| 224 | + ) |
| 225 | + full = chunk if full is None else cast(AIMessageChunk, full + chunk) |
| 226 | + |
216 | 227 | assert isinstance(full, AIMessageChunk)
|
217 | 228 | assert full.usage_metadata is not None
|
218 | 229 | assert isinstance(full.usage_metadata["input_tokens"], int)
|
|
0 commit comments