-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathanalysis.py
More file actions
executable file
·125 lines (112 loc) · 5.27 KB
/
Copy pathanalysis.py
File metadata and controls
executable file
·125 lines (112 loc) · 5.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import math
import operator
from functools import reduce
import torch
from settings import BATCH_SIZE
def sdc(baseline, target, over_approximate=False, sdc_injection_ids=None):
baseline_top1, label = merge(baseline)
target_top1 = None
injections = []
for e in target:
top1 = e['top5'].T[0]
injections.append(e['config']['injection'])
if target_top1 is None:
target_top1 = top1
else:
target_top1 = torch.cat((target_top1, top1), dim=0)
unit_baseline_top1 = baseline_top1
unit_label = label
while len(baseline_top1) < len(target_top1):
baseline_top1 = torch.cat((baseline_top1, unit_baseline_top1), dim=0)
label = torch.cat((label, unit_label), dim=0)
n = len(target)
correct = label == target_top1
base_correct = label == baseline_top1
corrupted = torch.logical_and(torch.logical_not(correct), base_correct)
if sdc_injection_ids is not None:
for image_id, s in enumerate(corrupted):
if s:
sdc_injection_ids.append(injections[image_id // BATCH_SIZE])
observed_sdc_events = torch.sum(corrupted)
sdc = (max(1, observed_sdc_events) if over_approximate else observed_sdc_events) / torch.sum(base_correct)
z = 1.96 # 95% confidence interval
error = z * math.sqrt(sdc * (1 - sdc) / n)
return float(sdc), error
def elapsed_time(baseline, target):
return sum(e['elapsed_time'] for e in target), 0
def sc_detection_hit_rate(baseline, target):
baseline_top1, label = merge(baseline)
baseline_top1 = baseline_top1[[4, 10, 14, 16, 23, 27, 39, 51, 53, 64, 68, 109, 111, 120, 124, 131, 139,
143, 162, 215, 236, 242, 276, 284, 303, 332, 374, 384, 397, 405, 408, 413,
419, 420, 423, 424, 431, 432, 447, 448, 462, 466, 485, 502, 503, 511, 532,
536, 538, 540, 563, 581, 621, 662, 673, 677, 690, 693, 701, 733, 767, 774,
784, 789, 806, 808, 828, 851, 872, 877, 885, 907, 912, 915, 928, 929, 934,
948, 966, 998]]
label = label[[4, 10, 14, 16, 23, 27, 39, 51, 53, 64, 68, 109, 111, 120, 124, 131, 139,
143, 162, 215, 236, 242, 276, 284, 303, 332, 374, 384, 397, 405, 408, 413,
419, 420, 423, 424, 431, 432, 447, 448, 462, 466, 485, 502, 503, 511, 532,
536, 538, 540, 563, 581, 621, 662, 673, 677, 690, 693, 701, 733, 767, 774,
784, 789, 806, 808, 828, 851, 872, 877, 885, 907, 912, 915, 928, 929, 934,
948, 966, 998]]
sdc_batches = 0
detected_sdc_batches = 0
for e in target:
start_index = e['batch'] * e['batch_size'] % len(baseline_top1)
baseline_top1_chunk = baseline_top1[start_index: start_index + e['batch_size']]
if torch.sum(torch.logical_and(baseline_top1_chunk == e['label'],
e['top5'].T[0] != e['label'])) > 0:
sdc_batches += 1
if any(all(d) for d in e['detection']):
detected_sdc_batches += 1
return (detected_sdc_batches, sdc_batches), 0
def detection(baseline, target, term='sdc'):
baseline_top1, label = merge(baseline)
target_top1 = None
detection = None
for e in target:
top1 = e['top5'].T[0]
detected_channels = None
for d in e['detection']:
if d is not None:
detected_channels = d
break
evaluation_detection = torch.zeros(top1.shape, device=top1.device)
if detected_channels is not None:
for channel in detected_channels:
evaluation_detection[channel[0]] += 1
if target_top1 is None:
target_top1 = top1
detection = evaluation_detection
else:
target_top1 = torch.cat((target_top1, top1), dim=0)
detection = torch.cat((detection, evaluation_detection), dim=0)
unit_baseline_top1 = baseline_top1
unit_label = label
while len(baseline_top1) < len(target_top1):
baseline_top1 = torch.cat((baseline_top1, unit_baseline_top1), dim=0)
label = torch.cat((label, unit_label), dim=0)
n = len(target)
correct = label == target_top1
base_correct = label == baseline_top1
corrupted = torch.logical_and(torch.logical_not(correct), base_correct)
detected_uncorrected = torch.logical_and(detection > 0, corrupted)
sdc = torch.sum(corrupted) / torch.sum(base_correct)
due = torch.sum(detected_uncorrected) / torch.sum(base_correct)
term = eval(term)
z = 1.96 # 95% confidence interval
error = z * math.sqrt(term * (1 - term) / n)
return float(term), error
def merge(baseline):
baseline_top1 = None
label = None
for e in baseline:
top1 = e['top5'].T[0]
if baseline_top1 is None:
baseline_top1 = top1
else:
baseline_top1 = torch.cat((baseline_top1, top1), dim=0)
if label is None:
label = e['label']
else:
label = torch.cat((label, e['label']), dim=0)
return baseline_top1, label