From a649467fd7c8f5b8527ed6d990e969b549b2bde4 Mon Sep 17 00:00:00 2001 From: graham Date: Wed, 13 Mar 2024 15:50:27 -0600 Subject: [PATCH 1/2] Union handles lists --- simple_parsing/wrappers/field_wrapper.py | 9 +++++++++ test/test_union.py | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 3a4d1860..c9d71e21 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -330,6 +330,8 @@ def get_arg_options(self) -> dict[str, Any]: elif self.is_union: logger.debug("Parsing a Union type!") _arg_options["type"] = get_parsing_fn(self.type) + if any(utils.is_list(o) for o in utils.get_args(self.type)): + _arg_options["nargs"] = "*" elif self.is_enum: logger.debug(f"Adding an Enum attribute '{self.name}'") @@ -501,6 +503,13 @@ def postprocess(self, raw_parsed_value: Any) -> Any: else: return raw_parsed_value + elif self.is_union: + list_in = [utils.is_list(o) for o in utils.get_args(self.type)] + # if type is like Union[str, list[str]] and only a single value was passed, + if any(list_in) and (not all(list_in)) and (len(raw_parsed_value) == 1): + raw_parsed_value = raw_parsed_value[0] + return raw_parsed_value + elif self.is_subparser: return raw_parsed_value diff --git a/test/test_union.py b/test/test_union.py index 4c0027dd..53bfe476 100644 --- a/test/test_union.py +++ b/test/test_union.py @@ -32,3 +32,26 @@ class Foo2(TestSetup): foo = Foo2.setup("--x 2") assert foo.x == 2 and type(foo.x) is int + + +def test_union_type_with_list(): + + @dataclass + class Foo(TestSetup): + x: Union[str, list[str]] + + foo = Foo.setup("--x bob") + assert foo.x == "bob" + + foo = Foo.setup("--x bob alice") + assert foo.x == ["bob", "alice"] + + @dataclass + class Foo(TestSetup): + x: Union[list[int], list[str]] + + foo = Foo.setup("--x bob alice") + assert foo.x == ["bob", "alice"] + + foo = Foo.setup("--x 1 2") + assert foo.x == [1, 2] From ed523986f777814f2683f8c059a625bbcc7434b2 Mon Sep 17 00:00:00 2001 From: graham Date: Tue, 19 Mar 2024 14:13:39 -0600 Subject: [PATCH 2/2] ran precommit --- test/test_union.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_union.py b/test/test_union.py index 53bfe476..e90a39ac 100644 --- a/test/test_union.py +++ b/test/test_union.py @@ -35,7 +35,6 @@ class Foo2(TestSetup): def test_union_type_with_list(): - @dataclass class Foo(TestSetup): x: Union[str, list[str]]