33# Licensed under the MIT license as detailed in LICENSE.txt
44
55import asyncio
6- import time
76from pathlib import Path
87from unittest import mock
98
1211
1312from aiolimiter import AsyncLimiter
1413
15- WAIT_LIMIT = 2 # seconds before we declare the test failed
14+ # max number of wait_for rounds when waiting for all events to settle
15+ MAX_WAIT_FOR_ITER = 5
1616
1717
1818def test_version ():
@@ -32,14 +32,16 @@ def test_version():
3232
3333async def wait_for_n_done (tasks , n ):
3434 """Wait for n (or more) tasks to have completed"""
35- start = time .time ()
36- pending = tasks
35+ iteration = 0
3736 remainder = len (tasks ) - n
38- while time .time () <= start + WAIT_LIMIT and len (pending ) > remainder :
37+ while iteration <= MAX_WAIT_FOR_ITER :
38+ iteration += 1
3939 _ , pending = await asyncio .wait (
40- tasks , timeout = WAIT_LIMIT , return_when = asyncio .FIRST_COMPLETED
40+ tasks , timeout = 0 , return_when = asyncio .FIRST_COMPLETED
4141 )
42- assert len (pending ) >= remainder
42+ if len (pending ) <= remainder :
43+ break
44+ assert len (pending ) <= remainder
4345 return pending
4446
4547
@@ -73,33 +75,61 @@ async def async_contextmanager_task(limiter):
7375 pass
7476
7577
76- @pytest .mark .parametrize ("task" , [acquire_task , async_contextmanager_task ])
77- async def test_acquire (task ):
78- current_time = 0
78+ class MockLoopTime :
79+ def __init__ (self ):
80+ self .current_time = 0
81+ event_loop = asyncio .get_running_loop ()
82+ self .patch = mock .patch .object (event_loop , "time" , self .mocked_time )
83+
84+ def mocked_time (self ):
85+ return self .current_time
7986
80- def mocked_time ():
81- return current_time
87+ def __enter__ (self ):
88+ self .patch .start ()
89+ return self
8290
91+ def __exit__ (self , * _ ):
92+ self .patch .stop ()
93+
94+
95+ @pytest .mark .parametrize ("task" , [acquire_task , async_contextmanager_task ])
96+ async def test_acquire (task ):
8397 # capacity released every 2 seconds
8498 limiter = AsyncLimiter (5 , 10 )
8599
86- event_loop = asyncio .get_running_loop ()
87- with mock .patch .object (event_loop , "time" , mocked_time ):
100+ with MockLoopTime () as mocked_time :
88101 tasks = [asyncio .ensure_future (task (limiter )) for _ in range (10 )]
89102
90103 pending = await wait_for_n_done (tasks , 5 )
91104 assert len (pending ) == 5
92105
93- current_time = 3 # releases capacity for one and some buffer
106+ mocked_time . current_time = 3 # releases capacity for one and some buffer
94107 assert limiter .has_capacity ()
95108
96109 pending = await wait_for_n_done (pending , 1 )
97110 assert len (pending ) == 4
98111
99- current_time = 7 # releases capacity for two more, plus buffer
112+ mocked_time . current_time = 7 # releases capacity for two more, plus buffer
100113 pending = await wait_for_n_done (pending , 2 )
101114 assert len (pending ) == 2
102115
103- current_time = 11 # releases the remainder
116+ mocked_time . current_time = 11 # releases the remainder
104117 pending = await wait_for_n_done (pending , 2 )
105118 assert len (pending ) == 0
119+
120+
121+ async def test_acquire_wait_time ():
122+ limiter = AsyncLimiter (3 , 3 )
123+
124+ with MockLoopTime () as mocked_time :
125+ # Fill the bucket with an amount of 1
126+ await limiter .acquire (1 )
127+
128+ # Acquiring an amount of 3 now should take 1 second
129+ task = asyncio .ensure_future (limiter .acquire (3 ))
130+ pending = await wait_for_n_done ([task ], 0 )
131+ assert pending
132+
133+ mocked_time .current_time = 1
134+ pending = await wait_for_n_done ([task ], 1 )
135+ assert not pending
0 commit comments