Skip to content

Commit 694c1db

Browse files
flipbit03erikwrede
andauthored
Vendor DataLoader from aiodataloader and move get_event_loop() out of __init__ function. (#1459)
* Vendor DataLoader from aiodataloader and also move get_event_loop behavior from `__init__` to a property which only gets resolved when actually needed (this will solve PyTest-related to early get_event_loop() issues) * Added DataLoader's specific tests * plug `loop` parameter into `self._loop`, so that we still have the ability to pass in a custom event loop, if needed. Co-authored-by: Erik Wrede <erikwrede2@gmail.com>
1 parent 20219fd commit 694c1db

File tree

5 files changed

+737
-80
lines changed

5 files changed

+737
-80
lines changed

graphene/utils/dataloader.py

+281
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
from asyncio import (
2+
gather,
3+
ensure_future,
4+
get_event_loop,
5+
iscoroutine,
6+
iscoroutinefunction,
7+
)
8+
from collections import namedtuple
9+
from collections.abc import Iterable
10+
from functools import partial
11+
12+
from typing import List # flake8: noqa
13+
14+
Loader = namedtuple("Loader", "key,future")
15+
16+
17+
def iscoroutinefunctionorpartial(fn):
18+
return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn)
19+
20+
21+
class DataLoader(object):
22+
batch = True
23+
max_batch_size = None # type: int
24+
cache = True
25+
26+
def __init__(
27+
self,
28+
batch_load_fn=None,
29+
batch=None,
30+
max_batch_size=None,
31+
cache=None,
32+
get_cache_key=None,
33+
cache_map=None,
34+
loop=None,
35+
):
36+
37+
self._loop = loop
38+
39+
if batch_load_fn is not None:
40+
self.batch_load_fn = batch_load_fn
41+
42+
assert iscoroutinefunctionorpartial(
43+
self.batch_load_fn
44+
), "batch_load_fn must be coroutine. Received: {}".format(self.batch_load_fn)
45+
46+
if not callable(self.batch_load_fn):
47+
raise TypeError( # pragma: no cover
48+
(
49+
"DataLoader must be have a batch_load_fn which accepts "
50+
"Iterable<key> and returns Future<Iterable<value>>, but got: {}."
51+
).format(batch_load_fn)
52+
)
53+
54+
if batch is not None:
55+
self.batch = batch # pragma: no cover
56+
57+
if max_batch_size is not None:
58+
self.max_batch_size = max_batch_size
59+
60+
if cache is not None:
61+
self.cache = cache # pragma: no cover
62+
63+
self.get_cache_key = get_cache_key or (lambda x: x)
64+
65+
self._cache = cache_map if cache_map is not None else {}
66+
self._queue = [] # type: List[Loader]
67+
68+
@property
69+
def loop(self):
70+
if not self._loop:
71+
self._loop = get_event_loop()
72+
73+
return self._loop
74+
75+
def load(self, key=None):
76+
"""
77+
Loads a key, returning a `Future` for the value represented by that key.
78+
"""
79+
if key is None:
80+
raise TypeError( # pragma: no cover
81+
(
82+
"The loader.load() function must be called with a value, "
83+
"but got: {}."
84+
).format(key)
85+
)
86+
87+
cache_key = self.get_cache_key(key)
88+
89+
# If caching and there is a cache-hit, return cached Future.
90+
if self.cache:
91+
cached_result = self._cache.get(cache_key)
92+
if cached_result:
93+
return cached_result
94+
95+
# Otherwise, produce a new Future for this value.
96+
future = self.loop.create_future()
97+
# If caching, cache this Future.
98+
if self.cache:
99+
self._cache[cache_key] = future
100+
101+
self.do_resolve_reject(key, future)
102+
return future
103+
104+
def do_resolve_reject(self, key, future):
105+
# Enqueue this Future to be dispatched.
106+
self._queue.append(Loader(key=key, future=future))
107+
# Determine if a dispatch of this queue should be scheduled.
108+
# A single dispatch should be scheduled per queue at the time when the
109+
# queue changes from "empty" to "full".
110+
if len(self._queue) == 1:
111+
if self.batch:
112+
# If batching, schedule a task to dispatch the queue.
113+
enqueue_post_future_job(self.loop, self)
114+
else:
115+
# Otherwise dispatch the (queue of one) immediately.
116+
dispatch_queue(self) # pragma: no cover
117+
118+
def load_many(self, keys):
119+
"""
120+
Loads multiple keys, returning a list of values
121+
122+
>>> a, b = await my_loader.load_many([ 'a', 'b' ])
123+
124+
This is equivalent to the more verbose:
125+
126+
>>> a, b = await gather(
127+
>>> my_loader.load('a'),
128+
>>> my_loader.load('b')
129+
>>> )
130+
"""
131+
if not isinstance(keys, Iterable):
132+
raise TypeError( # pragma: no cover
133+
(
134+
"The loader.load_many() function must be called with Iterable<key> "
135+
"but got: {}."
136+
).format(keys)
137+
)
138+
139+
return gather(*[self.load(key) for key in keys])
140+
141+
def clear(self, key):
142+
"""
143+
Clears the value at `key` from the cache, if it exists. Returns itself for
144+
method chaining.
145+
"""
146+
cache_key = self.get_cache_key(key)
147+
self._cache.pop(cache_key, None)
148+
return self
149+
150+
def clear_all(self):
151+
"""
152+
Clears the entire cache. To be used when some event results in unknown
153+
invalidations across this particular `DataLoader`. Returns itself for
154+
method chaining.
155+
"""
156+
self._cache.clear()
157+
return self
158+
159+
def prime(self, key, value):
160+
"""
161+
Adds the provied key and value to the cache. If the key already exists, no
162+
change is made. Returns itself for method chaining.
163+
"""
164+
cache_key = self.get_cache_key(key)
165+
166+
# Only add the key if it does not already exist.
167+
if cache_key not in self._cache:
168+
# Cache a rejected future if the value is an Error, in order to match
169+
# the behavior of load(key).
170+
future = self.loop.create_future()
171+
if isinstance(value, Exception):
172+
future.set_exception(value)
173+
else:
174+
future.set_result(value)
175+
176+
self._cache[cache_key] = future
177+
178+
return self
179+
180+
181+
def enqueue_post_future_job(loop, loader):
182+
async def dispatch():
183+
dispatch_queue(loader)
184+
185+
loop.call_soon(ensure_future, dispatch())
186+
187+
188+
def get_chunks(iterable_obj, chunk_size=1):
189+
chunk_size = max(1, chunk_size)
190+
return (
191+
iterable_obj[i : i + chunk_size]
192+
for i in range(0, len(iterable_obj), chunk_size)
193+
)
194+
195+
196+
def dispatch_queue(loader):
197+
"""
198+
Given the current state of a Loader instance, perform a batch load
199+
from its current queue.
200+
"""
201+
# Take the current loader queue, replacing it with an empty queue.
202+
queue = loader._queue
203+
loader._queue = []
204+
205+
# If a max_batch_size was provided and the queue is longer, then segment the
206+
# queue into multiple batches, otherwise treat the queue as a single batch.
207+
max_batch_size = loader.max_batch_size
208+
209+
if max_batch_size and max_batch_size < len(queue):
210+
chunks = get_chunks(queue, max_batch_size)
211+
for chunk in chunks:
212+
ensure_future(dispatch_queue_batch(loader, chunk))
213+
else:
214+
ensure_future(dispatch_queue_batch(loader, queue))
215+
216+
217+
async def dispatch_queue_batch(loader, queue):
218+
# Collect all keys to be loaded in this dispatch
219+
keys = [loaded.key for loaded in queue]
220+
221+
# Call the provided batch_load_fn for this loader with the loader queue's keys.
222+
batch_future = loader.batch_load_fn(keys)
223+
224+
# Assert the expected response from batch_load_fn
225+
if not batch_future or not iscoroutine(batch_future):
226+
return failed_dispatch( # pragma: no cover
227+
loader,
228+
queue,
229+
TypeError(
230+
(
231+
"DataLoader must be constructed with a function which accepts "
232+
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
233+
"not return a Coroutine: {}."
234+
).format(batch_future)
235+
),
236+
)
237+
238+
try:
239+
values = await batch_future
240+
if not isinstance(values, Iterable):
241+
raise TypeError( # pragma: no cover
242+
(
243+
"DataLoader must be constructed with a function which accepts "
244+
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
245+
"not return a Future of a Iterable: {}."
246+
).format(values)
247+
)
248+
249+
values = list(values)
250+
if len(values) != len(keys):
251+
raise TypeError( # pragma: no cover
252+
(
253+
"DataLoader must be constructed with a function which accepts "
254+
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
255+
"not return a Future of a Iterable with the same length as the Iterable "
256+
"of keys."
257+
"\n\nKeys:\n{}"
258+
"\n\nValues:\n{}"
259+
).format(keys, values)
260+
)
261+
262+
# Step through the values, resolving or rejecting each Future in the
263+
# loaded queue.
264+
for loaded, value in zip(queue, values):
265+
if isinstance(value, Exception):
266+
loaded.future.set_exception(value)
267+
else:
268+
loaded.future.set_result(value)
269+
270+
except Exception as e:
271+
return failed_dispatch(loader, queue, e)
272+
273+
274+
def failed_dispatch(loader, queue, error):
275+
"""
276+
Do not cache individual loads if the entire batch dispatch fails,
277+
but still reject each request so they do not hang.
278+
"""
279+
for loaded in queue:
280+
loader.clear(loaded.key)
281+
loaded.future.set_exception(error)

0 commit comments

Comments
 (0)