Skip to content

Commit 4badee7

Browse files
committed
Add support for List and Dict providers to _locate_dependent_closing_args
1 parent 72a316c commit 4badee7

File tree

3 files changed

+87
-99
lines changed

3 files changed

+87
-99
lines changed

src/dependency_injector/wiring.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
"""Wiring module."""
22

33
import functools
4-
import inspect
54
import importlib
65
import importlib.machinery
6+
import inspect
77
import pkgutil
8-
import warnings
98
import sys
9+
import warnings
1010
from types import ModuleType
1111
from typing import (
12-
Optional,
13-
Iterable,
14-
Iterator,
15-
Callable,
1612
Any,
17-
Tuple,
13+
Callable,
1814
Dict,
1915
Generic,
20-
TypeVar,
16+
Iterable,
17+
Iterator,
18+
Optional,
19+
Set,
20+
Tuple,
2121
Type,
22+
TypeVar,
2223
Union,
23-
Set,
2424
cast,
2525
)
2626

@@ -645,17 +645,17 @@ def _fetch_reference_injections( # noqa: C901
645645
def _locate_dependent_closing_args(
646646
provider: providers.Provider,
647647
) -> Dict[str, providers.Provider]:
648-
if not hasattr(provider, "args"):
649-
return {}
648+
closing_deps: Dict[str, providers.Provider] = {}
650649

651-
closing_deps = {}
652-
for arg in [*provider.args, *provider.kwargs.values()]:
653-
if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"):
650+
for arg in [
651+
*getattr(provider, "args", []),
652+
*getattr(provider, "kwargs", {}).values(),
653+
]:
654+
if not isinstance(arg, providers.Provider):
654655
continue
655656
if isinstance(arg, providers.Resource):
656-
return {str(id(arg)): arg}
657-
if arg.args or arg.kwargs:
658-
closing_deps |= _locate_dependent_closing_args(arg)
657+
closing_deps[str(id(arg))] = arg
658+
closing_deps |= _locate_dependent_closing_args(arg)
659659

660660
return closing_deps
661661

@@ -1030,8 +1030,8 @@ def is_loader_installed() -> bool:
10301030
_loader = AutoLoader()
10311031

10321032
# Optimizations
1033-
from ._cwiring import _get_sync_patched # noqa
10341033
from ._cwiring import _async_inject # noqa
1034+
from ._cwiring import _get_sync_patched # noqa
10351035

10361036

10371037
# Wiring uses the following Python wrapper because there is
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,78 @@
1+
from typing import Any, Dict, List, Optional
2+
13
from dependency_injector import containers, providers
2-
from dependency_injector.wiring import inject, Provide, Closing
4+
from dependency_injector.wiring import Closing, Provide, inject
5+
6+
7+
class Counter:
8+
def __init__(self) -> None:
9+
self._init = 0
10+
self._shutdown = 0
11+
12+
def init(self) -> None:
13+
self._init += 1
314

15+
def shutdown(self) -> None:
16+
self._shutdown += 1
417

5-
class Singleton:
6-
pass
18+
def reset(self) -> None:
19+
self._init = 0
20+
self._shutdown = 0
721

822

923
class Service:
10-
init_counter: int = 0
11-
shutdown_counter: int = 0
12-
dependency: Singleton = None
24+
def __init__(self, counter: Optional[Counter] = None, **dependencies: Any) -> None:
25+
self.counter = counter or Counter()
26+
self.dependencies = dependencies
1327

14-
@classmethod
15-
def reset_counter(cls):
16-
cls.init_counter = 0
17-
cls.shutdown_counter = 0
28+
def init(self) -> None:
29+
self.counter.init()
1830

19-
@classmethod
20-
def init(cls, dependency: Singleton = None):
21-
if dependency:
22-
cls.dependency = dependency
23-
cls.init_counter += 1
31+
def shutdown(self) -> None:
32+
self.counter.shutdown()
2433

25-
@classmethod
26-
def shutdown(cls):
27-
cls.shutdown_counter += 1
34+
@property
35+
def init_counter(self) -> int:
36+
return self.counter._init
37+
38+
@property
39+
def shutdown_counter(self) -> int:
40+
return self.counter._shutdown
2841

2942

3043
class FactoryService:
31-
def __init__(self, service: Service):
44+
def __init__(self, service: Service, service2: Service):
3245
self.service = service
46+
self.service2 = service2
3347

3448

3549
class NestedService:
3650
def __init__(self, factory_service: FactoryService):
3751
self.factory_service = factory_service
3852

3953

40-
def init_service():
41-
service = Service()
54+
def init_service(counter: Counter, _list: List[int], _dict: Dict[str, int]):
55+
service = Service(counter, _list=_list, _dict=_dict)
4256
service.init()
4357
yield service
4458
service.shutdown()
4559

4660

47-
def init_service_with_singleton(singleton: Singleton):
48-
service = Service()
49-
service.init(singleton)
50-
yield service
51-
service.shutdown()
52-
53-
5461
class Container(containers.DeclarativeContainer):
55-
56-
service = providers.Resource(init_service)
57-
factory_service = providers.Factory(FactoryService, service)
58-
factory_service_kwargs = providers.Factory(
59-
FactoryService,
60-
service=service
62+
counter = providers.Singleton(Counter)
63+
_list = providers.List(
64+
providers.Callable(lambda a: a, a=1), providers.Callable(lambda b: b, 2)
6165
)
62-
nested_service = providers.Factory(NestedService, factory_service)
63-
64-
65-
class ContainerSingleton(containers.DeclarativeContainer):
66-
67-
singleton = providers.Singleton(Singleton)
68-
service = providers.Resource(
69-
init_service_with_singleton,
70-
singleton
66+
_dict = providers.Dict(
67+
a=providers.Callable(lambda a: a, a=1), b=providers.Callable(lambda b: b, 2)
7168
)
72-
factory_service = providers.Factory(FactoryService, service)
69+
service = providers.Resource(init_service, counter, _list, _dict=_dict)
70+
service2 = providers.Resource(init_service, counter, _list, _dict=_dict)
71+
factory_service = providers.Factory(FactoryService, service, service2)
7372
factory_service_kwargs = providers.Factory(
7473
FactoryService,
75-
service=service
74+
service=service,
75+
service2=service2,
7676
)
7777
nested_service = providers.Factory(NestedService, factory_service)
7878

@@ -84,20 +84,20 @@ def test_function(service: Service = Closing[Provide["service"]]):
8484

8585
@inject
8686
def test_function_dependency(
87-
factory: FactoryService = Closing[Provide["factory_service"]]
87+
factory: FactoryService = Closing[Provide["factory_service"]],
8888
):
8989
return factory
9090

9191

9292
@inject
9393
def test_function_dependency_kwargs(
94-
factory: FactoryService = Closing[Provide["factory_service_kwargs"]]
94+
factory: FactoryService = Closing[Provide["factory_service_kwargs"]],
9595
):
9696
return factory
9797

9898

9999
@inject
100100
def test_function_nested_dependency(
101-
nested: NestedService = Closing[Provide["nested_service"]]
101+
nested: NestedService = Closing[Provide["nested_service"]],
102102
):
103103
return nested

tests/unit/wiring/string_ids/test_main_py36.py

+21-33
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from decimal import Decimal
44

5-
from dependency_injector import errors
6-
from dependency_injector.wiring import Closing, Provide, Provider, wire
75
from pytest import fixture, mark, raises
8-
96
from samples.wiringstringids import module, package, resourceclosing
10-
from samples.wiringstringids.service import Service
117
from samples.wiringstringids.container import Container, SubContainer
8+
from samples.wiringstringids.service import Service
9+
10+
from dependency_injector import errors
11+
from dependency_injector.wiring import Closing, Provide, Provider, wire
1212

1313

1414
@fixture(autouse=True)
@@ -33,14 +33,12 @@ def subcontainer():
3333
container.unwire()
3434

3535

36-
@fixture(params=[
37-
resourceclosing.Container,
38-
resourceclosing.ContainerSingleton,
39-
])
36+
@fixture
4037
def resourceclosing_container(request):
41-
container = request.param()
38+
container = resourceclosing.Container()
4239
container.wire(modules=[resourceclosing])
43-
yield container
40+
with container.reset_singletons():
41+
yield container
4442
container.unwire()
4543

4644

@@ -277,8 +275,6 @@ def test_wire_multiple_containers():
277275

278276
@mark.usefixtures("resourceclosing_container")
279277
def test_closing_resource():
280-
resourceclosing.Service.reset_counter()
281-
282278
result_1 = resourceclosing.test_function()
283279
assert isinstance(result_1, resourceclosing.Service)
284280
assert result_1.init_counter == 1
@@ -294,55 +290,48 @@ def test_closing_resource():
294290

295291
@mark.usefixtures("resourceclosing_container")
296292
def test_closing_dependency_resource():
297-
resourceclosing.Service.reset_counter()
298-
299293
result_1 = resourceclosing.test_function_dependency()
300294
assert isinstance(result_1, resourceclosing.FactoryService)
301-
assert result_1.service.init_counter == 1
302-
assert result_1.service.shutdown_counter == 1
295+
assert result_1.service.init_counter == 2
296+
assert result_1.service.shutdown_counter == 2
303297

304298
result_2 = resourceclosing.test_function_dependency()
299+
305300
assert isinstance(result_2, resourceclosing.FactoryService)
306-
assert result_2.service.init_counter == 2
307-
assert result_2.service.shutdown_counter == 2
301+
assert result_2.service.init_counter == 4
302+
assert result_2.service.shutdown_counter == 4
308303

309304

310305
@mark.usefixtures("resourceclosing_container")
311306
def test_closing_dependency_resource_kwargs():
312-
resourceclosing.Service.reset_counter()
313-
314307
result_1 = resourceclosing.test_function_dependency_kwargs()
315308
assert isinstance(result_1, resourceclosing.FactoryService)
316-
assert result_1.service.init_counter == 1
317-
assert result_1.service.shutdown_counter == 1
309+
assert result_1.service.init_counter == 2
310+
assert result_1.service.shutdown_counter == 2
318311

319312
result_2 = resourceclosing.test_function_dependency_kwargs()
320313
assert isinstance(result_2, resourceclosing.FactoryService)
321-
assert result_2.service.init_counter == 2
322-
assert result_2.service.shutdown_counter == 2
314+
assert result_2.service.init_counter == 4
315+
assert result_2.service.shutdown_counter == 4
323316

324317

325318
@mark.usefixtures("resourceclosing_container")
326319
def test_closing_nested_dependency_resource():
327-
resourceclosing.Service.reset_counter()
328-
329320
result_1 = resourceclosing.test_function_nested_dependency()
330321
assert isinstance(result_1, resourceclosing.NestedService)
331-
assert result_1.factory_service.service.init_counter == 1
332-
assert result_1.factory_service.service.shutdown_counter == 1
322+
assert result_1.factory_service.service.init_counter == 2
323+
assert result_1.factory_service.service.shutdown_counter == 2
333324

334325
result_2 = resourceclosing.test_function_nested_dependency()
335326
assert isinstance(result_2, resourceclosing.NestedService)
336-
assert result_2.factory_service.service.init_counter == 2
337-
assert result_2.factory_service.service.shutdown_counter == 2
327+
assert result_2.factory_service.service.init_counter == 4
328+
assert result_2.factory_service.service.shutdown_counter == 4
338329

339330
assert result_1 is not result_2
340331

341332

342333
@mark.usefixtures("resourceclosing_container")
343334
def test_closing_resource_bypass_marker_injection():
344-
resourceclosing.Service.reset_counter()
345-
346335
result_1 = resourceclosing.test_function(service=Closing[Provide["service"]])
347336
assert isinstance(result_1, resourceclosing.Service)
348337
assert result_1.init_counter == 1
@@ -358,7 +347,6 @@ def test_closing_resource_bypass_marker_injection():
358347

359348
@mark.usefixtures("resourceclosing_container")
360349
def test_closing_resource_context():
361-
resourceclosing.Service.reset_counter()
362350
service = resourceclosing.Service()
363351

364352
result_1 = resourceclosing.test_function(service=service)

0 commit comments

Comments
 (0)