Skip to content
This repository was archived by the owner on Nov 21, 2024. It is now read-only.

Commit 20b56a0

Browse files
core[patch]: fix repr and str for Serializable (langchain-ai#26786)
Fixes langchain-ai#26499 --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
1 parent 2d58a8a commit 20b56a0

File tree

2 files changed

+61
-24
lines changed

2 files changed

+61
-24
lines changed

libs/core/langchain_core/load/serializable.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111

1212
from pydantic import BaseModel, ConfigDict
13+
from pydantic.fields import FieldInfo
1314
from typing_extensions import NotRequired
1415

1516

@@ -77,10 +78,23 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
7778
Raises:
7879
Exception: If the key is not in the model.
7980
"""
81+
field = model.model_fields[key]
82+
return _try_neq_default(value, field)
83+
84+
85+
def _try_neq_default(value: Any, field: FieldInfo) -> bool:
86+
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
87+
# Pandas DataFrames).
8088
try:
81-
return model.model_fields[key].get_default() != value
82-
except Exception:
83-
return True
89+
return bool(field.get_default() != value)
90+
except Exception as _:
91+
try:
92+
return all(field.get_default() != value)
93+
except Exception as _:
94+
try:
95+
return value is not field.default
96+
except Exception as _:
97+
return False
8498

8599

86100
class Serializable(BaseModel, ABC):
@@ -297,18 +311,7 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
297311
if field.default_factory is list and isinstance(value, list):
298312
return False
299313

300-
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
301-
# Pandas DataFrames).
302-
try:
303-
value_neq_default = bool(field.get_default() != value)
304-
except Exception as _:
305-
try:
306-
value_neq_default = all(field.get_default() != value)
307-
except Exception as _:
308-
try:
309-
value_neq_default = value is not field.default
310-
except Exception as _:
311-
value_neq_default = False
314+
value_neq_default = _try_neq_default(value, field)
312315

313316
# If value is falsy and does not match the default
314317
return value_is_truthy or value_neq_default

libs/core/tests/unit_tests/load/test_serializable.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,22 @@
44
from langchain_core.load.serializable import _is_field_useful
55

66

7+
class NonBoolObj:
8+
def __bool__(self) -> bool:
9+
msg = "Truthiness can't be determined"
10+
raise ValueError(msg)
11+
12+
def __eq__(self, other: object) -> bool:
13+
msg = "Equality can't be determined"
14+
raise ValueError(msg)
15+
16+
def __str__(self) -> str:
17+
return self.__class__.__name__
18+
19+
def __repr__(self) -> str:
20+
return self.__class__.__name__
21+
22+
723
def test_simple_serialization() -> None:
824
class Foo(Serializable):
925
bar: int
@@ -82,15 +98,6 @@ def __bool__(self) -> bool:
8298
def __eq__(self, other: object) -> bool:
8399
return self # type: ignore[return-value]
84100

85-
class NonBoolObj:
86-
def __bool__(self) -> bool:
87-
msg = "Truthiness can't be determined"
88-
raise ValueError(msg)
89-
90-
def __eq__(self, other: object) -> bool:
91-
msg = "Equality can't be determined"
92-
raise ValueError(msg)
93-
94101
default_x = ArrayObj()
95102
default_y = NonBoolObj()
96103

@@ -169,3 +176,30 @@ def test_simple_deserialization_with_additional_imports() -> None:
169176
},
170177
)
171178
assert isinstance(new_foo, Foo2)
179+
180+
181+
class Foo3(Serializable):
182+
model_config = ConfigDict(arbitrary_types_allowed=True)
183+
184+
content: str
185+
non_bool: NonBoolObj
186+
187+
@classmethod
188+
def is_lc_serializable(cls) -> bool:
189+
return True
190+
191+
192+
def test_repr() -> None:
193+
foo = Foo3(
194+
content="repr",
195+
non_bool=NonBoolObj(),
196+
)
197+
assert repr(foo) == "Foo3(content='repr', non_bool=NonBoolObj)"
198+
199+
200+
def test_str() -> None:
201+
foo = Foo3(
202+
content="str",
203+
non_bool=NonBoolObj(),
204+
)
205+
assert str(foo) == "content='str' non_bool=NonBoolObj"

0 commit comments

Comments
 (0)