@@ -148,7 +148,7 @@ def allocate_batches_test(
148148
149149
150150# %%
151- def parse_config () -> tuple [Precision , str , str , int ]:
151+ def parse_config () -> tuple [Precision , str , str , int , bool ]:
152152 """Parse command-line arguments or return defaults."""
153153 parser = argparse .ArgumentParser (
154154 description = "Compute EKFAC ground truth for testing"
@@ -181,6 +181,12 @@ def parse_config() -> tuple[Precision, str, str, int]:
181181 default = 1 ,
182182 help = "Number of workers for simulated distributed computation (default: 1)" ,
183183 )
184+ parser .add_argument (
185+ "--overwrite" ,
186+ action = "store_true" ,
187+ default = False ,
188+ help = "Overwrite existing ground truth data and config" ,
189+ )
184190
185191 # For interactive mode (Jupyter/IPython) or no args, use defaults
186192 if len (sys .argv ) > 1 and not hasattr (builtins , "__IPYTHON__" ):
@@ -191,11 +197,11 @@ def parse_config() -> tuple[Precision, str, str, int]:
191197 # Set random seeds for reproducibility
192198 set_all_seeds (42 )
193199
194- return args .precision , args .output_dir , args .model_name , args .world_size
200+ return args .precision , args .output_dir , args .model_name , args .world_size , args . overwrite
195201
196202
197203if __name__ == "__main__" or TYPE_CHECKING :
198- precision , test_path , model_name , world_size_arg = parse_config ()
204+ precision , test_path , model_name , world_size_arg , overwrite_arg = parse_config ()
199205
200206
201207# %%
@@ -204,6 +210,7 @@ def setup_paths_and_config(
204210 test_path : str ,
205211 model_name : str ,
206212 world_size : int ,
213+ overwrite : bool = False ,
207214) -> tuple [IndexConfig , int , torch .device , Any , torch .dtype ]:
208215 """Setup paths and configuration object."""
209216 os .makedirs (test_path , exist_ok = True )
@@ -240,9 +247,37 @@ def setup_paths_and_config(
240247 subset .save_to_disk (data_str )
241248 print (f"Generated pile-100 in { data_str } " )
242249
243- # Save config
244- with open (os .path .join (test_path , "index_config.json" ), "w" ) as f :
245- json .dump (asdict (cfg ), f , indent = 4 )
250+ config_path = os .path .join (test_path , "index_config.json" )
251+ if os .path .exists (config_path ):
252+ if not overwrite :
253+ # Load existing config and compare
254+ with open (config_path , "r" ) as f :
255+ existing_cfg_dict = json .load (f )
256+
257+ new_cfg_dict = asdict (cfg )
258+
259+ if existing_cfg_dict != new_cfg_dict :
260+ # Show differences for debugging
261+ diffs = [
262+ f" { k } : { existing_cfg_dict [k ]} != { new_cfg_dict [k ]} "
263+ for k in new_cfg_dict
264+ if k in existing_cfg_dict and existing_cfg_dict [k ] != new_cfg_dict [k ]
265+ ]
266+ raise RuntimeError (
267+ f"Existing config at { config_path } differs from requested config:\n "
268+ + "\n " .join (diffs )
269+ + "\n \n Use --overwrite to replace the existing config."
270+ )
271+
272+ print (f"Using existing config from { config_path } " )
273+ else :
274+ print (f"Overwriting existing config at { config_path } " )
275+ with open (config_path , "w" ) as f :
276+ json .dump (asdict (cfg ), f , indent = 4 )
277+ else :
278+ # Save new config
279+ with open (config_path , "w" ) as f :
280+ json .dump (asdict (cfg ), f , indent = 4 )
246281
247282 # Setup
248283 workers = world_size
@@ -271,7 +306,7 @@ def setup_paths_and_config(
271306
272307if __name__ == "__main__" or TYPE_CHECKING :
273308 cfg , workers , device , target_modules , dtype = setup_paths_and_config (
274- precision , test_path , model_name , world_size_arg
309+ precision , test_path , model_name , world_size_arg , overwrite_arg
275310 )
276311
277312
0 commit comments