diff --git a/README.md b/README.md index 5ba172b..e559c80 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ -# Sharp Frame Extractor +# Sharp Frame Extractor [![PyPI](https://img.shields.io/pypi/v/sharp-frame-extractor)](https://pypi.org/project/sharp-frame-extractor/) Sharp Frame Extractor is a command line utility for sampling videos into still images using sharpness scoring. It processes the input in short time windows and writes the highest scoring frame from each window to disk, which is useful for photogrammetry, volumetric capture, and similar pipelines. Version 2 focuses on: - A simpler command line interface - Better sharpness scoring -- Faster, smoother processing via improved parallelism +- Two-pass architecture for memory-efficient analysis and faster processing ## Example @@ -96,19 +96,26 @@ Outputs: ### Performance tuning -There are two layers of parallelism: +By default, the extractor automatically chooses performance settings based on the workload and the available hardware. The options below let you override those defaults when you want more direct control. -- `-j/--jobs` = how many videos are processed at the same time. Each job is mainly an orchestrator: it drives ffmpeg frame decoding and feeds blocks into the analysis pipeline. -- `--workers` = how many analysis workers run in parallel. Workers are separate processes that run the sharpness scoring. +There are three main tuning knobs, with two layers of parallelism: + +* `-j/--jobs` (`max_video_jobs`) = how many videos are processed at the same time. Each job mainly acts as an orchestrator: it drives frame decoding and hands blocks to the analysis stage. +* `-w/--workers` (`max_workers`) = how many analysis workers run in parallel. Workers are separate processes that perform the CPU intensive sharpness scoring and are shared across all jobs. +* `-m/--memory-limit` (`memory_limit_mb`) = the total memory budget for frame buffers. This limit is split across active jobs, so increasing `--jobs` reduces the buffer size available per video. How the pipeline behaves: -- A job processes a video block by block. -- Each block needs an available worker to be analyzed. -- If no worker is available, the job waits and does not keep decoding more blocks. + +* A job processes a video block by block. +* Each block needs an available worker to be analyzed. +* If no worker is available, the job waits and does not keep decoding more blocks. +* Frame buffering is bounded by the global memory limit, preventing unbounded memory growth when many jobs are active. Practical guidance: -- Processing a single video: keep `--jobs 1` and tune `--workers` (this usually controls total throughput). -- Processing many videos: pick a sensible `--workers` value first (often around your CPU core count), then increase `--jobs` until the workers stay busy. If CPU is already pegged, raising `--jobs` will mostly add overhead without speeding things up. + +* Processing a single video: keep `--jobs 1` and tune `--workers`. This usually controls total throughput. +* Processing many videos: pick a sensible `--workers` value first, often close to your CPU core count, then increase `--jobs` until the workers stay busy. If the CPU is already fully utilized, increasing `--jobs` will mostly add overhead without speeding things up. +* If you run many jobs at once and see increased waiting or reduced throughput, consider raising the memory limit so each job has enough buffering. Example: @@ -133,28 +140,28 @@ sharp-frame-extractor --help ``` ```text -usage: sharp-frame-extractor [-h] [-o DIR] (--count N | --every SECONDS) - [-j N] [--workers N] +Usage: sharp-frame-extractor [-h] [-o DIR] (--count N | --every SECONDS) [-j N] [-w N] + [-m MEMORY_MB] VIDEO [VIDEO ...] -Extract the sharpest frame from regular blocks of a video. -Choose exactly one sampling mode: --count or --every. +Extract sharp frames from a video by scoring frames within blocks. Choose exactly one +sampling mode: --count or --every. -positional arguments: +Positional Arguments: VIDEO One or more input video files. -options: +Options: -h, --help show this help message and exit - -o DIR, --output DIR Base output directory. If omitted, outputs are written - to "//". If set, outputs are - written to "//". - --count N Target number of frames to extract per input video. - --every SECONDS Extract one sharp frame every N seconds. Supports - decimals, for example 0.25. - -j N, --jobs N Max number of videos processed in parallel. Default: - 4. - --workers N Max number of frame analyzer workers. Default: 5. - + -o, --output DIR Base output directory. If omitted, outputs are written to + "//". If set, outputs are written to + "//". (default: None) + --count N Target number of frames to extract per input video. (default: None) + --every SECONDS Extract one sharp frame every N seconds. Supports decimals, for example 0.25. (default: None) + -j, --jobs N Max number of videos processed in parallel (video jobs). (default: 4) + -w, --workers N Total analysis worker processes shared across all video jobs. (default: 8) + -m, --memory-limit MEMORY_MB + Global memory limit for frame buffers in MB (shared across jobs). (default: 52428) + Examples: Extract frames by target count: sharp-frame-extractor input.mp4 --count 300 diff --git a/pyproject.toml b/pyproject.toml index 9ccff7c..f13ab1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sharp-frame-extractor" -version = "2.0.0-rc.1" +version = "2.0.0-rc.2" description = "Extracts sharp frames from a video." readme = "README.md" license = { file = "LICENSE" } @@ -33,6 +33,7 @@ dependencies = [ "ffmpegio>=0.11.1", "numpy>=2.4.0", "opencv-python>=4.11.0.86", + "psutil>=7.2.1", "rich>=14.2.0", "rich-argparse>=1.7.2", ] @@ -77,4 +78,4 @@ exclude = [] packages = ["sharp_frame_extractor"] [project.scripts] -sharp-frame-extractor = "sharp_frame_extractor.__main__:main" \ No newline at end of file +sharp-frame-extractor = "sharp_frame_extractor.__main__:main" diff --git a/sharp_frame_extractor/SharpFrameExtractor.py b/sharp_frame_extractor/SharpFrameExtractor.py new file mode 100644 index 0000000..4943cd2 --- /dev/null +++ b/sharp_frame_extractor/SharpFrameExtractor.py @@ -0,0 +1,338 @@ +import math +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import partial +from itertools import chain +from typing import Self, Sequence + +import ffmpegio +import numpy as np + +from sharp_frame_extractor.analyzer.frame_analyzer_base import FrameAnalyzerResult, FrameAnalyzerTask +from sharp_frame_extractor.analyzer.frame_analyzer_pool import FrameAnalyzerWorkerPool +from sharp_frame_extractor.args_utils import MIN_MEMORY_LIMIT, default_concurrency, default_memory_limit_mb +from sharp_frame_extractor.event import Event +from sharp_frame_extractor.memory.shared_ndarray import SharedNDArrayRef, SharedNDArrayStoreBase +from sharp_frame_extractor.memory.shared_ndarray_pool import PooledSharedNDArrayStore +from sharp_frame_extractor.models import ( + BlockAnalyzedEvent, + BlockEvent, + BlockFrameExtracted, + ExtractionResult, + ExtractionTask, + TaskAnalyzedEvent, + TaskEvent, + TaskFinishedEvent, + TaskPreparedEvent, + TaskStartedEvent, + VideoFrameInfo, +) +from sharp_frame_extractor.output.frame_output_handler_base import FrameOutputHandlerBase +from sharp_frame_extractor.worker.Future import Future + + +class SharpFrameExtractor: + def __init__( + self, + output_handlers: Sequence[FrameOutputHandlerBase], + max_video_jobs: int | None = None, + max_workers: int | None = None, + memory_limit_mb: int | None = None, + ): + default_jobs, default_workers = default_concurrency() + default_memory_limit = default_memory_limit_mb() + + self._output_handlers = output_handlers + + self._max_video_jobs = max_video_jobs or default_jobs + self._max_workers = max_workers or default_workers + self._total_memory_limit_mb = memory_limit_mb or default_memory_limit + self.memory_limit_per_job_mb = max( + MIN_MEMORY_LIMIT, math.ceil(self._total_memory_limit_mb / self._max_video_jobs) + ) + + self._analyzer_pool = FrameAnalyzerWorkerPool(self._max_workers) + + # callbacks + self.on_task_event: Event[TaskEvent] = Event() + self.on_block_event: Event[BlockEvent] = Event() + + # internal defaults + self._preferred_block_size = 32 + self._analysis_pixel_format = "gray" + self._analysis_channels = 1 + self._extraction_pixel_format = "rgb24" + self._extraction_channels = 3 + + def start(self): + self._analyzer_pool.start() + + for handler in self._output_handlers: + handler.open() + + def process(self, tasks: list[ExtractionTask]) -> list[ExtractionResult]: + results: list[ExtractionResult] = [] + + # Sequential execution for debugging or single worker + if self._max_video_jobs <= 1: + for task in tasks: + result = self._process_extraction_task(task) + results.append(result) + return results + + # Parallel threaded execution with ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=self._max_video_jobs) as executor: + futures = {} + for task in tasks: + # Submit tasks to executor and track their futures + future = executor.submit(self._process_extraction_task, task) + futures[future] = task + + # Process tasks as workers become available + for future in as_completed(futures): + # Wait for the future to complete + result = future.result() + results.append(result) + + # order results by input id + results.sort(key=lambda r: r.task_id) + + return results + + def stop(self): + self._analyzer_pool.stop() + + for handler in self._output_handlers: + handler.close() + + def _process_extraction_task(self, task: ExtractionTask) -> ExtractionResult: + self.on_task_event(TaskStartedEvent(task)) + + video_path = task.video_path + options = task.options + + # read stream info + video_streams = ffmpegio.probe.video_streams_basic(str(video_path)) + video_info = video_streams[0] + + # extract video information + video_duration_seconds = float(video_info["duration"]) + video_fps = float(video_info["frame_rate"]) + + video_width = int(video_info["width"]) + video_height = int(video_info["height"]) + + if "nb_frames" in video_info: + total_video_frames = int(video_info["nb_frames"]) + else: + total_video_frames = math.ceil(video_duration_seconds * video_fps) + + # calculate frame interval for selecting the amount of output frames + if options.frame_interval_seconds is not None: + frame_interval = max(1, int(round(options.frame_interval_seconds * video_fps))) + elif options.total_frame_count is not None: + frame_interval = max(1, int(math.ceil(total_video_frames / options.total_frame_count))) + else: + raise ValueError('Please provide either "--every" or "--count".') + + # total frames to extract + total_frames = int(math.ceil(total_video_frames / frame_interval)) + + # calculate stream block size + possible_block_size = self._calculate_block_size( + video_width, video_height, self._extraction_channels, self.memory_limit_per_job_mb + ) + + # Distribute memory among the worker buffers + max_block_size_per_worker = max(1, possible_block_size // self._max_workers) + stream_block_size = min(self._preferred_block_size, max_block_size_per_worker) + + # setup progress bar for analysis + total_sub_tasks = int(math.ceil(total_video_frames / stream_block_size)) + self.on_task_event(TaskPreparedEvent(task, total_blocks=total_sub_tasks, total_frames=total_frames)) + + # prepare shared memory store + buffer_size = video_width * video_height * self._analysis_channels * stream_block_size + + # limit buffers to max workers to prevent over-allocation + with PooledSharedNDArrayStore(item_size=buffer_size, n_buffers=self._max_workers) as store: + # analyze video first + interval_ids, frame_ids, scores = self._analyze_frames(task, stream_block_size, frame_interval, store) + self.on_task_event(TaskAnalyzedEvent(task, total_blocks=total_sub_tasks, total_frames=total_frames)) + + # extraction run + self._extract_frames(task, stream_block_size, interval_ids, frame_ids, scores) + + self.on_task_event(TaskFinishedEvent(task)) + return ExtractionResult(task.task_id) + + def _analyze_frames( + self, task: ExtractionTask, stream_block_size: int, frame_interval: int, store: SharedNDArrayStoreBase + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + submitted_tasks: list[Future] = [] + analysis_results: list[FrameAnalyzerResult] = [] + + # analysis run + block_index = 0 + with ffmpegio.open( + str(task.video_path), "rv", blocksize=stream_block_size, pix_fmt=self._analysis_pixel_format + ) as fin: + for frames in fin: + # create shared memory + shared_memory_ref = store.put(frames, worker_writeable=False) + + # analyze video block + worker_task = self._analyzer_pool.submit_task(FrameAnalyzerTask(block_index, shared_memory_ref)) + worker_task.add_done_callback( + partial( + self._on_block_finished, + results=analysis_results, + task=task, + shared_memory_ref=shared_memory_ref, + store=store, + ) + ) + submitted_tasks.append(worker_task) + + block_index += 1 + + # wait for all tasks to be done + for worker_task in submitted_tasks: + worker_task.wait() + worker_task.clear() + + # select best frames per interval + analysis_results.sort(key=lambda e: e.block_index) + raw_frame_scores = list(chain.from_iterable(r.scores for r in analysis_results)) + best_frames_per_interval = self._select_best_frames_per_interval(raw_frame_scores, frame_interval) + + return best_frames_per_interval + + def _on_block_finished( + self, + future: Future[FrameAnalyzerResult], + results: list[FrameAnalyzerResult], + task: ExtractionTask, + shared_memory_ref: SharedNDArrayRef, + store: SharedNDArrayStoreBase, + ): + # append result to results list + result = future.result() + + # todo: do we have to be careful here (regarding thread-safety)? + results.append(result) + + # release memory + store.release(shared_memory_ref) + self.on_block_event(BlockAnalyzedEvent(task, result.block_index, result)) + + def _extract_frames( + self, + task: ExtractionTask, + stream_block_size: int, + interval_ids: np.ndarray, + frame_ids: np.ndarray, + scores: np.ndarray, + ): + # setup output handlers for this task + for handler in self._output_handlers: + handler.prepare_task(task) + + global_start = 0 # first global frame index in current chunk + + with ffmpegio.open( + str(task.video_path), + "rv", + blocksize=stream_block_size, + pix_fmt=self._extraction_pixel_format, + ) as fin: + for block_index, frames in enumerate(fin): + block_len = len(frames) + if block_len == 0: + continue + + block_end = global_start + block_len # exclusive + + i0 = np.searchsorted(frame_ids, global_start, side="left") + i1 = np.searchsorted(frame_ids, block_end, side="left") + + if i0 == i1: + global_start = block_end + continue + + local_idxs = frame_ids[i0:i1] - global_start + + for k, local_idx in zip(range(i0, i1), local_idxs): + frame_id = int(frame_ids[k]) + interval_id = int(interval_ids[k]) + score = float(scores[k]) + + frame = frames[int(local_idx)] + + frame_info = VideoFrameInfo( + interval_index=interval_id, frame_index=frame_id, score=score, frame=frame + ) + for handler in self._output_handlers: + handler.handle_block(task, frame_info) + + self.on_block_event(BlockFrameExtracted(task=task, frame_info=frame_info)) + + global_start = block_end + + if i1 >= frame_ids.size: + break + + @staticmethod + def _calculate_block_size( + width: int, height: int, channels: int, memory_limit_mb: int, safe_factor: float = 0.8 + ) -> int: + # RGB24 = 3 bytes per pixel + frame_size_bytes = width * height * channels + memory_limit_bytes = memory_limit_mb * 1024 * 1024 + + # Allow using up to n% of the limit for the buffer to be safe + safe_memory_bytes = memory_limit_bytes * safe_factor + + count = int(safe_memory_bytes / frame_size_bytes) + return max(1, count) + + @staticmethod + def _select_best_frames_per_interval( + raw_frame_scores: list[float], + frame_interval: int, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + if frame_interval <= 0: + raise ValueError("frame_interval must be > 0") + + scores = np.asarray(raw_frame_scores, dtype=np.float32) + n = int(scores.size) + if n == 0: + return ( + np.empty((0,), dtype=np.int64), # interval_index + np.empty((0,), dtype=np.int64), # frame_index + np.empty((0,), dtype=np.float32), # score + ) + + interval_index = np.arange(n, dtype=np.int64) // frame_interval + num_intervals = int(interval_index[-1]) + 1 + + best_frame_index = np.zeros(num_intervals, dtype=np.int64) + best_score = np.full(num_intervals, -np.inf, dtype=np.float32) + + for frame_index in range(n): + ii = interval_index[frame_index] + s = scores[frame_index] + if s > best_score[ii]: + best_score[ii] = s + best_frame_index[ii] = frame_index + + out_interval_index = np.arange(num_intervals, dtype=np.int64) + + order = np.argsort(best_frame_index, kind="stable") + return out_interval_index[order], best_frame_index[order], best_score[order] + + def __enter__(self) -> Self: + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() diff --git a/sharp_frame_extractor/__main__.py b/sharp_frame_extractor/__main__.py index e5174e6..149919c 100644 --- a/sharp_frame_extractor/__main__.py +++ b/sharp_frame_extractor/__main__.py @@ -1,144 +1,48 @@ import argparse -import math -import os import time -from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass from datetime import timedelta +from functools import partial from pathlib import Path -import cv2 -import ffmpegio -import numpy as np from rich.console import Console -from rich.progress import Progress, TextColumn, BarColumn, TimeRemainingColumn, MofNCompleteColumn, TimeElapsedColumn - -from sharp_frame_extractor.analyzer.frame_analyzer_base import FrameAnalyzerTask, FrameAnalyzerResult -from sharp_frame_extractor.analyzer.frame_analyzer_pool import FrameAnalyzerWorkerPool -from sharp_frame_extractor.args_utils import positive_int, positive_float, default_concurrency -from sharp_frame_extractor.worker.Future import Future - -analyzer_pool: FrameAnalyzerWorkerPool | None = None - - -@dataclass -class ExtractionOptions: - # either one of the two have ot be set - frame_interval_seconds: float | None = None - total_frame_count: int | None = None - - -@dataclass -class ExtractionTask: - video_path: Path - result_path: Path - options: ExtractionOptions - - -def process_extraction_task(task: ExtractionTask, progress: Progress) -> None: - task_id = progress.add_task(description=f"analyzing {task.video_path.name}", total=None) - - video_path = task.video_path - result_path = task.result_path - options = task.options - - # read stream info - video_streams = ffmpegio.probe.video_streams_basic(str(video_path)) - video_info = video_streams[0] - - # extract video information - video_duration_seconds = float(video_info["duration"]) - video_fps = float(video_info["frame_rate"]) - - if "nb_frames" in video_info: - total_video_frames = int(video_info["nb_frames"]) - else: - total_video_frames = math.ceil(video_duration_seconds * video_fps) - - # calculate stream block size - if options.frame_interval_seconds is not None: - stream_block_size = max(1, int(round(options.frame_interval_seconds * video_fps))) - elif options.total_frame_count is not None: - stream_block_size = max(1, int(math.ceil(total_video_frames / options.total_frame_count))) - else: - progress.print('Please provide either "--every" or "--count".', style="bold yellow") - progress.stop_task(task_id) - return - - # ensure output path exists - result_path.mkdir(parents=True, exist_ok=True) - - # setup progress bar - total_sub_tasks = int(math.ceil(total_video_frames / stream_block_size)) - progress.update(task_id, total=total_sub_tasks, description=f"processing {task.video_path.name}") - - submitted_tasks: list[Future] = [] - - def on_task_finished(future: Future[FrameAnalyzerResult]): - result = future.result() - output_file_name = task.result_path / f"frame-{result.block_index:05d}.png" - - if output_file_name.exists(): - output_file_name.unlink(missing_ok=True) - - cv2.imwrite(str(output_file_name.absolute()), result.frame) - result.frame = None - progress.update(task_id, advance=1) - - # start reading video file - block_index = 0 - with ffmpegio.open(str(video_path), "rv", blocksize=stream_block_size, pix_fmt="rgb24") as fin: - for frames in fin: - # convert rgb to bgr frames - frames_bgr = np.empty_like(frames) - for i in range(frames.shape[0]): - frames_bgr[i] = cv2.cvtColor(frames[i], cv2.COLOR_RGB2BGR) - - # analyze video block - worker_task = analyzer_pool.submit_task(FrameAnalyzerTask(block_index, frames_bgr)) - worker_task.add_done_callback(on_task_finished) - submitted_tasks.append(worker_task) - - block_index += 1 - - # wait for all tasks to be done - for worker_task in submitted_tasks: - worker_task.result() - - progress.update(task_id, completed=total_sub_tasks) - progress.stop_task(task_id) - - -def cpu_count_fraction(factor: float, min_value: int = 1) -> int: - return max(min_value, int(os.cpu_count() * factor)) +from rich.progress import ( + MofNCompleteColumn, + Progress, + TaskID, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from rich_argparse import ArgumentDefaultsRichHelpFormatter + +from sharp_frame_extractor.args_utils import default_concurrency, default_memory_limit_mb, positive_float, positive_int +from sharp_frame_extractor.models import ( + BlockAnalyzedEvent, + BlockEvent, + BlockFrameExtracted, + ExtractionOptions, + TaskAnalyzedEvent, + TaskEvent, + TaskFinishedEvent, + TaskPreparedEvent, + TaskStartedEvent, +) +from sharp_frame_extractor.output.file_output_handler import FileOutputHandler +from sharp_frame_extractor.SharpFrameExtractor import ExtractionTask, SharpFrameExtractor +from sharp_frame_extractor.ui.progress_bar import StatefulBarColumn def parse_args() -> argparse.Namespace: - examples = """ -Examples: - Extract frames by target count: - sharp-frame-extractor input.mp4 --count 300 - - Extract one sharp frame every 0.25 seconds: - sharp-frame-extractor input.mp4 --every 0.25 - - Process multiple videos, outputs next to each input: - sharp-frame-extractor a.mp4 b.mp4 --count 100 - - Write outputs into a single base folder (per input subfolder): - sharp-frame-extractor a.mp4 b.mp4 -o out --every 2 -""" - default_jobs, default_workers = default_concurrency() + default_memory_limit = default_memory_limit_mb() parser = argparse.ArgumentParser( prog="sharp-frame-extractor", description=( - "Extract the sharpest frame from regular blocks of a video.\n" + "Extract sharp frames from a video by scoring frames within blocks.\n" "Choose exactly one sampling mode: --count or --every." ), - epilog=examples, - formatter_class=argparse.RawDescriptionHelpFormatter, + formatter_class=partial(ArgumentDefaultsRichHelpFormatter, width=90), ) parser.add_argument( @@ -179,16 +83,27 @@ def parse_args() -> argparse.Namespace: type=positive_int, default=default_jobs, metavar="N", - help=f"Max number of videos processed in parallel. Default: {default_jobs}.", + help="Max number of videos processed in parallel (video jobs).", ) parser.add_argument( + "-w", "--workers", dest="workers", type=positive_int, default=default_workers, metavar="N", - help=f"Max number of frame analyzer workers. Default: {default_workers}.", + help="Total analysis worker processes shared across all video jobs.", + ) + + parser.add_argument( + "-m", + "--memory-limit", + dest="memory_limit", + type=positive_int, + default=default_memory_limit, + metavar="MEMORY_MB", + help="Global memory limit for frame buffers in MB (shared across jobs).", ) return parser.parse_args() @@ -206,6 +121,7 @@ def main(): max_video_threads: int = int(args.jobs) max_workers: int = int(args.workers) + max_memory_limit_mb: int = int(args.memory_limit) if output_base_dir is not None: output_paths: list[Path] = [output_base_dir / p.stem for p in input_paths] @@ -213,9 +129,9 @@ def main(): output_paths = [p.parent / p.stem for p in input_paths] if every_seconds is not None: - default_options = ExtractionOptions(frame_interval_seconds=every_seconds, total_frame_count=None) + default_options = ExtractionOptions.from_interval(every_seconds) else: - default_options = ExtractionOptions(frame_interval_seconds=None, total_frame_count=count) + default_options = ExtractionOptions.from_count(count) # create tasks with console.status("creating tasks..."): @@ -229,46 +145,78 @@ def main(): max_video_threads = min(task_count, max_video_threads) # print processing info - console.print(f"Running {task_count} tasks with {max_video_threads} jobs and {max_workers} workers.") + console.print( + f"Running {task_count} tasks " + f"with {max_video_threads} jobs, " + f"{max_workers} workers " + f"and a memory limit of ~{max_memory_limit_mb / 1024:.1f} GB." + ) - # create pool - global analyzer_pool - analyzer_pool = FrameAnalyzerWorkerPool(max_workers) + # create output handler + output_handlers = [FileOutputHandler()] # run processing start_time = time.time() - analyzer_pool.start() - with Progress( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TimeElapsedColumn(), - TimeRemainingColumn(), - MofNCompleteColumn(), - ) as progress: - # Create an overall progress bar - overall_task_id = progress.add_task(description="Sharp Frame Extractor", total=task_count) - - # Sequential execution for debugging or single worker - if max_video_threads <= 1: - for task in tasks: - process_extraction_task(task, progress) - progress.advance(overall_task_id) - else: - # Parallel threaded execution with ThreadPoolExecutor - with ThreadPoolExecutor(max_workers=max_video_threads) as executor: - futures = {} - for task in tasks: - # Submit tasks to executor and track their futures - future = executor.submit(process_extraction_task, task, progress) - futures[future] = task - - # Process tasks as workers become available - for future in as_completed(futures): - # Wait for the future to complete - future.result() - progress.advance(overall_task_id) - - analyzer_pool.stop() + with SharpFrameExtractor(output_handlers, max_video_threads, max_workers, max_memory_limit_mb) as sfe: + with Progress( + TextColumn("[progress.description]{task.description}"), + StatefulBarColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + MofNCompleteColumn(), + ) as progress: + # Create an overall progress bar + main_task_id = progress.add_task( + description="[bold]Sharp Frame Extractor[/bold]", + total=task_count, + bar_complete_style="bright_white", + bar_finished_style="bright_white", + bar_pulse_style="bright_white", + ) + + task_to_progress_lut: dict[int, TaskID] = {} + + def _set_state(t: ExtractionTask, label: str, color: str) -> None: + progress_task_id = task_to_progress_lut[t.task_id] + progress.update( + progress_task_id, + description=f"[{color}]{label}[/{color}] [bold]{t.video_path.name}[/bold]", + bar_complete_style=color, + bar_finished_style=color, + bar_pulse_style=color, + ) + + # handle progress events + @sfe.on_task_event.register + def _on_task_event(event: TaskEvent): + if isinstance(event, TaskStartedEvent): + task_to_progress_lut[event.task.task_id] = progress.add_task( + description=f"{event.task.task_id}", total=None + ) + _set_state(event.task, "preparing", "gold1") + elif isinstance(event, TaskPreparedEvent): + progress.update( + task_to_progress_lut[event.task.task_id], total=event.total_blocks + event.total_frames + ) + _set_state(event.task, "analyzing", "slate_blue1") + elif isinstance(event, TaskAnalyzedEvent): + progress.update( + task_to_progress_lut[event.task.task_id], total=event.total_blocks + event.total_frames + ) + _set_state(event.task, "extracting", "dodger_blue1") + elif isinstance(event, TaskFinishedEvent): + _set_state(event.task, "done", "spring_green1") + progress.stop_task(task_to_progress_lut[event.task.task_id]) + progress.advance(main_task_id) + + @sfe.on_block_event.register + def _on_block_event(event: BlockEvent): + if isinstance(event, BlockAnalyzedEvent) or isinstance(event, BlockFrameExtracted): + progress.advance(task_to_progress_lut[event.task.task_id]) + + # run process + _ = sfe.process(tasks) + end_time = time.time() console.print(f"It took {str(timedelta(seconds=end_time - start_time))} to process {task_count} tasks.") diff --git a/sharp_frame_extractor/analyzer/frame_analyzer_base.py b/sharp_frame_extractor/analyzer/frame_analyzer_base.py index 79dec3a..3ca6287 100644 --- a/sharp_frame_extractor/analyzer/frame_analyzer_base.py +++ b/sharp_frame_extractor/analyzer/frame_analyzer_base.py @@ -3,19 +3,19 @@ import numpy as np +from sharp_frame_extractor.memory.shared_ndarray import SharedNDArrayRef + @dataclass class FrameAnalyzerTask: block_index: int - frames: np.ndarray + frames_ref: SharedNDArrayRef @dataclass class FrameAnalyzerResult: block_index: int - frame_index: int - frame: np.ndarray - score: float + scores: list[float] class FrameAnalyzerBase(ABC): @@ -24,5 +24,5 @@ def reset_states(self): pass @abstractmethod - def process(self, task: FrameAnalyzerTask) -> FrameAnalyzerResult: + def process(self, task: FrameAnalyzerTask, frames: np.ndarray) -> FrameAnalyzerResult: pass diff --git a/sharp_frame_extractor/analyzer/frame_analyzer_pool.py b/sharp_frame_extractor/analyzer/frame_analyzer_pool.py index ffd4ce6..a2ceac5 100644 --- a/sharp_frame_extractor/analyzer/frame_analyzer_pool.py +++ b/sharp_frame_extractor/analyzer/frame_analyzer_pool.py @@ -1,7 +1,8 @@ import logging -from sharp_frame_extractor.analyzer.frame_analyzer_base import FrameAnalyzerTask, FrameAnalyzerResult, FrameAnalyzerBase +from sharp_frame_extractor.analyzer.frame_analyzer_base import FrameAnalyzerBase, FrameAnalyzerResult, FrameAnalyzerTask from sharp_frame_extractor.analyzer.tenegrad_frame_analyzer import TenengradFrameAnalyzer +from sharp_frame_extractor.memory.shared_ndarray import SharedNDArray from sharp_frame_extractor.worker.BaseWorker import BaseWorker from sharp_frame_extractor.worker.BaseWorkerPool import BaseWorkerPool from sharp_frame_extractor.worker.Future import Future @@ -19,7 +20,12 @@ def setup(self): def handle_task(self, task: FrameAnalyzerTask) -> FrameAnalyzerResult: self.analyzer.reset_states() - return self.analyzer.process(task) + + with SharedNDArray.attach(task.frames_ref) as shared: + frames = shared.ndarray # view into shared memory + result = self.analyzer.process(task, frames) + + return result def cleanup(self): self.analyzer = None diff --git a/sharp_frame_extractor/analyzer/tenegrad_frame_analyzer.py b/sharp_frame_extractor/analyzer/tenegrad_frame_analyzer.py index 93b064c..14205ad 100644 --- a/sharp_frame_extractor/analyzer/tenegrad_frame_analyzer.py +++ b/sharp_frame_extractor/analyzer/tenegrad_frame_analyzer.py @@ -8,8 +8,8 @@ from sharp_frame_extractor.analyzer.frame_analyzer_base import ( FrameAnalyzerBase, - FrameAnalyzerTask, FrameAnalyzerResult, + FrameAnalyzerTask, ) @@ -66,8 +66,7 @@ def reset_states(self) -> None: self._cached_weight_key = None self._cached_weight = None - def process(self, task: FrameAnalyzerTask) -> FrameAnalyzerResult: - frames = task.frames + def process(self, task: FrameAnalyzerTask, frames: np.ndarray) -> FrameAnalyzerResult: if frames.ndim not in (3, 4): raise ValueError(f"Expected frames with shape (N,H,W) or (N,H,W,C), got {frames.shape}") @@ -79,13 +78,10 @@ def process(self, task: FrameAnalyzerTask) -> FrameAnalyzerResult: gray = self._to_gray(frames[i]) raw_scores[i] = self._tenengrad(gray, weights) - best_idx = int(np.argmax(raw_scores)) - best_frame = frames[best_idx] - score = float(self._score_01(raw_scores, best_idx)) - - return FrameAnalyzerResult(block_index=task.block_index, frame_index=best_idx, frame=best_frame, score=score) + return FrameAnalyzerResult(block_index=task.block_index, scores=raw_scores.tolist()) - def _to_gray(self, frame: np.ndarray) -> np.ndarray: + @staticmethod + def _to_gray(frame: np.ndarray) -> np.ndarray: if frame.ndim == 2: return frame @@ -147,18 +143,3 @@ def _tenengrad(self, gray: np.ndarray, weights: np.ndarray | None) -> np.float32 return np.float32(g2.mean()) return np.float32((g2 * weights).sum()) - - def _score_01(self, raw_scores: np.ndarray, best_idx: int) -> float: - eps = float(self._cfg.eps) - best = float(raw_scores[best_idx]) - - if self._cfg.normalize is ScoreNormalization.MINMAX: - mn = float(raw_scores.min()) - mx = float(raw_scores.max()) - return float(np.clip((best - mn) / (mx - mn + eps), 0.0, 1.0)) - - med = float(np.median(raw_scores)) - mad = float(np.median(np.abs(raw_scores - med))) - scale = 1.4826 * mad + eps - z = (best - med) / scale - return float(1.0 / (1.0 + np.exp(-z))) diff --git a/sharp_frame_extractor/args_utils.py b/sharp_frame_extractor/args_utils.py index 783fe88..b126b9f 100644 --- a/sharp_frame_extractor/args_utils.py +++ b/sharp_frame_extractor/args_utils.py @@ -1,6 +1,10 @@ import argparse import os +import psutil + +MIN_MEMORY_LIMIT = 4096 + def positive_int(value: str) -> int: try: @@ -36,3 +40,21 @@ def default_concurrency() -> tuple[int, int]: workers = max(1, int(cpu * 0.8)) return jobs, workers + + +def default_memory_limit_mb(safe_factor: float = 0.8) -> int: + memory_info = psutil.virtual_memory() + + total_bytes = 0 + + try: + total_bytes = memory_info.total + except AttributeError: + pass + + # Fallback to 4GB if detection failed or 0 + if total_bytes <= 0: + return MIN_MEMORY_LIMIT + + # Return n% of total memory in MB + return int((total_bytes * 0.8) / (1024 * 1024)) diff --git a/sharp_frame_extractor/event.py b/sharp_frame_extractor/event.py new file mode 100644 index 0000000..83d737e --- /dev/null +++ b/sharp_frame_extractor/event.py @@ -0,0 +1,210 @@ +import threading +from typing import TypeVar, Generic, Callable, List, Optional, Iterator + +T = TypeVar("T") +H = Callable[[T], None] + + +class Event(Generic[T]): + """ + A generic event class that allows you to register and trigger event handlers, + and also provides a way to wait for the next event to be fired. + + Attributes: + _handlers (List[H]): A list to store event handlers. + """ + + def __init__(self): + """ + Initialize the Event instance with an empty list of handlers + and a threading event to allow waiting for events. + """ + self._handlers: List[H] = [] + self._latest_value: Optional[T] = None + self._event_trigger = threading.Event() + + def append(self, handler: H) -> None: + """ + Append an event handler to the list of handlers. + + Args: + handler (H): The event handler function to add. + """ + self._handlers.append(handler) + + def remove(self, handler: H) -> None: + """ + Remove an event handler from the list of handlers. + + Args: + handler (H): The event handler function to remove. + """ + self._handlers.remove(handler) + + def contains(self, handler: H) -> bool: + """ + Check if a specific event handler is already registered. + + Args: + handler (H): The event handler function to check for. + + Returns: + bool: True if the handler is in the list, False otherwise. + """ + return handler in self._handlers + + def invoke(self, value: T) -> None: + """ + Invoke all registered event handlers with the provided value. + Also set the threading event to allow waiting mechanisms to proceed. + + Args: + value (T): The value to pass to the event handlers. + """ + self._latest_value = value + for handler in self._handlers: + handler(value) + + # Trigger the event for waiting threads + self._event_trigger.set() + + def invoke_latest(self, value: T) -> None: + """ + Invoke the most recently added event handler with the provided value. + + If no event handlers are registered, this method does nothing. + + Args: + value (T): The value to pass to the latest event handler. + """ + if len(self._handlers) == 0: + return + self._handlers[-1](value) + + def clear(self) -> None: + """ + Clear all registered event handlers, removing them from the list. + """ + self._handlers.clear() + + def register(self, handler: H) -> H: + """ + Append an event handler to the list of handlers and return it. + This method should be used as decorator. + + Args: + handler (H): The event handler function to add. + Returns: + H: Returns the handler given as argument. + """ + self.append(handler) + return handler + + @property + def handler_size(self) -> int: + """ + Get the number of registered event handlers. + + Returns: + int: The number of event handlers currently registered. + """ + return len(self._handlers) + + def __iadd__(self, other): + """ + Allow the use of '+=' to add an event handler. + + Args: + other (H): The event handler function to add. + + Returns: + Event[T]: The updated Event instance. + """ + self.append(other) + return self + + def __isub__(self, other): + """ + Allow the use of '-=' to remove an event handler. + + Args: + other (H): The event handler function to remove. + + Returns: + Event[T]: The updated Event instance. + """ + self.remove(other) + return self + + def __contains__(self, item) -> bool: + """ + Check if a specific event handler is already registered using 'in' operator. + + Args: + item (H): The event handler function to check for. + + Returns: + bool: True if the handler is in the list, False otherwise. + """ + return self.contains(item) + + def __call__(self, value: T): + """ + Allow the instance to be called as a function, invoking all event handlers. + + Args: + value (T): The value to pass to the event handlers. + """ + self.invoke(value) + + def wait(self, timeout: Optional[float] = None) -> Optional[T]: + """ + Wait for the next event to be fired, with an optional timeout. + + Args: + timeout (Optional[float]): The maximum time (in seconds) to wait. + If None, wait indefinitely. + + Returns: + Optional[T]: The value passed when the event was triggered, + or None if the timeout was reached. + """ + event_occurred = self._event_trigger.wait(timeout) + + # If the event occurred, clear the event and return the latest value + if event_occurred: + self._event_trigger.clear() + return self._latest_value + else: + # Return None if the timeout is reached + return None + + def stream(self, timeout: Optional[float] = None) -> Iterator[Optional[T]]: + """ + Continuously yield the value whenever the event is triggered, with an optional timeout. + + Args: + timeout (Optional[float]): The maximum time (in seconds) to wait + between yielding values. If None, wait indefinitely. + + Yields: + Optional[T]: The value passed each time the event is triggered, + or None if the timeout was reached. + """ + while True: + yield self.wait(timeout) + + def __getstate__(self): + """ + Custom method to remove the _event_trigger from the state when pickling. + """ + state = self.__dict__.copy() + state["_event_trigger"] = None # Exclude the event trigger from pickling + return state + + def __setstate__(self, state): + """ + Custom method to restore the _event_trigger after unpickling. + """ + self.__dict__.update(state) + self._event_trigger = threading.Event() # Reinitialize the event diff --git a/sharp_frame_extractor/memory/__init__.py b/sharp_frame_extractor/memory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sharp_frame_extractor/memory/shared_ndarray.py b/sharp_frame_extractor/memory/shared_ndarray.py new file mode 100644 index 0000000..eddfe91 --- /dev/null +++ b/sharp_frame_extractor/memory/shared_ndarray.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Dict, Optional, Self, Tuple + +import numpy as np + + +@dataclass(frozen=True, slots=True) +class SharedNDArrayRef: + name: str + shape: Tuple[int, ...] + dtype: str + order: str = "C" + writeable: bool = False + + +class SharedNDArray: + def __init__(self, shm: shared_memory.SharedMemory, ref: SharedNDArrayRef): + self._shm = shm + self.ref = ref + self.ndarray = np.ndarray( + ref.shape, + dtype=np.dtype(ref.dtype), + buffer=shm.buf, + order=ref.order, + ) + if not ref.writeable: + self.ndarray.setflags(write=False) + + @staticmethod + def _nbytes(shape: Tuple[int, ...], dtype: np.dtype) -> int: + if len(shape) == 0: + raise ValueError("shape must not be empty") + count = int(np.prod(shape)) + if count <= 0: + raise ValueError(f"invalid shape {shape}") + return count * int(dtype.itemsize) + + @classmethod + def create( + cls, + shape: Tuple[int, ...], + dtype: np.dtype | str, + *, + order: str = "C", + writeable: bool = True, + name: Optional[str] = None, + ) -> SharedNDArray: + dtype = np.dtype(dtype) + nbytes = cls._nbytes(shape, dtype) + shm = shared_memory.SharedMemory(create=True, size=nbytes, name=name) + ref = SharedNDArrayRef( + name=shm.name, + shape=tuple(shape), + dtype=dtype.str, + order=order, + writeable=writeable, + ) + return cls(shm, ref) + + @classmethod + def attach(cls, ref: SharedNDArrayRef) -> SharedNDArray: + shm = shared_memory.SharedMemory(name=ref.name, create=False) + return cls(shm, ref) + + def close(self) -> None: + self._shm.close() + + def unlink(self) -> None: + self._shm.unlink() + + def __enter__(self) -> SharedNDArray: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + + +class SharedNDArrayStoreBase(ABC): + def open(self) -> None: + """Optional setup method.""" + pass + + def close(self) -> None: + """Closes the store and releases all resources.""" + self.release_all() + + def __enter__(self) -> Self: + self.open() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + @abstractmethod + def put(self, arr: np.ndarray, *, order: str = "C", worker_writeable: bool = False) -> SharedNDArrayRef: + pass + + @abstractmethod + def release(self, ref: SharedNDArrayRef) -> None: + pass + + @abstractmethod + def release_all(self) -> None: + pass + + +class SharedNDArrayStore(SharedNDArrayStoreBase): + """ + Owns shared memory blocks in the producer process. + Call release(ref) when you are done to avoid leaked segments. + """ + + def __init__(self) -> None: + self._owned: Dict[str, SharedNDArray] = {} + + def put(self, arr: np.ndarray, *, order: str = "C", worker_writeable: bool = False) -> SharedNDArrayRef: + if order != "C": + raise ValueError("only C order is implemented in this helper") + if not arr.flags["C_CONTIGUOUS"]: + arr = np.ascontiguousarray(arr) + + shared = SharedNDArray.create(arr.shape, arr.dtype, order="C", writeable=True) + shared.ndarray[...] = arr + self._owned[shared.ref.name] = shared + + return SharedNDArrayRef( + name=shared.ref.name, + shape=shared.ref.shape, + dtype=shared.ref.dtype, + order=shared.ref.order, + writeable=worker_writeable, + ) + + def release(self, ref: SharedNDArrayRef) -> None: + shared = self._owned.pop(ref.name, None) + if shared is None: + return + + shared.close() + try: + shared.unlink() + except FileNotFoundError: + pass + + def release_all(self) -> None: + names = list(self._owned.keys()) + for name in names: + shared = self._owned.pop(name, None) + if shared is None: + continue + shared.close() + try: + shared.unlink() + except FileNotFoundError: + pass diff --git a/sharp_frame_extractor/memory/shared_ndarray_pool.py b/sharp_frame_extractor/memory/shared_ndarray_pool.py new file mode 100644 index 0000000..1dd63eb --- /dev/null +++ b/sharp_frame_extractor/memory/shared_ndarray_pool.py @@ -0,0 +1,75 @@ +import threading +from itertools import chain +from multiprocessing.shared_memory import SharedMemory + +import numpy as np + +from sharp_frame_extractor.memory.shared_ndarray import SharedNDArrayRef, SharedNDArrayStoreBase + + +class PooledSharedNDArrayStore(SharedNDArrayStoreBase): + """ + A shared memory store that reuses memory segments. + + It is initialized with a fixed item size and a buffer limit. + """ + + def __init__(self, item_size: int, n_buffers: int): + self._item_size = item_size + self._n_buffers = n_buffers + self._pool: list[SharedMemory] = [] + self._active: dict[str, SharedMemory] = {} + self._lock = threading.Lock() + # Semaphore limits the number of active shared memory segments + # This provides backpressure if the consumers (workers) are slower than the producer + self._semaphore = threading.Semaphore(n_buffers) + + def put(self, arr: np.ndarray, *, order: str = "C", worker_writeable: bool = False) -> SharedNDArrayRef: + if order != "C": + raise ValueError("only C order is implemented in this helper") + + # Ensure we don't exceed the buffer size + if arr.nbytes > self._item_size: + raise ValueError(f"Array size {arr.nbytes} exceeds configured pool item size {self._item_size}") + + # Block until a buffer is available (backpressure) + self._semaphore.acquire() + + shm = None + with self._lock: + # Try to find a free buffer in the pool + if self._pool: + shm = self._pool.pop() + + # If no buffer in pool (but semaphore acquired), create a new one + if shm is None: + shm = SharedMemory(create=True, size=self._item_size) + + self._active[shm.name] = shm + + # Copy data into shared memory + shm_array = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf) + shm_array[:] = arr[:] + + return SharedNDArrayRef(shm.name, arr.shape, arr.dtype.str, order=order, writeable=worker_writeable) + + def release(self, ref: SharedNDArrayRef) -> None: + with self._lock: + if ref.name in self._active: + shm = self._active.pop(ref.name) + self._pool.append(shm) + + # Signal that a buffer is free + self._semaphore.release() + + def release_all(self) -> None: + with self._lock: + # Combine active and pool, close and unlink everything + for shm in chain(self._active.values(), self._pool): + try: + shm.close() + shm.unlink() + except Exception: + pass + self._active.clear() + self._pool.clear() diff --git a/sharp_frame_extractor/models.py b/sharp_frame_extractor/models.py new file mode 100644 index 0000000..b7802c0 --- /dev/null +++ b/sharp_frame_extractor/models.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from abc import ABC +from dataclasses import dataclass, field +from pathlib import Path +from threading import Lock +from typing import ClassVar, Self + +import numpy as np + +from sharp_frame_extractor.analyzer.frame_analyzer_base import FrameAnalyzerResult + + +@dataclass +class AutoTaskIdMixin: + _task_id: int = field(init=False, repr=False) + + _next_id: ClassVar[int] + _id_lock: ClassVar[Lock] + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + cls._next_id = 1 + cls._id_lock = Lock() + + def __post_init__(self) -> None: + cls = type(self) + with cls._id_lock: + self._task_id = cls._next_id + cls._next_id += 1 + + post_init = getattr(super(), "__post_init__", None) + if post_init is not None: + post_init() + + @property + def task_id(self) -> int: + return self._task_id + + +@dataclass +class ExtractionOptions: + # either one of the two have ot be set + frame_interval_seconds: float | None = None + total_frame_count: int | None = None + + @classmethod + def from_interval(cls, frame_interval_seconds: float) -> Self: + return ExtractionOptions(frame_interval_seconds=frame_interval_seconds) + + @classmethod + def from_count(cls, total_frame_count: int) -> Self: + return ExtractionOptions(total_frame_count=total_frame_count) + + +@dataclass +class ExtractionTask(AutoTaskIdMixin): + video_path: Path + result_path: Path + options: ExtractionOptions + + +@dataclass +class ExtractionResult: + task_id: int + + +@dataclass +class VideoFrameInfo: + interval_index: int + frame_index: int + score: float + frame: np.ndarray + + +# events models + + +@dataclass +class TaskEvent(ABC): + task: ExtractionTask + + +@dataclass +class TaskStartedEvent(TaskEvent): + pass + + +@dataclass +class TaskPreparedEvent(TaskEvent): + total_blocks: int + total_frames: int + + +@dataclass +class TaskAnalyzedEvent(TaskEvent): + total_blocks: int + total_frames: int + + +@dataclass +class TaskFinishedEvent(TaskEvent): + pass + + +@dataclass +class BlockEvent(ABC): + task: ExtractionTask + + +@dataclass +class BlockAnalyzedEvent(BlockEvent): + block_id: int + result: FrameAnalyzerResult + + +@dataclass +class BlockFrameExtracted(BlockEvent): + frame_info: VideoFrameInfo diff --git a/sharp_frame_extractor/output/__init__.py b/sharp_frame_extractor/output/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sharp_frame_extractor/output/file_output_handler.py b/sharp_frame_extractor/output/file_output_handler.py new file mode 100644 index 0000000..eb4ac21 --- /dev/null +++ b/sharp_frame_extractor/output/file_output_handler.py @@ -0,0 +1,57 @@ +import threading +from concurrent.futures import Future, ThreadPoolExecutor +from pathlib import Path + +import cv2 +import numpy as np + +from sharp_frame_extractor.models import ExtractionTask, VideoFrameInfo +from sharp_frame_extractor.output.frame_output_handler_base import FrameOutputHandlerBase + + +class FileOutputHandler(FrameOutputHandlerBase): + def __init__(self, max_workers: int = 4, max_queue_size: int = 32): + self._max_workers = max_workers + self._writer_pool: ThreadPoolExecutor | None = None + + # Semaphore to prevent unbounded memory usage if writing is slower than extraction + self._queue_semaphore = threading.Semaphore(max_queue_size) + + def open(self): + self._writer_pool = ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="writer") + + def prepare_task(self, task: ExtractionTask): + # make the output directory exists + task.result_path.mkdir(parents=True, exist_ok=True) + + def handle_block(self, task: ExtractionTask, frame_info: VideoFrameInfo): + output_file_name = task.result_path / f"frame-{frame_info.interval_index:05d}.png" + + if output_file_name.exists(): + output_file_name.unlink(missing_ok=True) + + # Create a copy of the frame to detach it from the larger memory block + # This ensures the large buffer from ffmpegio can be GC'd even if writing is pending + frame_copy = frame_info.frame.copy() + + # Block if queue is full (backpressure) + self._queue_semaphore.acquire() + + future = self._writer_pool.submit(self._write_output, output_file_name, frame_copy) + future.add_done_callback(self._on_task_done) + + def _on_task_done(self, future: Future): + self._queue_semaphore.release() + try: + future.result() + except Exception as e: + print(f"Error writing frame: {e}") + + @staticmethod + def _write_output(output_file_name: Path, frame: np.ndarray): + # convert frame to bgr + bgr_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + cv2.imwrite(str(output_file_name.absolute()), bgr_frame) + + def close(self): + self._writer_pool.shutdown(wait=True) diff --git a/sharp_frame_extractor/output/frame_output_handler_base.py b/sharp_frame_extractor/output/frame_output_handler_base.py new file mode 100644 index 0000000..ffebf36 --- /dev/null +++ b/sharp_frame_extractor/output/frame_output_handler_base.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + +from sharp_frame_extractor.models import ExtractionTask, VideoFrameInfo + + +class FrameOutputHandlerBase(ABC): + @abstractmethod + def open(self): + pass + + @abstractmethod + def prepare_task(self, task: ExtractionTask): + pass + + @abstractmethod + def handle_block(self, task: ExtractionTask, frame_info: VideoFrameInfo): + pass + + @abstractmethod + def close(self): + pass diff --git a/sharp_frame_extractor/ui/__init__.py b/sharp_frame_extractor/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sharp_frame_extractor/ui/progress_bar.py b/sharp_frame_extractor/ui/progress_bar.py new file mode 100644 index 0000000..59659c5 --- /dev/null +++ b/sharp_frame_extractor/ui/progress_bar.py @@ -0,0 +1,33 @@ +from rich.progress import BarColumn, Task +from rich.progress_bar import ProgressBar +from rich.style import StyleType + + +class StatefulBarColumn(BarColumn): + """ + Read bar styles from task.fields so they can be changed at runtime via Progress.update(...). + + Supported task fields: + - bar_style + - bar_complete_style + - bar_finished_style + - bar_pulse_style + """ + + def render(self, task: Task) -> ProgressBar: + style: StyleType = task.fields.get("bar_style", self.style) + complete_style: StyleType = task.fields.get("bar_complete_style", self.complete_style) + finished_style: StyleType = task.fields.get("bar_finished_style", self.finished_style) + pulse_style: StyleType = task.fields.get("bar_pulse_style", self.pulse_style) + + return ProgressBar( + total=max(0, int(task.total)) if task.total is not None else None, + completed=max(0, int(task.completed)), + width=None if self.bar_width is None else max(1, self.bar_width), + pulse=not task.started, + animation_time=task.get_time(), + style=style, + complete_style=complete_style, + finished_style=finished_style, + pulse_style=pulse_style, + ) diff --git a/sharp_frame_extractor/worker/BaseWorker.py b/sharp_frame_extractor/worker/BaseWorker.py index ecd2de2..df47be7 100644 --- a/sharp_frame_extractor/worker/BaseWorker.py +++ b/sharp_frame_extractor/worker/BaseWorker.py @@ -1,11 +1,11 @@ import logging import threading from abc import ABC, abstractmethod -from multiprocessing import Process, Queue, Event, current_process -from typing import Generic, Dict +from multiprocessing import Event, Process, Queue, current_process +from typing import Dict, Generic from .Future import Future -from .types import TTask, TResult +from .types import TResult, TTask logger = logging.getLogger(__name__) @@ -76,7 +76,11 @@ def _result_listener(self): and resolves the corresponding Future. """ while True: - task_id, result = self.results.get() + item = self.results.get() + if item is None: # Sentinel value signals shutdown + break + + task_id, result = item future = self._futures.pop(task_id, None) if future: if isinstance(result, Exception): diff --git a/sharp_frame_extractor/worker/BaseWorkerPool.py b/sharp_frame_extractor/worker/BaseWorkerPool.py index c4ee88a..4d87c14 100644 --- a/sharp_frame_extractor/worker/BaseWorkerPool.py +++ b/sharp_frame_extractor/worker/BaseWorkerPool.py @@ -1,7 +1,7 @@ import logging from abc import ABC from multiprocessing import Queue -from typing import TypeVar, Generic, Callable, List +from typing import Callable, Generic, List, TypeVar from .BaseWorker import BaseWorker @@ -56,4 +56,7 @@ def stop(self): logger.debug("Pool: Stopping all workers.") for worker in self.workers: worker.stop() + + for worker in self.workers: + worker.join() logger.debug("Pool: All workers stopped.") diff --git a/sharp_frame_extractor/worker/Future.py b/sharp_frame_extractor/worker/Future.py index 37cb6df..afbef99 100644 --- a/sharp_frame_extractor/worker/Future.py +++ b/sharp_frame_extractor/worker/Future.py @@ -2,7 +2,7 @@ import logging import threading -from typing import Optional, Generic, Callable, List +from typing import Callable, Generic, List, Optional from .types import TResult @@ -16,6 +16,7 @@ class Future(Generic[TResult]): def __init__(self): self._done = threading.Event() + self._callbacks_done = threading.Event() self._result: Optional[TResult] = None self._exception: Optional[Exception] = None self._callbacks: List[Callable[[Future[TResult]], None]] = [] @@ -27,18 +28,25 @@ def _invoke_callbacks(self): callback(self) except Exception as e: logger.error(f"Error in Future callback: {e}") + self._callbacks_done.set() def set_result(self, result: TResult): with self._lock: self._result = result self._done.set() - self._invoke_callbacks() + if self._callbacks: + self._invoke_callbacks() + else: + self._callbacks_done.set() def set_exception(self, exception: Exception): with self._lock: self._exception = exception self._done.set() - self._invoke_callbacks() + if self._callbacks: + self._invoke_callbacks() + else: + self._callbacks_done.set() def add_done_callback(self, fn: Callable[[Future[TResult]], None]): """ @@ -52,6 +60,10 @@ def add_done_callback(self, fn: Callable[[Future[TResult]], None]): self._callbacks.append(fn) def result(self, timeout: Optional[float] = None) -> TResult: + """ + Waits for the result to be available and returns it. + Raises the exception if one was set. + """ if self._done.wait(timeout): if self._exception: raise self._exception @@ -59,5 +71,22 @@ def result(self, timeout: Optional[float] = None) -> TResult: else: raise TimeoutError("Future result not available within timeout.") + def wait(self, timeout: Optional[float] = None) -> bool: + """ + Waits for the future to complete AND all callbacks to finish. + Returns True if completed within timeout, False otherwise. + """ + if not self._done.wait(timeout): + return False + return self._callbacks_done.wait(timeout) + + def clear(self): + """ + Clear the result and callbacks to allow garbage collection. + """ + self._result = None + self._exception = None + self._callbacks.clear() + def done(self) -> bool: return self._done.is_set() diff --git a/uv.lock b/uv.lock index 07d8a2e..a460906 100644 --- a/uv.lock +++ b/uv.lock @@ -138,6 +138,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "psutil" +version = "7.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/73/cb/09e5184fb5fc0358d110fc3ca7f6b1d033800734d34cac10f4136cfac10e/psutil-7.2.1.tar.gz", hash = "sha256:f7583aec590485b43ca601dd9cea0dcd65bd7bb21d30ef4ddbf4ea6b5ed1bdd3", size = 490253, upload-time = "2025-12-29T08:26:00.169Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/8e/f0c242053a368c2aa89584ecd1b054a18683f13d6e5a318fc9ec36582c94/psutil-7.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:ba9f33bb525b14c3ea563b2fd521a84d2fa214ec59e3e6a2858f78d0844dd60d", size = 129624, upload-time = "2025-12-29T08:26:04.255Z" }, + { url = "https://files.pythonhosted.org/packages/26/97/a58a4968f8990617decee234258a2b4fc7cd9e35668387646c1963e69f26/psutil-7.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:81442dac7abfc2f4f4385ea9e12ddf5a796721c0f6133260687fec5c3780fa49", size = 130132, upload-time = "2025-12-29T08:26:06.228Z" }, + { url = "https://files.pythonhosted.org/packages/db/6d/ed44901e830739af5f72a85fa7ec5ff1edea7f81bfbf4875e409007149bd/psutil-7.2.1-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ea46c0d060491051d39f0d2cff4f98d5c72b288289f57a21556cc7d504db37fc", size = 180612, upload-time = "2025-12-29T08:26:08.276Z" }, + { url = "https://files.pythonhosted.org/packages/c7/65/b628f8459bca4efbfae50d4bf3feaab803de9a160b9d5f3bd9295a33f0c2/psutil-7.2.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:35630d5af80d5d0d49cfc4d64c1c13838baf6717a13effb35869a5919b854cdf", size = 183201, upload-time = "2025-12-29T08:26:10.622Z" }, + { url = "https://files.pythonhosted.org/packages/fb/23/851cadc9764edcc18f0effe7d0bf69f727d4cf2442deb4a9f78d4e4f30f2/psutil-7.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:923f8653416604e356073e6e0bccbe7c09990acef442def2f5640dd0faa9689f", size = 139081, upload-time = "2025-12-29T08:26:12.483Z" }, + { url = "https://files.pythonhosted.org/packages/59/82/d63e8494ec5758029f31c6cb06d7d161175d8281e91d011a4a441c8a43b5/psutil-7.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:cfbe6b40ca48019a51827f20d830887b3107a74a79b01ceb8cc8de4ccb17b672", size = 134767, upload-time = "2025-12-29T08:26:14.528Z" }, + { url = "https://files.pythonhosted.org/packages/05/c2/5fb764bd61e40e1fe756a44bd4c21827228394c17414ade348e28f83cd79/psutil-7.2.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:494c513ccc53225ae23eec7fe6e1482f1b8a44674241b54561f755a898650679", size = 129716, upload-time = "2025-12-29T08:26:16.017Z" }, + { url = "https://files.pythonhosted.org/packages/c9/d2/935039c20e06f615d9ca6ca0ab756cf8408a19d298ffaa08666bc18dc805/psutil-7.2.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:3fce5f92c22b00cdefd1645aa58ab4877a01679e901555067b1bd77039aa589f", size = 130133, upload-time = "2025-12-29T08:26:18.009Z" }, + { url = "https://files.pythonhosted.org/packages/77/69/19f1eb0e01d24c2b3eacbc2f78d3b5add8a89bf0bb69465bc8d563cc33de/psutil-7.2.1-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:93f3f7b0bb07711b49626e7940d6fe52aa9940ad86e8f7e74842e73189712129", size = 181518, upload-time = "2025-12-29T08:26:20.241Z" }, + { url = "https://files.pythonhosted.org/packages/e1/6d/7e18b1b4fa13ad370787626c95887b027656ad4829c156bb6569d02f3262/psutil-7.2.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d34d2ca888208eea2b5c68186841336a7f5e0b990edec929be909353a202768a", size = 184348, upload-time = "2025-12-29T08:26:22.215Z" }, + { url = "https://files.pythonhosted.org/packages/98/60/1672114392dd879586d60dd97896325df47d9a130ac7401318005aab28ec/psutil-7.2.1-cp314-cp314t-win_amd64.whl", hash = "sha256:2ceae842a78d1603753561132d5ad1b2f8a7979cb0c283f5b52fb4e6e14b1a79", size = 140400, upload-time = "2025-12-29T08:26:23.993Z" }, + { url = "https://files.pythonhosted.org/packages/fb/7b/d0e9d4513c46e46897b46bcfc410d51fc65735837ea57a25170f298326e6/psutil-7.2.1-cp314-cp314t-win_arm64.whl", hash = "sha256:08a2f175e48a898c8eb8eace45ce01777f4785bc744c90aa2cc7f2fa5462a266", size = 135430, upload-time = "2025-12-29T08:26:25.999Z" }, + { url = "https://files.pythonhosted.org/packages/c5/cf/5180eb8c8bdf6a503c6919f1da28328bd1e6b3b1b5b9d5b01ae64f019616/psutil-7.2.1-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b2e953fcfaedcfbc952b44744f22d16575d3aa78eb4f51ae74165b4e96e55f42", size = 128137, upload-time = "2025-12-29T08:26:27.759Z" }, + { url = "https://files.pythonhosted.org/packages/c5/2c/78e4a789306a92ade5000da4f5de3255202c534acdadc3aac7b5458fadef/psutil-7.2.1-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:05cc68dbb8c174828624062e73078e7e35406f4ca2d0866c272c2410d8ef06d1", size = 128947, upload-time = "2025-12-29T08:26:29.548Z" }, + { url = "https://files.pythonhosted.org/packages/29/f8/40e01c350ad9a2b3cb4e6adbcc8a83b17ee50dd5792102b6142385937db5/psutil-7.2.1-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5e38404ca2bb30ed7267a46c02f06ff842e92da3bb8c5bfdadbd35a5722314d8", size = 154694, upload-time = "2025-12-29T08:26:32.147Z" }, + { url = "https://files.pythonhosted.org/packages/06/e4/b751cdf839c011a9714a783f120e6a86b7494eb70044d7d81a25a5cd295f/psutil-7.2.1-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab2b98c9fc19f13f59628d94df5cc4cc4844bc572467d113a8b517d634e362c6", size = 156136, upload-time = "2025-12-29T08:26:34.079Z" }, + { url = "https://files.pythonhosted.org/packages/44/ad/bbf6595a8134ee1e94a4487af3f132cef7fce43aef4a93b49912a48c3af7/psutil-7.2.1-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:f78baafb38436d5a128f837fab2d92c276dfb48af01a240b861ae02b2413ada8", size = 148108, upload-time = "2025-12-29T08:26:36.225Z" }, + { url = "https://files.pythonhosted.org/packages/1c/15/dd6fd869753ce82ff64dcbc18356093471a5a5adf4f77ed1f805d473d859/psutil-7.2.1-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:99a4cd17a5fdd1f3d014396502daa70b5ec21bf4ffe38393e152f8e449757d67", size = 147402, upload-time = "2025-12-29T08:26:39.21Z" }, + { url = "https://files.pythonhosted.org/packages/34/68/d9317542e3f2b180c4306e3f45d3c922d7e86d8ce39f941bb9e2e9d8599e/psutil-7.2.1-cp37-abi3-win_amd64.whl", hash = "sha256:b1b0671619343aa71c20ff9767eced0483e4fc9e1f489d50923738caf6a03c17", size = 136938, upload-time = "2025-12-29T08:26:41.036Z" }, + { url = "https://files.pythonhosted.org/packages/3e/73/2ce007f4198c80fcf2cb24c169884f833fe93fbc03d55d302627b094ee91/psutil-7.2.1-cp37-abi3-win_arm64.whl", hash = "sha256:0d67c1822c355aa6f7314d92018fb4268a76668a536f133599b91edd48759442", size = 133836, upload-time = "2025-12-29T08:26:43.086Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -200,12 +228,13 @@ wheels = [ [[package]] name = "sharp-frame-extractor" -version = "2.0.0rc1" +version = "2.0.0rc2" source = { editable = "." } dependencies = [ { name = "ffmpegio" }, { name = "numpy" }, { name = "opencv-python" }, + { name = "psutil" }, { name = "rich" }, { name = "rich-argparse" }, ] @@ -223,6 +252,7 @@ requires-dist = [ { name = "ffmpegio", specifier = ">=0.11.1" }, { name = "numpy", specifier = ">=2.4.0" }, { name = "opencv-python", specifier = ">=4.11.0.86" }, + { name = "psutil", specifier = ">=7.2.1" }, { name = "rich", specifier = ">=14.2.0" }, { name = "rich-argparse", specifier = ">=1.7.2" }, ]