Skip to content

Commit 5d6ee74

Browse files
committed
Use background thread to generate BLB subsets
1 parent cdd4f0d commit 5d6ee74

File tree

1 file changed

+101
-22
lines changed

1 file changed

+101
-22
lines changed

src/lenskit/stats/_blb.py

Lines changed: 101 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77
from __future__ import annotations
88

99
import warnings
10-
from collections.abc import Callable
10+
from collections import deque
11+
from collections.abc import Callable, Generator
1112
from dataclasses import dataclass
12-
from typing import Any, ClassVar, Literal, Protocol, TypeAlias, TypeVar
13+
from threading import Condition, Lock, Thread
14+
from typing import Any, ClassVar, Deque, Literal, Protocol, TypeAlias, TypeVar
1315

1416
import numpy as np
1517
import pandas as pd
1618
from numpy.typing import NDArray
1719

1820
from lenskit.diagnostics import DataWarning
19-
from lenskit.logging import Tracer, get_logger, get_tracer
21+
from lenskit.logging import Tracer, get_logger, get_tracer, trace
2022
from lenskit.random import RNGInput, random_generator
2123

2224
F = TypeVar("F", bound=np.floating, covariant=True)
@@ -139,6 +141,7 @@ class _BLBootstrapper:
139141
r_window: int
140142
b_factor: float
141143
rng: np.random.Generator
144+
_rep_generator: ReplicateGenerator
142145

