diff --git a/python/unblob/cli.py b/python/unblob/cli.py index 3c0d80a28a..42f150f171 100755 --- a/python/unblob/cli.py +++ b/python/unblob/cli.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 import atexit +import shutil import sys +import tempfile from collections.abc import Iterable from importlib.metadata import version from pathlib import Path @@ -363,6 +365,9 @@ def cli( extra_magics_to_skip = () if clear_skip_magics else DEFAULT_SKIP_MAGIC skip_magic = tuple(sorted(set(skip_magic).union(extra_magics_to_skip))) + # Create dedicated unblob temp directory + unblob_tmp_dir = Path(tempfile.mkdtemp(prefix="unblob-", dir=tempfile.gettempdir())) + config = ExtractionConfig( extract_root=extract_root, force_extract=force, @@ -382,6 +387,7 @@ def cli( progress_reporter=NullProgressReporter if verbose else RichConsoleProgressReporter, + tmp_dir=unblob_tmp_dir, ) logger.info("Creating extraction directory", extract_root=extract_root) @@ -389,6 +395,17 @@ def cli( logger.info("Start processing file", file=file) sandbox = Sandbox(config, log_path, report_file) process_results = sandbox.run(process_file, config, file, report_file) + + # Clean up the temp directory we created + try: + shutil.rmtree(unblob_tmp_dir) + except Exception as e: + logger.warning( + "Failed to clean up tmp_dir", + tmp_dir=unblob_tmp_dir, + exc_info=e, + ) + if verbose == 0: if skip_extraction: print_scan_report(process_results) diff --git a/python/unblob/processing.py b/python/unblob/processing.py index 2509de7c83..e9d6928439 100644 --- a/python/unblob/processing.py +++ b/python/unblob/processing.py @@ -1,6 +1,9 @@ import multiprocessing +import os import shutil +import tempfile from collections.abc import Iterable, Sequence +from contextlib import contextmanager from operator import attrgetter from pathlib import Path from typing import Optional, Union @@ -100,6 +103,7 @@ class ExtractionConfig: dir_handlers: DirectoryHandlers = BUILTIN_DIR_HANDLERS verbose: int = 1 progress_reporter: type[ProgressReporter] = NullProgressReporter + tmp_dir: Path = attrs.field(factory=lambda: Path(tempfile.gettempdir())) def _get_output_path(self, path: Path) -> Path: """Return path under extract root.""" @@ -227,6 +231,46 @@ def write_json_report(report_file: Path, process_result: ProcessResult): logger.info("JSON report written", path=report_file) +@contextmanager +def task_tmp_dir(parent_tmp_dir): + """Context manager that creates a task-specific temp subdirectory. + + Creates a subdirectory under parent_tmp_dir, sets all temp env vars to it, + yields the path, then cleans up the subdirectory on exit. + + The parent_tmp_dir itself is NOT cleaned up - caller is responsible for that. + """ + tmp_vars = ("TMP", "TMPDIR", "TEMP", "TEMPDIR") + saved = {} + + # Create task-specific subdirectory + task_temp = Path(tempfile.mkdtemp(dir=parent_tmp_dir, prefix="unblob-task-")) + + try: + # Override env vars to point to task subdirectory + for var in tmp_vars: + saved[var] = os.environ.get(var) + os.environ[var] = str(task_temp) + yield task_temp + finally: + # Restore original env vars + for var, original in saved.items(): + if original is None: + os.environ.pop(var, None) + else: + os.environ[var] = original + + # Clean up the task subdirectory (NOT the parent) + try: + shutil.rmtree(task_temp) + except Exception as e: + logger.warning( + "Failed to clean up task temp dir", + task_temp=task_temp, + exc_info=e, + ) + + class Processor: def __init__(self, config: ExtractionConfig): self._config = config @@ -244,7 +288,8 @@ def __init__(self, config: ExtractionConfig): def process_task(self, task: Task) -> TaskResult: result = TaskResult(task=task) try: - self._process_task(result, task) + with task_tmp_dir(self._config.tmp_dir): + self._process_task(result, task) except Exception as exc: self._process_error(result, exc) return result diff --git a/python/unblob/sandbox.py b/python/unblob/sandbox.py index 61b02b099d..c46c9aa337 100644 --- a/python/unblob/sandbox.py +++ b/python/unblob/sandbox.py @@ -55,6 +55,10 @@ def __init__( AccessFS.remove_file(config.extract_root), AccessFS.make_dir(config.extract_root.parent), AccessFS.read_write(log_path), + # Allow access to the managed temp directory for handlers + AccessFS.read_write(config.tmp_dir), + AccessFS.remove_dir(config.tmp_dir), + AccessFS.remove_file(config.tmp_dir), *extra_passthrough, ] diff --git a/tests/test_cli.py b/tests/test_cli.py index 892499c939..6157e4ae5d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -273,6 +273,7 @@ def test_archive_success( handlers=BUILTIN_HANDLERS, verbose=expected_verbosity, progress_reporter=expected_progress_reporter, + tmp_dir=mock.ANY, ) process_file_mock.assert_called_once_with(config, in_path, None) logger_config_mock.assert_called_once_with(expected_verbosity, tmp_path, log_path) diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py index 7d8511d059..307e210d03 100644 --- a/tests/test_sandbox.py +++ b/tests/test_sandbox.py @@ -21,6 +21,9 @@ def extraction_config(extraction_config, tmp_path): extraction_config.extract_root = tmp_path / "extract" / "root" # parent has to exist extraction_config.extract_root.parent.mkdir() + # Set tmp_dir to a specific directory under tmp_path to avoid conflicts + extraction_config.tmp_dir = tmp_path / "unblob_tmp" + extraction_config.tmp_dir.mkdir() return extraction_config @@ -34,6 +37,8 @@ def test_necessary_resources_can_be_created_in_sandbox( ): directory_in_extract_root = extraction_config.extract_root / "path" / "to" / "dir" file_in_extract_root = directory_in_extract_root / "file" + file_in_tmp_dir = extraction_config.tmp_dir / "tmp_file" + directory_in_tmp_dir = extraction_config.tmp_dir / "tmp_dir" sandbox.run(extraction_config.extract_root.mkdir, parents=True) sandbox.run(directory_in_extract_root.mkdir, parents=True) @@ -45,6 +50,12 @@ def test_necessary_resources_can_be_created_in_sandbox( log_path.touch() sandbox.run(log_path.write_text, "log line") + sandbox.run(directory_in_tmp_dir.mkdir, parents=True) + sandbox.run(file_in_tmp_dir.touch) + sandbox.run(file_in_tmp_dir.write_text, "tmp file content") + sandbox.run(file_in_tmp_dir.unlink) + sandbox.run(directory_in_tmp_dir.rmdir) + def test_access_outside_sandbox_is_not_possible(sandbox: Sandbox, tmp_path: Path): unrelated_dir = tmp_path / "unrelated" / "path"