forked from haonan-yuan/RAG-GFM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcentrality_utils.py
More file actions
167 lines (133 loc) · 5.79 KB
/
centrality_utils.py
File metadata and controls
167 lines (133 loc) · 5.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
import torch.nn.functional as F
import math
from torch_geometric.utils import scatter, to_dense_adj
from torch_geometric.utils.num_nodes import maybe_num_nodes
import time
from tqdm import tqdm
import os
import hashlib
import pickle
from typing import Optional, Dict, Any
def _get_cache_dir():
cache_dir = "cse_cache"
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
return cache_dir
def _generate_cache_key(edge_index, num_nodes, ksteps, k, normalize_cse):
key_data = {
'edge_index_shape': edge_index.shape,
'edge_index_hash': hashlib.md5(edge_index.cpu().numpy().tobytes()).hexdigest(),
'num_nodes': num_nodes,
'ksteps': sorted(ksteps) if ksteps else None,
'k': k,
'normalize_cse': normalize_cse
}
key_str = str(sorted(key_data.items()))
return hashlib.md5(key_str.encode()).hexdigest()
def _get_cache_path(cache_key):
cache_dir = _get_cache_dir()
return os.path.join(cache_dir, f"cse_{cache_key}.pkl")
def _load_cse_cache(cache_key):
cache_path = _get_cache_path(cache_key)
if os.path.exists(cache_path):
try:
with open(cache_path, 'rb') as f:
cached_data = pickle.load(f)
print(f" - Loaded CSE features from cache: {cache_path}")
return cached_data
except Exception as e:
print(f" - Failed to load CSE features from cache: {e}")
return None
return None
def _save_cse_cache(cache_key, results):
cache_path = _get_cache_path(cache_key)
try:
with open(cache_path, 'wb') as f:
pickle.dump(results, f)
print(f" - CSE features saved to cache: {cache_path}")
except Exception as e:
print(f" - Failed to save CSE features to cache: {e}")
def compute_walk_based_centrality(ksteps, edge_index, edge_weight=None, num_nodes=None, space_dim=0):
if edge_weight is None:
edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
num_nodes = maybe_num_nodes(edge_index, num_nodes)
print(f" - Building adjacency matrix ({num_nodes}×{num_nodes})...")
P = to_dense_adj(edge_index, max_num_nodes=num_nodes, edge_attr=edge_weight)[0]
print(f" - Normalizing transition matrix...")
deg_inv = P.sum(dim=1).pow(-1.)
deg_inv.masked_fill_(deg_inv == float('inf'), 0)
P = P * deg_inv.unsqueeze(1)
rws = []
print(f" - Computing matrix power (steps: {min(ksteps)}-{max(ksteps)})...")
try:
Pk = torch.matrix_power(P, min(ksteps))
except TypeError:
Pk = P.clone()
for _ in range(min(ksteps) - 1):
Pk = torch.matmul(Pk, P)
k_range = range(min(ksteps), max(ksteps) + 1)
for k in tqdm(k_range, desc=" - Computing walk probabilities", unit="step"):
rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1).unsqueeze(1))
if k < max(ksteps):
Pk = torch.matmul(Pk, P)
rw_landing = torch.cat(rws, dim=1) # (Num nodes) x (len(ksteps))
return rw_landing
def select_top_k_nodes(centralities, k):
num_nodes = centralities.size(0)
total_centrality = centralities.sum(dim=1)
order = torch.argsort(total_centrality, descending=True)
num_to_select = min(num_nodes, k)
top_k_indices = order[:num_to_select]
coloring = F.one_hot(top_k_indices, num_classes=int(num_nodes)).T
top_k_mask = torch.zeros(num_nodes, dtype=torch.bool, device=centralities.device)
top_k_mask[top_k_indices] = True
return order, top_k_indices, top_k_mask
def extract_cse_encodings(centralities, normalize=True):
if normalize:
mean = centralities.mean(dim=0, keepdim=True)
std = centralities.std(dim=0, keepdim=True)
std = torch.where(std < 1e-8, torch.ones_like(std), std)
cse_encodings = (centralities - mean) / std
else:
cse_encodings = centralities
return cse_encodings
def compute_centrality_and_cse(edge_index, ksteps=None, k=None, edge_weight=None,
num_nodes=None, space_dim=0, normalize_cse=True,
use_cache=True):
if ksteps is None:
ksteps = list(range(1, 17))
if k is None:
k = int(0.1 * num_nodes) if num_nodes else 100
if use_cache:
cache_key = _generate_cache_key(edge_index, num_nodes, ksteps, k, normalize_cse)
cached_results = _load_cse_cache(cache_key)
if cached_results is not None:
return cached_results
print(f" - Cache miss, starting to compute CSE features...")
centralities = compute_walk_based_centrality(
ksteps=ksteps, edge_index=edge_index, edge_weight=edge_weight,
num_nodes=num_nodes, space_dim=space_dim
)
centrality_ranking, top_k_indices, top_k_mask = select_top_k_nodes(centralities, k)
cse_encodings = extract_cse_encodings(centralities, normalize=normalize_cse)
results = {
'centralities': centralities,
'centrality_ranking': centrality_ranking,
'top_k_indices': top_k_indices,
'top_k_mask': top_k_mask,
'cse_encodings': cse_encodings
}
if use_cache:
_save_cse_cache(cache_key, results)
return results
def integrate_with_pyg_data(data, ksteps=None, k=None, normalize_cse=True, use_cache=True):
results = compute_centrality_and_cse(
edge_index=data.edge_index, ksteps=ksteps, k=k,
edge_weight=getattr(data, 'edge_weight', None),
num_nodes=data.num_nodes, normalize_cse=normalize_cse, use_cache=use_cache
)
data.cse_encodings = results['cse_encodings']
data.centrality_ranking = results['centrality_ranking']
data.top_k_indices = results['top_k_indices']
return data