Skip to content

Commit 069684e

Browse files
McPatateCopilot
andauthored
feat(ci): add continuous batching to benchmarks (#41916)
* feat(ci): add continuous batching to benchmarks * refactor(ci): PR comments * refactor(cb): when stopping, block by default * fix(benchmarks): `stream` -> `streaming` * fix(benchmarks): invalid configuration when cb has attn_impl == sdpa * tests(cb): fix attn impl * fix(benchmarks): update `get_throughput` formula * fix(benchmarks): prevent version conflicts and ensure proper cleanup in continuous batching (#42063) * Initial plan * fix(benchmarks): ensure proper cleanup and remove transformers from requirements - Remove transformers from benchmark_v2/requirements.txt to prevent version conflicts - Add try-finally block to ensure ContinuousBatchingManager.stop() is always called - This fixes TypeError about unexpected 'streaming' argument and prevents OOM from improper cleanup Co-authored-by: McPatate <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: McPatate <[email protected]> * fix(benchmarks): raise the exception on failure instead of ignoring we catch the exception later on and raising it here helps debugging because it will be logged * test(cb): comment out failing tests for now added a `FIXME` mark * fix(benchmarks): revert `finally` removal but keep raising exception * test(cb): fix missing `require_read_token` import * refactor(benchmarks): error if no benchmarks were run * refactor(benchmarks): change default lvls of cb bench config --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: McPatate <[email protected]>
1 parent a127710 commit 069684e

File tree

10 files changed

+190
-107
lines changed

10 files changed

+190
-107
lines changed

.github/workflows/benchmark.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ jobs:
3232
options: --gpus all --privileged --ipc host
3333
steps:
3434
- name: Get repo
35-
uses: actions/checkout@v4
35+
uses: actions/checkout@v5
3636
with:
37-
ref: ${{ github.event.pull_request.head.sha || github.sha }}
37+
fetch-depth: 1
3838

3939
- name: Install benchmark script dependencies
4040
run: python3 -m pip install -r benchmark_v2/requirements.txt kernels
4141

4242
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
4343
working-directory: /transformers
44-
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]" && python3 -m pip uninstall -y torchvision # temp fix
44+
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]"
4545

4646
- name: Run benchmark
4747
run: |

benchmark/requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
gpustat==1.1.1
22
psutil==6.0.0
33
psycopg2==2.9.9
4-
torch>=2.4.0
54
hf_xet
6-
pandas>=1.5.0
5+
pandas>=1.5.0

