11"""Warning checker."""
2- from typing import Dict , List , Sequence
2+ from collections import defaultdict
3+ from itertools import groupby
4+ from typing import Dict , List , NamedTuple , Sequence
35from warnings import warn
46
57from .spy_events import (
810 SpyEvent ,
911 VerifyRehearsal ,
1012 WhenRehearsal ,
13+ SpyRehearsal ,
1114 match_event ,
1215)
1316from .warnings import DecoyWarning , MiscalledStubWarning , RedundantVerifyWarning
@@ -23,44 +26,62 @@ def check(all_calls: Sequence[AnySpyEvent]) -> None:
2326 _check_no_redundant_verify (all_calls )
2427
2528
29+ class _Call (NamedTuple ):
30+ event : SpyEvent
31+ all_rehearsals : List [SpyRehearsal ]
32+ matching_rehearsals : List [SpyRehearsal ]
33+
34+
2635def _check_no_miscalled_stubs (all_events : Sequence [AnySpyEvent ]) -> None :
2736 """Ensure every call matches a rehearsal, if the spy has rehearsals."""
28- all_calls_by_id : Dict [int , List [AnySpyEvent ]] = {}
37+ all_events_by_id : Dict [int , List [AnySpyEvent ]] = defaultdict (list )
38+ all_calls_by_id : Dict [int , List [_Call ]] = defaultdict (list )
2939
3040 for event in all_events :
31- if isinstance (event .payload , SpyCall ):
32- spy_id = event .spy .id
33- spy_calls = all_calls_by_id .get (spy_id , [])
34- all_calls_by_id [spy_id ] = [* spy_calls , event ]
41+ all_events_by_id [event .spy .id ].append (event )
42+
43+ for events in all_events_by_id .values ():
44+ for index , event in enumerate (events ):
45+ if isinstance (event , SpyEvent ) and isinstance (event .payload , SpyCall ):
46+ when_rehearsals = [
47+ rehearsal
48+ for rehearsal in events [0 :index ]
49+ if isinstance (rehearsal , WhenRehearsal )
50+ and isinstance (rehearsal .payload , SpyCall )
51+ ]
52+ verify_rehearsals = [
53+ rehearsal
54+ for rehearsal in events [index + 1 :]
55+ if isinstance (rehearsal , VerifyRehearsal )
56+ and isinstance (rehearsal .payload , SpyCall )
57+ ]
58+
59+ all_rehearsals : List [SpyRehearsal ] = [
60+ * when_rehearsals ,
61+ * verify_rehearsals ,
62+ ]
63+ matching_rehearsals = [
64+ rehearsal
65+ for rehearsal in all_rehearsals
66+ if match_event (event , rehearsal )
67+ ]
68+
69+ all_calls_by_id [event .spy .id ].append (
70+ _Call (event , all_rehearsals , matching_rehearsals )
71+ )
3572
3673 for spy_calls in all_calls_by_id .values ():
37- unmatched : List [SpyEvent ] = []
38-
39- for index , call in enumerate (spy_calls ):
40- past_stubs = [
41- wr for wr in spy_calls [0 :index ] if isinstance (wr , WhenRehearsal )
42- ]
43-
44- matched_past_stubs = [wr for wr in past_stubs if match_event (call , wr )]
45-
46- matched_future_verifies = [
47- vr
48- for vr in spy_calls [index + 1 :]
49- if isinstance (vr , VerifyRehearsal ) and match_event (call , vr )
50- ]
51-
52- if (
53- isinstance (call , SpyEvent )
54- and len (past_stubs ) > 0
55- and len (matched_past_stubs ) == 0
56- and len (matched_future_verifies ) == 0
57- ):
58- unmatched = [* unmatched , call ]
59- if index == len (spy_calls ) - 1 :
60- _warn (MiscalledStubWarning (calls = unmatched , rehearsals = past_stubs ))
61- elif isinstance (call , WhenRehearsal ) and len (unmatched ) > 0 :
62- _warn (MiscalledStubWarning (calls = unmatched , rehearsals = past_stubs ))
63- unmatched = []
74+ for rehearsals , grouped_calls in groupby (spy_calls , lambda c : c .all_rehearsals ):
75+ calls = list (grouped_calls )
76+ is_stubbed = any (isinstance (r , WhenRehearsal ) for r in rehearsals )
77+
78+ if is_stubbed and all (len (c .matching_rehearsals ) == 0 for c in calls ):
79+ _warn (
80+ MiscalledStubWarning (
81+ calls = [c .event for c in calls ],
82+ rehearsals = rehearsals ,
83+ )
84+ )
6485
6586
6687def _check_no_redundant_verify (all_calls : Sequence [AnySpyEvent ]) -> None :
0 commit comments