143146
def __init__(
144147
self,
@@ -165,7 +168,10 @@ def __init__(
165168
self._tracer = get_tracer(_log, stat=stat.__name__) # type: ignore
166169

167170
def run_bootstraps(self, xs: NDArray[F]) -> _BootResult:
168-
self._tracer.add_bindings(n=len(xs))
171+
n = len(xs)
172+
b = int(n**self.b_factor)
173+
174+
self._tracer.add_bindings(n=n, b=b)
169175
_log.debug("starting bootstrap", stat=self.statistic.__name__, n=len(xs)) # type: ignore
170176
ss_frames = {}
171177

@@ -174,17 +180,23 @@ def run_bootstraps(self, xs: NDArray[F]) -> _BootResult:
174180
lbs = StatAccum(np.mean)
175181
ubs = StatAccum(np.mean)
176182

177-
for i, ss in enumerate(self.blb_subsets(xs)):
178-
self._tracer.add_bindings(subset=i)
179-
self._tracer.trace("starting subset")
180-
res = self.measure_subset(xs, ss)
181-
ss_frames[i] = res.samples
182-
means.record(res.rep_mean)
183-
vars.record(res.rep_var)
184-
lbs.record(res.ci_lower)
185-
ubs.record(res.ci_upper)
186-
if self._check_convergence(means, vars, lbs, ubs, tol=self.tolerance, w=self.s_window):
187-
break
183+
self._rep_generator = ReplicateGenerator(n, b, self.rng)
184+
self._tracer.trace("let's go!")
185+
186+
with self._rep_generator:
187+
for i, ss in enumerate(self.blb_subsets(n, b)):
188+
self._tracer.add_bindings(subset=i)
189+
self._tracer.trace("starting subset")
190+
res = self.measure_subset(xs, ss)
191+
ss_frames[i] = res.samples
192+
means.record(res.rep_mean)
193+
vars.record(res.rep_var)
194+
lbs.record(res.ci_lower)
195+
ubs.record(res.ci_upper)
196+
if self._check_convergence(
197+
means, vars, lbs, ubs, tol=self.tolerance, w=self.s_window
198+
):
199+
break
188200

189201
return _BootResult(
190202
means.statistic,
@@ -194,12 +206,9 @@ def run_bootstraps(self, xs: NDArray[F]) -> _BootResult:
194206
pd.concat(ss_frames, names=["subset"]),
195207
)
196208

197-
def blb_subsets(self, xs: NDArray[F]):
198-
b = int(len(xs) ** self.b_factor)
199-
self._tracer.add_bindings(b=b)
200-
209+
def blb_subsets(self, n: int, b: int):
201210
while True:
202-
yield self.rng.choice(len(xs), b, replace=False)
211+
yield self.rng.choice(n, b, replace=False)
203212

204213
def measure_subset(self, xs: NDArray[F], ss: NDArray[np.int64]) -> _BootResult:
205214
b = len(ss)
@@ -211,7 +220,8 @@ def measure_subset(self, xs: NDArray[F], ss: NDArray[np.int64]) -> _BootResult:
211220
lbs = StatAccum(lambda a: np.quantile(a, self._ci_qmin))
212221
ubs = StatAccum(lambda a: np.quantile(a, self._ci_qmax))
213222

214-
for i, weights in enumerate(self.miniboot_weights(n, b)):
223+
loop = self._rep_generator.subsets()
224+
for i, weights in enumerate(loop):
215225
self._tracer.add_bindings(rep=i)
216226
self._tracer.trace("starting replicate")
217227
assert weights.shape == (b,)
@@ -223,7 +233,7 @@ def measure_subset(self, xs: NDArray[F], ss: NDArray[np.int64]) -> _BootResult:
223233
ubs.record(stat)
224234

225235
if self._check_convergence(means, vars, lbs, ubs, tol=self.tolerance, w=self.r_window):
226-
break
236+
loop.close()
227237

228238
df = pd.DataFrame({"statistic": means.values})
229239
df.index.name = "iter"
@@ -251,6 +261,75 @@ def _check_convergence(self, *arrays: StatAccum, tol: float, w: int) -> bool:
251261
return np.all(gaps < tol).item()
252262

253263

264+
class ReplicateGenerator:
265+
"""
266+
Generate the subset samples for a bootstrap in a background thread.
267+
"""
268+
269+
n: int
270+
b: int
271+
272+
_rng: np.random.Generator
273+
_flat: NDArray[np.float64]
274+
_lock: Lock
275+
_notify: Condition
276+
_running: bool = True
277+
_queue: Deque
278+
_thread: Thread
279+
280+
def __init__(self, n: int, b: int, rng: np.random.Generator):
281+
self.n = n
282+
self.b = b
283+
self._rng = rng.spawn(1)[0]
284+
self._queue = deque()
285+
self._flat = np.full(b, 1.0 / b)
286+
self._lock = Lock()
287+
self._notify = Condition(self._lock)
288+
289+
def subsets(self) -> Generator[NDArray[np.int64], None, None]:
290+
while True:
291+
with self._notify:
292+
while self._thread.is_alive() and len(self._queue) == 0:
293+
self._notify.wait()
294+
295+
try:
296+
val = self._queue.popleft()
297+
self._notify.notify_all()
298+
except IndexError:
299+
break # things have shut down, loop is over
300+
except GeneratorExit:
301+
break # we've been asked to close
302+
303+
yield val
304+
305+
def _generate(self):
306+
with self._notify:
307+
while True:
308+
# check if we need to wake up
309+
while self._running and len(self._queue) >= 5:
310+
trace(_log, "waiting for queue", len=len(self._queue))
311+
self._notify.wait()
312+
313+
# are we done?
314+
if not self._running:
315+
break
316+
317+
# generate a new value
318+
val = self._rng.multinomial(self.n, self._flat)
319+
self._queue.append(val)
320+
self._notify.notify_all()
321+
322+
def __enter__(self):
323+
self._thread = Thread(target=self._generate)
324+
self._thread.start()
325+
return self
326+
327+
def __exit__(self, *args: Any):
328+
with self._notify:
329+
self._running = False
330+
self._notify.notify_all()
331+
332+
254333
class StatAccum:
255334
INIT_SIZE: ClassVar[int] = 100
256335

0 commit comments

Comments
 (0)