benchmark_v2/framework/benchmark_config.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
warmup_iterations: int = 5,
3737
measurement_iterations: int = 20,
3838
gpu_monitoring: bool = True, # NOTE: you may want to disable this at times as we have obsvered it could heavily slow down benchmarks on AMD
39+
continuous_batching: bool = False,
3940
batch_size: int = 1,
4041
sequence_length: int = 128,
4142
num_tokens_to_generate: int = 128,
@@ -51,6 +52,7 @@ def __init__(
5152
self.warmup_iterations = warmup_iterations
5253
self.measurement_iterations = measurement_iterations
5354
self.gpu_monitoring = gpu_monitoring
55+
self.continuous_batching = continuous_batching
5456
# Input parameters
5557
self.batch_size = batch_size
5658
self.sequence_length = sequence_length
@@ -85,6 +87,22 @@ def check_validity(self, skip_validity_check: bool = False) -> None:
8587
if is_fa:
8688
logger.warning("Flash attention does not support compile mode. Turning off compile mode.")
8789
self.compile_mode = None
90+
# Handle SDPA backend if not determined by the config (needs to be done before skipping duplicates)
91+
if self.attn_implementation == "sdpa" and self.sdpa_backend is None:
92+
default_backend = "flash_attention" # FIXME: torch has a _cur_sdpa_kernel_backends but it fails
93+
logger.warning(f"No SDPA backend provided, using {default_backend} instead.")
94+
self.sdpa_backend = default_backend
95+
if self.continuous_batching:
96+
if self.attn_implementation == "flex_attention":
97+
logger.error(
98+
"disabling continuous batching because of invalid configuration: flex attention is not supported"
99+
)
100+
self.continuous_batching = False
101+
elif self.attn_implementation == "sdpa" and self.sdpa_backend is not None:
102+
logger.warning(
103+
"when continuous batching is enabled, sdpa_backend must be None because of the attention mask, setting it to None"
104+
)
105+
self.sdpa_backend = "math"
88106

89107
@property
90108
def hash(self) -> str:
@@ -100,6 +118,7 @@ def infer_name(self, compact: bool = True) -> str:
100118
attn_code += f"_{self.sdpa_backend}" if self.attn_implementation == "sdpa" else ""
101119
compile_str = f"compiled_{self.compile_mode}" if self.compile_mode is not None else "uncompiled"
102120
kernelize_str = "kernelized" if self.kernelize else "unkernelized"
121+
continuous_batching_str = "cb" if self.continuous_batching else "generate"
103122
sep = "-"
104123
else:
105124
iter_str = f"{self.warmup_iterations} warmup, {self.measurement_iterations} iterations"
@@ -109,15 +128,19 @@ def infer_name(self, compact: bool = True) -> str:
109128
attn_code += f" with {self.sdpa_backend} backend" if self.attn_implementation == "sdpa" else ""
110129
compile_str = "compiled" if self.compile_mode is not None else "not compiled"
111130
kernelize_str = "kernelized" if self.kernelize else "not kernelized"
131+
continuous_batching_str = "continuous batching" if self.continuous_batching else "regular generate"
112132
sep = ", "
113-
return sep.join([iter_str, gpu_monitor_str, dimensions_str, attn_code, compile_str, kernelize_str])
133+
return sep.join(
134+
[iter_str, gpu_monitor_str, dimensions_str, attn_code, compile_str, kernelize_str, continuous_batching_str]
135+
)
114136

115137
def to_dict(self) -> dict[str, Any]:
116138
return {
117139
"name": self.name,
118140
"warmup_iterations": self.warmup_iterations,
119141
"measurement_iterations": self.measurement_iterations,
120142
"gpu_monitoring": self.gpu_monitoring,
143+
"continuous_batching": self.continuous_batching,
121144
"batch_size": self.batch_size,
122145
"sequence_length": self.sequence_length,
123146
"num_tokens_to_generate": self.num_tokens_to_generate,
@@ -134,6 +157,7 @@ def from_dict(cls, data: dict[str, Any], skip_validity_check: bool = False) -> "
134157
warmup_iterations=data.get("warmup_iterations", 5),
135158
measurement_iterations=data.get("measurement_iterations", 20),
136159
gpu_monitoring=data.get("gpu_monitoring", False),
160+
continuous_batching=data.get("continuous_batching", False),
137161
batch_size=data.get("batch_size", 1),
138162
sequence_length=data.get("sequence_length", 128),
139163
num_tokens_to_generate=data.get("num_tokens_to_generate", 128),
@@ -191,24 +215,28 @@ def get_config_by_level(level: int) -> list[BenchmarkConfig]:
191215
# Usually there is not much to gain by compiling with other modes, but we allow it for level 4
192216
compile_modes = BenchmarkConfig.all_compiled_modes if level >= 4 else [None, "default"]
193217
for cm in compile_modes:
194-
for kernelize_on in [False, KERNELIZATION_AVAILABLE]:
195-
configs.append(
196-
BenchmarkConfig(
197-
attn_implementation=attn_implementation,
198-
sdpa_backend=sdpa_backend,
199-
compile_mode=cm,
200-
kernelize=kernelize_on,
218+
for kernelize_on in {False, KERNELIZATION_AVAILABLE}:
219+
for cb_on in [False, True]:
220+
configs.append(
221+
BenchmarkConfig(
222+
attn_implementation=attn_implementation,
223+
sdpa_backend=sdpa_backend,
224+
compile_mode=cm,
225+
kernelize=kernelize_on,
226+
continuous_batching=cb_on,
227+
)
201228
)
202-
)
203229
return configs
204230
# Otherwise, we add the configs for the given level
205231
if level >= 0:
206232
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default"))
207233
if level >= 1:
208234
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2"))
209235
configs.append(BenchmarkConfig(attn_implementation="eager", compile_mode="default"))
236+
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", continuous_batching=True))
210237
if level >= 2:
211238
configs.append(BenchmarkConfig(attn_implementation="sdpa", compile_mode="default"))
212239
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", kernelize=True))
213240
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", kernelize=True))
241+
configs.append(BenchmarkConfig(attn_implementation="paged|sdpa", continuous_batching=True))
214242
return configs

benchmark_v2/framework/benchmark_runner.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,9 @@ def run_benchmark(
234234
self.logger.info(f"Running benchmark scenario: {config.name}")
235235

236236
# Quick validation: try one measurement first to see if this scenario works
237+
generate_fn = self.time_generate_batch if config.continuous_batching else self.time_generate
237238
flush_memory()
238-
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = self.time_generate(
239+
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = generate_fn(
239240
max_new_tokens=1, gpu_monitor=None
240241
)
241242
if e2e_latency < 0:
@@ -245,14 +246,14 @@ def run_benchmark(
245246
# Warmup runs
246247
self.logger.info(f"Warming up with {config.warmup_iterations} iterations...")
247248
for _ in trange(config.warmup_iterations):
248-
_ = self.time_generate(max_new_tokens=config.num_tokens_to_generate)
249+
_ = generate_fn(max_new_tokens=config.num_tokens_to_generate)
249250
self.logger.info("Warmup over.")
250251

251252
# Measurement runs
252253
result = BenchmarkResult()
253254
self.logger.info(f"Benchmarking with {config.measurement_iterations} iterations.")
254255
for _ in trange(config.measurement_iterations):
255-
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = self.time_generate(
256+
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = generate_fn(
256257
max_new_tokens=config.num_tokens_to_generate,
257258
gpu_monitor=(GPUMonitor(logger=self.logger) if config.gpu_monitoring else None),
258259
)
@@ -274,6 +275,58 @@ def run_benchmark(
274275
"config": config,
275276
}
276277

278+
# TODO: refactor `generate_batch` to handle streaming so we can use it here
279+
def time_generate_batch(
280+
self,
281+
max_new_tokens: int,
282+
gpu_monitor: GPUMonitor | None = None,
283+
) -> tuple[float, list[float], str, GPURawMetrics | None]:
284+
if gpu_monitor is not None:
285+
gpu_monitor.start()
286+
config = GenerationConfig(
287+
max_new_tokens=max_new_tokens,
288+
eos_token_id=self.tokenizer.eos_token_id,
289+
pad_token_id=self.tokenizer.pad_token_id,
290+
do_sample=True,
291+
)
292+
manager = self.model.init_continuous_batching(config)
293+
manager.start()
294+
try:
295+
first_req_results = []
296+
timestamps = []
297+
wall_time_0 = time.perf_counter()
298+
inputs = self.inputs["input_ids"].tolist()
299+
manager.add_requests(inputs, max_new_tokens=max_new_tokens, streaming=True)
300+
first_req_id = None
301+
num_requests = len(inputs)
302+
finished_requests = 0
303+
while finished_requests < num_requests:
304+
# NOTE: I don't like having the extra if stmt here, but hopefully won't degrade perf too much
305+
result = manager.get_result()
306+
if result:
307+
timestamps.append(time.perf_counter() - wall_time_0)
308+
if result.is_finished():
309+
finished_requests += 1
310+
if first_req_id is None:
311+
first_req_id = result.request_id
312+
if result.request_id == first_req_id:
313+
first_req_results.append(result)
314+
else:
315+
if not manager.is_running():
316+
raise RuntimeError("Generation thread exited unexpectedly")
317+
wall_time_1 = time.perf_counter()
318+
gpu_metrics = gpu_monitor.stop_and_collect() if gpu_monitor is not None else None
319+
decoded_output = self.tokenizer.decode(
320+
[res.generated_tokens[0] for res in first_req_results], skip_special_tokens=True
321+
)
322+
shape_and_decoded_output = f"{(1, len(first_req_results))} | {decoded_output}"
323+
e2e_latency = wall_time_1 - wall_time_0
324+
return e2e_latency, timestamps, shape_and_decoded_output, gpu_metrics
325+
except Exception as e:
326+
raise e
327+
finally:
328+
manager.stop()
329+
277330
def time_generate(
278331
self,
279332
max_new_tokens: int,
@@ -339,12 +392,6 @@ def run_benchmarks(
339392

340393
n_configs = len(benchmark_configs)
341394
for i, config in enumerate(benchmark_configs):
342-
# Handle SDPA backend if not determined by the config (needs to be done before skipping duplicates)
343-
if config.attn_implementation == "sdpa" and config.sdpa_backend is None:
344-
default_backend = "flash_attention" # FIXME: torch has a _cur_sdpa_kernel_backends but it fails
345-
self.logger.warning(f"No SDPA backend provided, using {default_backend} instead.")
346-
config.sdpa_backend = default_backend
347-
348395
# Skip if already run
349396
if config.hash in all_results:
350397
self.logger.info(f"Skipping duplicate config {config.name} for model {model_id} ({i + 1}/{n_configs})")
@@ -368,21 +415,27 @@ def run_benchmarks(
368415
self.cleanup()
369416
self.save_results(model_id, all_results, timestamp=timestamp)
370417

418+
if len(all_results) < 1:
419+
raise RuntimeError("No benchmark was run succesfully")
420+
371421
if pretty_print_summary:
372422
print()
373423
print("=" * 100)
374424
print(f"Finished benchmarks in {time.perf_counter() - start_time:.2f} seconds")
375425
print(f"Total number of benchmarks: {len(all_results)}")
376-
if len(all_results) > 0:
377-
print("First run metadata:")
378-
first_key = list(all_results.keys())[0]
379-
first_metadata = all_results[first_key]["metadata"].to_dict()
380-
hardware_info = first_metadata.pop("hardware_info")
381-
pretty_print_dict(first_metadata | hardware_info, tabs=1)
426+
print("First run metadata:")
427+
first_key = list(all_results.keys())[0]
428+
first_metadata = all_results[first_key]["metadata"].to_dict()
429+
hardware_info = first_metadata.pop("hardware_info")
430+
pretty_print_dict(first_metadata | hardware_info, tabs=1)
382431
for result in all_results.values():
383432
print("=" * 100)
384433
print(f"Config: {result['config'].infer_name(compact=False)}\n")
385-
result["measurements"].pprint(batch_size=result["config"].batch_size, tabs=1)
434+
result["measurements"].pprint(
435+
batch_size=result["config"].batch_size,
436+
num_generated_tokens=result["config"].num_tokens_to_generate,
437+
tabs=1,
438+
)
386439
print("=" * 100)
387440

388441
return (timestamp, all_results)

benchmark_v2/framework/data_classes.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,17 @@ def add_unit_to_duration(stats: dict[str, float]) -> dict[str, str]:
3636
return stats
3737

3838

39-
def equalize_lengths_and_collate(stats: list[dict[str, str]]) -> list[str]:
39+
def equalize_lengths_and_collate(stats: dict[str, dict[str, str]]) -> dict[str, str]:
40+
"""Note: This operation is destructive as it will update values in place before returning a new correctly formatted dict"""
4041
keys = ["avg", "std", "min", "med", "max", "p95"]
4142
for key in keys:
42-
max_length = max(len(stat[key]) for stat in stats)
43-
for stat in stats:
43+
max_length = max(len(stat[key]) for stat in stats.values())
44+
for stat in stats.values():
4445
stat[key] = stat[key].ljust(max_length, " ")
45-
return [" ".join([f"{key}={stat[key]}" for key in keys]) for stat in stats]
46+
return {name: " ".join([f"{key}={stat[key]}" for key in keys]) for name, stat in stats.items()}
4647

4748

48-
def pretty_print_dict(data: dict[str, Any], tabs: int = 0) -> None:
49+
def pretty_print_dict(data: dict[str, str], tabs: int = 0) -> None:
4950
max_key_length = max([len(key) for key in data.keys()])
5051
for key, value in data.items():
5152
tabs_str = " " * tabs
@@ -141,27 +142,19 @@ def get_measured_ttft(self) -> list[float]:
141142
def get_measured_itl(self) -> list[float]:
142143
return [(dt[-1] - dt[0]) / (len(dt) - 1) for dt in self.token_generation_times if len(dt) > 1]
143144

144-
def get_throughput(self, batch_size: int) -> float:
145-
return [
146-
batch_size * len(dt) / e2e_latency
147-
for e2e_latency, dt in zip(self.e2e_latency, self.token_generation_times)
148-
]
149-
150-
def pprint(self, batch_size: int = 0, tabs: int = 0) -> None:
151-
stats_to_collate = [
152-
add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
153-
add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
154-
add_unit_to_duration(compute_basic_statistics(self.get_measured_itl())),
155-
]
156-
if batch_size > 0:
157-
throughput_stats = compute_basic_statistics(self.get_throughput(batch_size))
158-
stats_to_collate.append({key: f"{value:.2f}tok/s" for key, value in throughput_stats.items()})
159-
collated_stats = equalize_lengths_and_collate(stats_to_collate)
160-
dict_to_pprint = {
161-
"E2E Latency": collated_stats[0],
162-
"Time to First Token": collated_stats[1],
163-
"Inter-Token Latency": collated_stats[2],
145+
def get_throughput(self, total_generated_tokens: int) -> list[float]:
146+
return [total_generated_tokens / e2e_latency for e2e_latency in self.e2e_latency]
147+
148+
def pprint(self, batch_size: int = 0, num_generated_tokens: int = 0, tabs: int = 0) -> None:
149+
measurements = {
150+
"E2E Latency": add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
151+
"Time to First Token": add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
164152
}
153+
itl_values = self.get_measured_itl()
154+
if len(itl_values) > 0:
155+
measurements["Inter-Token Latency"] = add_unit_to_duration(compute_basic_statistics(itl_values))
165156
if batch_size > 0:
166-
dict_to_pprint["Throughput"] = collated_stats[3]
157+
throughput_stats = compute_basic_statistics(self.get_throughput(batch_size * num_generated_tokens))
158+
measurements["Throughput"] = {key: f"{value:.2f}tok/s" for key, value in throughput_stats.items()}
159+
dict_to_pprint = equalize_lengths_and_collate(measurements)
167160
pretty_print_dict(dict_to_pprint, tabs=tabs)

benchmark_v2/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,5 @@ numpy>=1.21.0
22
psutil>=5.8.0
33
gpustat>=1.0.0
44
torch>=2.0.0
5-
transformers>=4.30.0
65
datasets>=2.10.0
76
huggingface_hub>=0.16.0

benchmark_v2/run_benchmarks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@
8080
logger.info(f"Benchmark run UUID: {benchmark_run_uuid}")
8181
logger.info(f"Output directory: {args.output_dir}")
8282

83+
# We cannot compute ITL if we don't have at least two measurements
84+
if any(n <= 1 for n in args.num_tokens_to_generate):
85+
raise ValueError("--num_tokens_to_generate arguments should be larger than 1")
86+
8387
# Error out if one of the arguments is not provided
8488
if len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 0:
8589
raise ValueError(

0 commit comments

Comments
 (0)