Skip to content

Commit df1ef3d

Browse files
committed
Add pre-commit as a dev dependency and run it
Ran "uv pre-commit run --all-files" which reads from .pre-commit-config.yaml Unfortunately pre-commit does not respect tool settings in pyproject.toml, so right now there's conflicting informations in pyproject.toml and .pre-commit-config.yaml and so different settings and tool versions used depending on how we run tools.
1 parent ed73c98 commit df1ef3d

File tree

5 files changed

+66
-53
lines changed

5 files changed

+66
-53
lines changed

bergson/hessians/data_filtering_ekfac.ipynb

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,17 @@
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
16+
"import json\n",
1617
"import os\n",
17-
"from typing import Literal\n",
1818
"\n",
1919
"import matplotlib.pyplot as plt\n",
2020
"import numpy as np\n",
2121
"import pandas as pd\n",
2222
"import torch\n",
23-
"from datasets import load_dataset\n",
24-
"from tqdm.notebook import tqdm\n",
25-
"\n",
26-
"from bergson.data import load_gradients\n",
2723
"from safetensors.torch import load_file\n",
28-
"\n",
29-
"import numpy as np\n",
30-
"import matplotlib.pyplot as plt\n",
3124
"from scipy.stats import spearmanr\n",
3225
"\n",
33-
"import json"
26+
"from bergson.data import load_gradients"
3427
]
3528
},
3629
{
@@ -373,8 +366,6 @@
373366
],
374367
"source": [
375368
"import numpy as np\n",
376-
"import matplotlib.pyplot as plt\n",
377-
"from scipy.stats import spearmanr\n",
378369
"\n",
379370
"# Calculate Spearman correlation\n",
380371
"mask = ~(np.isnan(attributions_scores) | np.isnan(attributions_ekfac_scores))\n",
@@ -461,9 +452,9 @@
461452
")\n",
462453
"\n",
463454
"\n",
464-
"plt.plot(np.array(top_percentages) * len(index), intersections, label=f\"Query without ekfac\")\n",
465-
"plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=f\"Query with ekfac\")\n",
466-
"plt.plot(np.array(top_percentages) * len(index), intersections_random, label=f\"Random baseline\")\n",
455+
"plt.plot(np.array(top_percentages) * len(index), intersections, label=\"Query without ekfac\")\n",
456+
"plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=\"Query with ekfac\")\n",
457+
"plt.plot(np.array(top_percentages) * len(index), intersections_random, label=\"Random baseline\")\n",
467458
"plt.xlabel(\"Number of elements removed\")\n",
468459
"plt.ylabel('Number of elements in the \"correct\" half')\n",
469460
"plt.title(\"EK-FAC, no attn, on train set\")\n",
@@ -638,7 +629,6 @@
638629
"# load the saved attributions\n",
639630
"import json\n",
640631
"\n",
641-
"\n",
642632
"all_attributions = {}\n",
643633
"\n",
644634
"for path in all_query_paths:\n",
@@ -849,7 +839,7 @@
849839
],
850840
"source": [
851841
"# plot intersection\n",
852-
"plt.plot(np.array(top_percentages) * len(index), intersection_12, label=f\"Intersection\")\n",
842+
"plt.plot(np.array(top_percentages) * len(index), intersection_12, label=\"Intersection\")\n",
853843
"plt.plot(\n",
854844
" [0, len(index) // 2, len(index)],\n",
855845
" [0, len(index) // 2, len(index)],\n",
@@ -1218,9 +1208,9 @@
12181208
")\n",
12191209
"\n",
12201210
"\n",
1221-
"plt.plot(np.array(top_percentages) * len(index), intersections, label=f\"Query without ekfac\")\n",
1222-
"plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=f\"Query with ekfac\")\n",
1223-
"plt.plot(np.array(top_percentages) * len(index), intersections_random, label=f\"Random baseline\")\n",
1211+
"plt.plot(np.array(top_percentages) * len(index), intersections, label=\"Query without ekfac\")\n",
1212+
"plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=\"Query with ekfac\")\n",
1213+
"plt.plot(np.array(top_percentages) * len(index), intersections_random, label=\"Random baseline\")\n",
12241214
"plt.xlabel(\"Number of elements removed\")\n",
12251215
"plt.ylabel('Number of elements in the \"correct\" half')\n",
12261216
"plt.legend()\n",
@@ -1374,9 +1364,10 @@
13741364
}
13751365
],
13761366
"source": [
1377-
"import torch\n",
13781367
"import os\n",
13791368
"\n",
1369+
"import torch\n",
1370+
"\n",
13801371
"# Set the debug flag - this is the correct way\n",
13811372
"os.environ[\"TORCH_COMPILE_DEBUG\"] = \"1\"\n",
13821373
"\n",

bergson/hessians/misaligned_datasets.ipynb

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,16 @@
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
16+
"import json\n",
1617
"import os\n",
17-
"from typing import Literal\n",
18-
"import joblib\n",
18+
"\n",
1919
"import matplotlib.pyplot as plt\n",
2020
"import numpy as np\n",
2121
"import pandas as pd\n",
2222
"import torch\n",
23-
"from datasets import load_dataset\n",
24-
"from tqdm.notebook import tqdm\n",
25-
"import json\n",
26-
"from datasets import Dataset\n",
27-
"from bergson.data import load_gradients\n",
28-
"from safetensors.torch import load_file\n",
29-
"from sklearn.metrics import roc_auc_score\n",
30-
"import numpy as np\n",
31-
"\n",
32-
"from sklearn.metrics import roc_auc_score, precision_recall_curve, auc\n"
23+
"from datasets import Dataset, load_dataset\n",
24+
"from sklearn.metrics import auc, precision_recall_curve, roc_auc_score\n",
25+
"\n"
3326
]
3427
},
3528
{
@@ -439,14 +432,14 @@
439432
"print(f\"PR AUC: {pr_auc:.4f}\")\n",
440433
"\n",
441434
"# Additional metrics for analysis\n",
442-
"print(f\"\\nDataset composition:\")\n",
435+
"print(\"\\nDataset composition:\")\n",
443436
"print(f\"Correct examples: {len(sorted_correct_scores)}\")\n",
444437
"print(f\"Incorrect examples: {len(sorted_incorrect_scores)}\")\n",
445438
"print(f\"Subtle incorrect examples: {len(sorted_subtle_scores)}\")\n",
446439
"print(f\"Total examples: {len(all_scores)}\")\n",
447440
"print(f\"Problematic ratio: {(len(sorted_incorrect_scores) + len(sorted_subtle_scores)) / len(all_scores):.3f}\")\n",
448441
"\n",
449-
"print(f\"\\nScore statistics:\")\n",
442+
"print(\"\\nScore statistics:\")\n",
450443
"print(f\"Correct scores - Mean: {sorted_correct_scores.mean():.4f}, Std: {sorted_correct_scores.std():.4f}\")\n",
451444
"print(f\"Incorrect scores - Mean: {sorted_incorrect_scores.mean():.4f}, Std: {sorted_incorrect_scores.std():.4f}\")\n",
452445
"print(f\"Subtle scores - Mean: {sorted_subtle_scores.mean():.4f}, Std: {sorted_subtle_scores.std():.4f}\")\n"

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ line-length = 120
5656

5757
[dependency-groups]
5858
dev = [
59+
"pre-commit>=4.2.0",
60+
"pre-commit-uv>=4.1.5",
5961
"pyright>=1.1.406",
6062
"pytest>=8.4.2",
6163
]

tests/ekfac_tests/apply_ekfac_ground_truth.ipynb

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,35 +23,19 @@
2323
"metadata": {},
2424
"outputs": [],
2525
"source": [
26-
"import gc\n",
2726
"import hashlib\n",
2827
"import json\n",
2928
"import os\n",
30-
"import random\n",
31-
"from contextlib import nullcontext\n",
32-
"from typing import Literal, Optional\n",
29+
"from typing import Literal\n",
3330
"\n",
34-
"import numpy as np\n",
3531
"import torch\n",
36-
"import torch.distributed as dist\n",
37-
"import torch.nn.functional as F\n",
3832
"from datasets import Dataset\n",
39-
"from jaxtyping import Float\n",
40-
"from safetensors import safe_open\n",
4133
"from safetensors.torch import load_file, save_file\n",
4234
"from torch import Tensor\n",
4335
"\n",
44-
"from tqdm.auto import tqdm\n",
45-
"from transformers import PreTrainedModel\n",
46-
"\n",
4736
"from bergson.collection import collect_gradients\n",
48-
"from bergson.data import DataConfig, IndexConfig, create_index, load_gradients, pad_and_tensor\n",
49-
"from bergson.distributed import distributed_computing, setup_data_pipeline\n",
50-
"from bergson.gradients import (\n",
51-
" GradientProcessor,\n",
52-
")\n",
53-
"from bergson.hessians.collector import EkfacCollector\n",
54-
"from bergson.hessians.logger import get_logger"
37+
"from bergson.data import DataConfig, IndexConfig, load_gradients\n",
38+
"from bergson.distributed import distributed_computing, setup_data_pipeline"
5539
]
5640
},
5741
{

uv.lock

Lines changed: 43 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)