|
| 1 | +from asyncio import ( |
| 2 | + gather, |
| 3 | + ensure_future, |
| 4 | + get_event_loop, |
| 5 | + iscoroutine, |
| 6 | + iscoroutinefunction, |
| 7 | +) |
| 8 | +from collections import namedtuple |
| 9 | +from collections.abc import Iterable |
| 10 | +from functools import partial |
| 11 | + |
| 12 | +from typing import List # flake8: noqa |
| 13 | + |
| 14 | +Loader = namedtuple("Loader", "key,future") |
| 15 | + |
| 16 | + |
| 17 | +def iscoroutinefunctionorpartial(fn): |
| 18 | + return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn) |
| 19 | + |
| 20 | + |
| 21 | +class DataLoader(object): |
| 22 | + batch = True |
| 23 | + max_batch_size = None # type: int |
| 24 | + cache = True |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + batch_load_fn=None, |
| 29 | + batch=None, |
| 30 | + max_batch_size=None, |
| 31 | + cache=None, |
| 32 | + get_cache_key=None, |
| 33 | + cache_map=None, |
| 34 | + loop=None, |
| 35 | + ): |
| 36 | + |
| 37 | + self._loop = loop |
| 38 | + |
| 39 | + if batch_load_fn is not None: |
| 40 | + self.batch_load_fn = batch_load_fn |
| 41 | + |
| 42 | + assert iscoroutinefunctionorpartial( |
| 43 | + self.batch_load_fn |
| 44 | + ), "batch_load_fn must be coroutine. Received: {}".format(self.batch_load_fn) |
| 45 | + |
| 46 | + if not callable(self.batch_load_fn): |
| 47 | + raise TypeError( # pragma: no cover |
| 48 | + ( |
| 49 | + "DataLoader must be have a batch_load_fn which accepts " |
| 50 | + "Iterable<key> and returns Future<Iterable<value>>, but got: {}." |
| 51 | + ).format(batch_load_fn) |
| 52 | + ) |
| 53 | + |
| 54 | + if batch is not None: |
| 55 | + self.batch = batch # pragma: no cover |
| 56 | + |
| 57 | + if max_batch_size is not None: |
| 58 | + self.max_batch_size = max_batch_size |
| 59 | + |
| 60 | + if cache is not None: |
| 61 | + self.cache = cache # pragma: no cover |
| 62 | + |
| 63 | + self.get_cache_key = get_cache_key or (lambda x: x) |
| 64 | + |
| 65 | + self._cache = cache_map if cache_map is not None else {} |
| 66 | + self._queue = [] # type: List[Loader] |
| 67 | + |
| 68 | + @property |
| 69 | + def loop(self): |
| 70 | + if not self._loop: |
| 71 | + self._loop = get_event_loop() |
| 72 | + |
| 73 | + return self._loop |
| 74 | + |
| 75 | + def load(self, key=None): |
| 76 | + """ |
| 77 | + Loads a key, returning a `Future` for the value represented by that key. |
| 78 | + """ |
| 79 | + if key is None: |
| 80 | + raise TypeError( # pragma: no cover |
| 81 | + ( |
| 82 | + "The loader.load() function must be called with a value, " |
| 83 | + "but got: {}." |
| 84 | + ).format(key) |
| 85 | + ) |
| 86 | + |
| 87 | + cache_key = self.get_cache_key(key) |
| 88 | + |
| 89 | + # If caching and there is a cache-hit, return cached Future. |
| 90 | + if self.cache: |
| 91 | + cached_result = self._cache.get(cache_key) |
| 92 | + if cached_result: |
| 93 | + return cached_result |
| 94 | + |
| 95 | + # Otherwise, produce a new Future for this value. |
| 96 | + future = self.loop.create_future() |
| 97 | + # If caching, cache this Future. |
| 98 | + if self.cache: |
| 99 | + self._cache[cache_key] = future |
| 100 | + |
| 101 | + self.do_resolve_reject(key, future) |
| 102 | + return future |
| 103 | + |
| 104 | + def do_resolve_reject(self, key, future): |
| 105 | + # Enqueue this Future to be dispatched. |
| 106 | + self._queue.append(Loader(key=key, future=future)) |
| 107 | + # Determine if a dispatch of this queue should be scheduled. |
| 108 | + # A single dispatch should be scheduled per queue at the time when the |
| 109 | + # queue changes from "empty" to "full". |
| 110 | + if len(self._queue) == 1: |
| 111 | + if self.batch: |
| 112 | + # If batching, schedule a task to dispatch the queue. |
| 113 | + enqueue_post_future_job(self.loop, self) |
| 114 | + else: |
| 115 | + # Otherwise dispatch the (queue of one) immediately. |
| 116 | + dispatch_queue(self) # pragma: no cover |
| 117 | + |
| 118 | + def load_many(self, keys): |
| 119 | + """ |
| 120 | + Loads multiple keys, returning a list of values |
| 121 | +
|
| 122 | + >>> a, b = await my_loader.load_many([ 'a', 'b' ]) |
| 123 | +
|
| 124 | + This is equivalent to the more verbose: |
| 125 | +
|
| 126 | + >>> a, b = await gather( |
| 127 | + >>> my_loader.load('a'), |
| 128 | + >>> my_loader.load('b') |
| 129 | + >>> ) |
| 130 | + """ |
| 131 | + if not isinstance(keys, Iterable): |
| 132 | + raise TypeError( # pragma: no cover |
| 133 | + ( |
| 134 | + "The loader.load_many() function must be called with Iterable<key> " |
| 135 | + "but got: {}." |
| 136 | + ).format(keys) |
| 137 | + ) |
| 138 | + |
| 139 | + return gather(*[self.load(key) for key in keys]) |
| 140 | + |
| 141 | + def clear(self, key): |
| 142 | + """ |
| 143 | + Clears the value at `key` from the cache, if it exists. Returns itself for |
| 144 | + method chaining. |
| 145 | + """ |
| 146 | + cache_key = self.get_cache_key(key) |
| 147 | + self._cache.pop(cache_key, None) |
| 148 | + return self |
| 149 | + |
| 150 | + def clear_all(self): |
| 151 | + """ |
| 152 | + Clears the entire cache. To be used when some event results in unknown |
| 153 | + invalidations across this particular `DataLoader`. Returns itself for |
| 154 | + method chaining. |
| 155 | + """ |
| 156 | + self._cache.clear() |
| 157 | + return self |
| 158 | + |
| 159 | + def prime(self, key, value): |
| 160 | + """ |
| 161 | + Adds the provied key and value to the cache. If the key already exists, no |
| 162 | + change is made. Returns itself for method chaining. |
| 163 | + """ |
| 164 | + cache_key = self.get_cache_key(key) |
| 165 | + |
| 166 | + # Only add the key if it does not already exist. |
| 167 | + if cache_key not in self._cache: |
| 168 | + # Cache a rejected future if the value is an Error, in order to match |
| 169 | + # the behavior of load(key). |
| 170 | + future = self.loop.create_future() |
| 171 | + if isinstance(value, Exception): |
| 172 | + future.set_exception(value) |
| 173 | + else: |
| 174 | + future.set_result(value) |
| 175 | + |
| 176 | + self._cache[cache_key] = future |
| 177 | + |
| 178 | + return self |
| 179 | + |
| 180 | + |
| 181 | +def enqueue_post_future_job(loop, loader): |
| 182 | + async def dispatch(): |
| 183 | + dispatch_queue(loader) |
| 184 | + |
| 185 | + loop.call_soon(ensure_future, dispatch()) |
| 186 | + |
| 187 | + |
| 188 | +def get_chunks(iterable_obj, chunk_size=1): |
| 189 | + chunk_size = max(1, chunk_size) |
| 190 | + return ( |
| 191 | + iterable_obj[i : i + chunk_size] |
| 192 | + for i in range(0, len(iterable_obj), chunk_size) |
| 193 | + ) |
| 194 | + |
| 195 | + |
| 196 | +def dispatch_queue(loader): |
| 197 | + """ |
| 198 | + Given the current state of a Loader instance, perform a batch load |
| 199 | + from its current queue. |
| 200 | + """ |
| 201 | + # Take the current loader queue, replacing it with an empty queue. |
| 202 | + queue = loader._queue |
| 203 | + loader._queue = [] |
| 204 | + |
| 205 | + # If a max_batch_size was provided and the queue is longer, then segment the |
| 206 | + # queue into multiple batches, otherwise treat the queue as a single batch. |
| 207 | + max_batch_size = loader.max_batch_size |
| 208 | + |
| 209 | + if max_batch_size and max_batch_size < len(queue): |
| 210 | + chunks = get_chunks(queue, max_batch_size) |
| 211 | + for chunk in chunks: |
| 212 | + ensure_future(dispatch_queue_batch(loader, chunk)) |
| 213 | + else: |
| 214 | + ensure_future(dispatch_queue_batch(loader, queue)) |
| 215 | + |
| 216 | + |
| 217 | +async def dispatch_queue_batch(loader, queue): |
| 218 | + # Collect all keys to be loaded in this dispatch |
| 219 | + keys = [loaded.key for loaded in queue] |
| 220 | + |
| 221 | + # Call the provided batch_load_fn for this loader with the loader queue's keys. |
| 222 | + batch_future = loader.batch_load_fn(keys) |
| 223 | + |
| 224 | + # Assert the expected response from batch_load_fn |
| 225 | + if not batch_future or not iscoroutine(batch_future): |
| 226 | + return failed_dispatch( # pragma: no cover |
| 227 | + loader, |
| 228 | + queue, |
| 229 | + TypeError( |
| 230 | + ( |
| 231 | + "DataLoader must be constructed with a function which accepts " |
| 232 | + "Iterable<key> and returns Future<Iterable<value>>, but the function did " |
| 233 | + "not return a Coroutine: {}." |
| 234 | + ).format(batch_future) |
| 235 | + ), |
| 236 | + ) |
| 237 | + |
| 238 | + try: |
| 239 | + values = await batch_future |
| 240 | + if not isinstance(values, Iterable): |
| 241 | + raise TypeError( # pragma: no cover |
| 242 | + ( |
| 243 | + "DataLoader must be constructed with a function which accepts " |
| 244 | + "Iterable<key> and returns Future<Iterable<value>>, but the function did " |
| 245 | + "not return a Future of a Iterable: {}." |
| 246 | + ).format(values) |
| 247 | + ) |
| 248 | + |
| 249 | + values = list(values) |
| 250 | + if len(values) != len(keys): |
| 251 | + raise TypeError( # pragma: no cover |
| 252 | + ( |
| 253 | + "DataLoader must be constructed with a function which accepts " |
| 254 | + "Iterable<key> and returns Future<Iterable<value>>, but the function did " |
| 255 | + "not return a Future of a Iterable with the same length as the Iterable " |
| 256 | + "of keys." |
| 257 | + "\n\nKeys:\n{}" |
| 258 | + "\n\nValues:\n{}" |
| 259 | + ).format(keys, values) |
| 260 | + ) |
| 261 | + |
| 262 | + # Step through the values, resolving or rejecting each Future in the |
| 263 | + # loaded queue. |
| 264 | + for loaded, value in zip(queue, values): |
| 265 | + if isinstance(value, Exception): |
| 266 | + loaded.future.set_exception(value) |
| 267 | + else: |
| 268 | + loaded.future.set_result(value) |
| 269 | + |
| 270 | + except Exception as e: |
| 271 | + return failed_dispatch(loader, queue, e) |
| 272 | + |
| 273 | + |
| 274 | +def failed_dispatch(loader, queue, error): |
| 275 | + """ |
| 276 | + Do not cache individual loads if the entire batch dispatch fails, |
| 277 | + but still reject each request so they do not hang. |
| 278 | + """ |
| 279 | + for loaded in queue: |
| 280 | + loader.clear(loaded.key) |
| 281 | + loaded.future.set_exception(error) |
0 commit comments