From 19d9f8ce830adee5ad4401e7ecb4da57e460a706 Mon Sep 17 00:00:00 2001 From: Yinzuo Jiang Date: Fri, 17 May 2024 15:42:48 +0800 Subject: [PATCH 1/2] fix: convert X to numpy.array in glove dataset function --- ann_benchmarks/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ann_benchmarks/datasets.py b/ann_benchmarks/datasets.py index 21e6efb60..590fbb4fb 100644 --- a/ann_benchmarks/datasets.py +++ b/ann_benchmarks/datasets.py @@ -213,7 +213,7 @@ def glove(out_fn: str, d: int) -> None: for line in z.open(z_fn): v = [float(x) for x in line.strip().split()[1:]] X.append(numpy.array(v)) - X_train, X_test = train_test_split(X) + X_train, X_test = train_test_split(numpy.array(X)) write_output(numpy.array(X_train), numpy.array(X_test), out_fn, "angular") From 3982de0c267c9e3de65044b5c55f143917df7023 Mon Sep 17 00:00:00 2001 From: Yinzuo Jiang Date: Fri, 17 May 2024 15:45:03 +0800 Subject: [PATCH 2/2] add `--count` and `--batch` args for data_export.py --- ann_benchmarks/results.py | 4 ++-- data_export.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/ann_benchmarks/results.py b/ann_benchmarks/results.py index e7e5b2e7f..b2e288f0d 100644 --- a/ann_benchmarks/results.py +++ b/ann_benchmarks/results.py @@ -84,7 +84,7 @@ def load_all_results(dataset: Optional[str] = None, Yields: tuple: A tuple containing properties as a dictionary and an h5py file object. """ - for root, _, files in os.walk(build_result_filepath(dataset, count)): + for root, _, files in os.walk(build_result_filepath(dataset, count, batch_mode=batch_mode)): for filename in files: if os.path.splitext(filename)[-1] != ".hdf5": continue @@ -110,4 +110,4 @@ def get_unique_algorithms() -> Set[str]: for batch_mode in [False, True]: for properties, _ in load_all_results(batch_mode=batch_mode): algorithms.add(properties["algo"]) - return algorithms \ No newline at end of file + return algorithms diff --git a/data_export.py b/data_export.py index 343f3acc3..416472a84 100644 --- a/data_export.py +++ b/data_export.py @@ -9,14 +9,21 @@ parser = argparse.ArgumentParser() parser.add_argument("--output", help="Path to the output file", required=True) parser.add_argument("--recompute", action="store_true", help="Recompute metrics") + parser.add_argument( + "-k", "--count", default=10, type=int, help="The number of near neighbours to search for" + ) + parser.add_argument("--batch", action="store_true", help="Batch mode") args = parser.parse_args() datasets = DATASETS.keys() dfs = [] for dataset_name in datasets: print("Looking at dataset", dataset_name) - if len(list(load_all_results(dataset_name))) > 0: - results = load_all_results(dataset_name) + if len(list(load_all_results(dataset_name, + count=args.count, + batch_mode=args.batch + ))) > 0: + results = load_all_results(dataset_name, count=args.count, batch_mode=args.batch) dataset, _ = get_dataset(dataset_name) results = compute_metrics_all_runs(dataset, results, args.recompute) for res in results: