77from __future__ import annotations
88
99import warnings
10- from collections .abc import Callable
10+ from collections import deque
11+ from collections .abc import Callable , Generator
1112from dataclasses import dataclass
12- from typing import Any , ClassVar , Literal , Protocol , TypeAlias , TypeVar
13+ from threading import Condition , Lock , Thread
14+ from typing import Any , ClassVar , Deque , Literal , Protocol , TypeAlias , TypeVar
1315
1416import numpy as np
1517import pandas as pd
1618from numpy .typing import NDArray
1719
1820from lenskit .diagnostics import DataWarning
19- from lenskit .logging import Tracer , get_logger , get_tracer
21+ from lenskit .logging import Tracer , get_logger , get_tracer , trace
2022from lenskit .random import RNGInput , random_generator
2123
2224F = TypeVar ("F" , bound = np .floating , covariant = True )
@@ -139,6 +141,7 @@ class _BLBootstrapper:
139141 r_window : int
140142 b_factor : float
141143 rng : np .random .Generator
144+ _rep_generator : ReplicateGenerator
142145
143146 def __init__ (
144147 self ,
@@ -165,7 +168,10 @@ def __init__(
165168 self ._tracer = get_tracer (_log , stat = stat .__name__ ) # type: ignore
166169
167170 def run_bootstraps (self , xs : NDArray [F ]) -> _BootResult :
168- self ._tracer .add_bindings (n = len (xs ))
171+ n = len (xs )
172+ b = int (n ** self .b_factor )
173+
174+ self ._tracer .add_bindings (n = n , b = b )
169175 _log .debug ("starting bootstrap" , stat = self .statistic .__name__ , n = len (xs )) # type: ignore
170176 ss_frames = {}
171177
@@ -174,17 +180,23 @@ def run_bootstraps(self, xs: NDArray[F]) -> _BootResult:
174180 lbs = StatAccum (np .mean )
175181 ubs = StatAccum (np .mean )
176182
177- for i , ss in enumerate (self .blb_subsets (xs )):
178- self ._tracer .add_bindings (subset = i )
179- self ._tracer .trace ("starting subset" )
180- res = self .measure_subset (xs , ss )
181- ss_frames [i ] = res .samples
182- means .record (res .rep_mean )
183- vars .record (res .rep_var )
184- lbs .record (res .ci_lower )
185- ubs .record (res .ci_upper )
186- if self ._check_convergence (means , vars , lbs , ubs , tol = self .tolerance , w = self .s_window ):
187- break
183+ self ._rep_generator = ReplicateGenerator (n , b , self .rng )
184+ self ._tracer .trace ("let's go!" )
185+
186+ with self ._rep_generator :
187+ for i , ss in enumerate (self .blb_subsets (n , b )):
188+ self ._tracer .add_bindings (subset = i )
189+ self ._tracer .trace ("starting subset" )
190+ res = self .measure_subset (xs , ss )
191+ ss_frames [i ] = res .samples
192+ means .record (res .rep_mean )
193+ vars .record (res .rep_var )
194+ lbs .record (res .ci_lower )
195+ ubs .record (res .ci_upper )
196+ if self ._check_convergence (
197+ means , vars , lbs , ubs , tol = self .tolerance , w = self .s_window
198+ ):
199+ break
188200
189201 return _BootResult (
190202 means .statistic ,
@@ -194,12 +206,9 @@ def run_bootstraps(self, xs: NDArray[F]) -> _BootResult:
194206 pd .concat (ss_frames , names = ["subset" ]),
195207 )
196208
197- def blb_subsets (self , xs : NDArray [F ]):
198- b = int (len (xs ) ** self .b_factor )
199- self ._tracer .add_bindings (b = b )
200-
209+ def blb_subsets (self , n : int , b : int ):
201210 while True :
202- yield self .rng .choice (len ( xs ) , b , replace = False )
211+ yield self .rng .choice (n , b , replace = False )
203212
204213 def measure_subset (self , xs : NDArray [F ], ss : NDArray [np .int64 ]) -> _BootResult :
205214 b = len (ss )
@@ -211,7 +220,8 @@ def measure_subset(self, xs: NDArray[F], ss: NDArray[np.int64]) -> _BootResult:
211220 lbs = StatAccum (lambda a : np .quantile (a , self ._ci_qmin ))
212221 ubs = StatAccum (lambda a : np .quantile (a , self ._ci_qmax ))
213222
214- for i , weights in enumerate (self .miniboot_weights (n , b )):
223+ loop = self ._rep_generator .subsets ()
224+ for i , weights in enumerate (loop ):
215225 self ._tracer .add_bindings (rep = i )
216226 self ._tracer .trace ("starting replicate" )
217227 assert weights .shape == (b ,)
@@ -223,7 +233,7 @@ def measure_subset(self, xs: NDArray[F], ss: NDArray[np.int64]) -> _BootResult:
223233 ubs .record (stat )
224234
225235 if self ._check_convergence (means , vars , lbs , ubs , tol = self .tolerance , w = self .r_window ):
226- break
236+ loop . close ()
227237
228238 df = pd .DataFrame ({"statistic" : means .values })
229239 df .index .name = "iter"
@@ -251,6 +261,75 @@ def _check_convergence(self, *arrays: StatAccum, tol: float, w: int) -> bool:
251261 return np .all (gaps < tol ).item ()
252262
253263
264+ class ReplicateGenerator :
265+ """
266+ Generate the subset samples for a bootstrap in a background thread.
267+ """
268+
269+ n : int
270+ b : int
271+
272+ _rng : np .random .Generator
273+ _flat : NDArray [np .float64 ]
274+ _lock : Lock
275+ _notify : Condition
276+ _running : bool = True
277+ _queue : Deque
278+ _thread : Thread
279+
280+ def __init__ (self , n : int , b : int , rng : np .random .Generator ):
281+ self .n = n
282+ self .b = b
283+ self ._rng = rng .spawn (1 )[0 ]
284+ self ._queue = deque ()
285+ self ._flat = np .full (b , 1.0 / b )
286+ self ._lock = Lock ()
287+ self ._notify = Condition (self ._lock )
288+
289+ def subsets (self ) -> Generator [NDArray [np .int64 ], None , None ]:
290+ while True :
291+ with self ._notify :
292+ while self ._thread .is_alive () and len (self ._queue ) == 0 :
293+ self ._notify .wait ()
294+
295+ try :
296+ val = self ._queue .popleft ()
297+ self ._notify .notify_all ()
298+ except IndexError :
299+ break # things have shut down, loop is over
300+ except GeneratorExit :
301+ break # we've been asked to close
302+
303+ yield val
304+
305+ def _generate (self ):
306+ with self ._notify :
307+ while True :
308+ # check if we need to wake up
309+ while self ._running and len (self ._queue ) >= 5 :
310+ trace (_log , "waiting for queue" , len = len (self ._queue ))
311+ self ._notify .wait ()
312+
313+ # are we done?
314+ if not self ._running :
315+ break
316+
317+ # generate a new value
318+ val = self ._rng .multinomial (self .n , self ._flat )
319+ self ._queue .append (val )
320+ self ._notify .notify_all ()
321+
322+ def __enter__ (self ):
323+ self ._thread = Thread (target = self ._generate )
324+ self ._thread .start ()
325+ return self
326+
327+ def __exit__ (self , * args : Any ):
328+ with self ._notify :
329+ self ._running = False
330+ self ._notify .notify_all ()
331+
332+
254333class StatAccum :
255334 INIT_SIZE : ClassVar [int ] = 100
256335
0 commit comments