Skip to content

Commit c61a2e1

Browse files
committed
Emit an error if we're overwriting ground truth with different params
Overwriting is allowed using the --overwrite flag.
1 parent b6b51ae commit c61a2e1

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

tests/ekfac_tests/compute_ekfac_ground_truth.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def allocate_batches_test(
148148

149149

150150
# %%
151-
def parse_config() -> tuple[Precision, str, str, int]:
151+
def parse_config() -> tuple[Precision, str, str, int, bool]:
152152
"""Parse command-line arguments or return defaults."""
153153
parser = argparse.ArgumentParser(
154154
description="Compute EKFAC ground truth for testing"
@@ -181,6 +181,12 @@ def parse_config() -> tuple[Precision, str, str, int]:
181181
default=1,
182182
help="Number of workers for simulated distributed computation (default: 1)",
183183
)
184+
parser.add_argument(
185+
"--overwrite",
186+
action="store_true",
187+
default=False,
188+
help="Overwrite existing ground truth data and config",
189+
)
184190

185191
# For interactive mode (Jupyter/IPython) or no args, use defaults
186192
if len(sys.argv) > 1 and not hasattr(builtins, "__IPYTHON__"):
@@ -191,11 +197,11 @@ def parse_config() -> tuple[Precision, str, str, int]:
191197
# Set random seeds for reproducibility
192198
set_all_seeds(42)
193199

194-
return args.precision, args.output_dir, args.model_name, args.world_size
200+
return args.precision, args.output_dir, args.model_name, args.world_size, args.overwrite
195201

196202

197203
if __name__ == "__main__" or TYPE_CHECKING:
198-
precision, test_path, model_name, world_size_arg = parse_config()
204+
precision, test_path, model_name, world_size_arg, overwrite_arg = parse_config()
199205

200206

201207
# %%
@@ -204,6 +210,7 @@ def setup_paths_and_config(
204210
test_path: str,
205211
model_name: str,
206212
world_size: int,
213+
overwrite: bool = False,
207214
) -> tuple[IndexConfig, int, torch.device, Any, torch.dtype]:
208215
"""Setup paths and configuration object."""
209216
os.makedirs(test_path, exist_ok=True)
@@ -240,9 +247,37 @@ def setup_paths_and_config(
240247
subset.save_to_disk(data_str)
241248
print(f"Generated pile-100 in {data_str}")
242249

243-
# Save config
244-
with open(os.path.join(test_path, "index_config.json"), "w") as f:
245-
json.dump(asdict(cfg), f, indent=4)
250+
config_path = os.path.join(test_path, "index_config.json")
251+
if os.path.exists(config_path):
252+
if not overwrite:
253+
# Load existing config and compare
254+
with open(config_path, "r") as f:
255+
existing_cfg_dict = json.load(f)
256+
257+
new_cfg_dict = asdict(cfg)
258+
259+
if existing_cfg_dict != new_cfg_dict:
260+
# Show differences for debugging
261+
diffs = [
262+
f" {k}: {existing_cfg_dict[k]} != {new_cfg_dict[k]}"
263+
for k in new_cfg_dict
264+
if k in existing_cfg_dict and existing_cfg_dict[k] != new_cfg_dict[k]
265+
]
266+
raise RuntimeError(
267+
f"Existing config at {config_path} differs from requested config:\n"
268+
+ "\n".join(diffs)
269+
+ "\n\nUse --overwrite to replace the existing config."
270+
)
271+
272+
print(f"Using existing config from {config_path}")
273+
else:
274+
print(f"Overwriting existing config at {config_path}")
275+
with open(config_path, "w") as f:
276+
json.dump(asdict(cfg), f, indent=4)
277+
else:
278+
# Save new config
279+
with open(config_path, "w") as f:
280+
json.dump(asdict(cfg), f, indent=4)
246281

247282
# Setup
248283
workers = world_size
@@ -271,7 +306,7 @@ def setup_paths_and_config(
271306

272307
if __name__ == "__main__" or TYPE_CHECKING:
273308
cfg, workers, device, target_modules, dtype = setup_paths_and_config(
274-
precision, test_path, model_name, world_size_arg
309+
precision, test_path, model_name, world_size_arg, overwrite_arg
275310
)
276311

277312

tests/ekfac_tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def ground_truth_base_path(test_dir: str) -> str:
144144

145145

146146
@pytest.fixture(scope="session")
147-
def ground_truth_setup(request, test_dir: str, precision: Precision) -> dict[str, Any]:
147+
def ground_truth_setup(request, test_dir: str, precision: Precision, overwrite: bool) -> dict[str, Any]:
148148
set_all_seeds(seed=42)
149149

150150
# Setup for generation
@@ -163,6 +163,7 @@ def ground_truth_setup(request, test_dir: str, precision: Precision) -> dict[str
163163
test_path=ground_truth_base_path(test_dir),
164164
model_name=model_name,
165165
world_size=world_size,
166+
overwrite=overwrite,
166167
)
167168

168169
model = load_model_step(cfg, dtype)

0 commit comments

Comments
 (0)