diff --git a/gcsfs/extended_gcsfs.py b/gcsfs/extended_gcsfs.py index 164c74fd..59dd6b0d 100644 --- a/gcsfs/extended_gcsfs.py +++ b/gcsfs/extended_gcsfs.py @@ -1,7 +1,10 @@ import asyncio +import contextlib import logging import os import uuid +import weakref +from concurrent.futures import ThreadPoolExecutor from enum import Enum from glob import has_magic @@ -16,11 +19,20 @@ AsyncAppendableObjectWriter, ) from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) from gcsfs import __version__ as version from gcsfs import zb_hns_utils from gcsfs.core import GCSFile, GCSFileSystem from gcsfs.retry import DEFAULT_RETRY_CONFIG, get_storage_control_retry_config +from gcsfs.zb_hns_utils import ( + DirectMemmoveBuffer, + MRDPool, + PyBytes_AsString, + PyBytes_FromStringAndSize, +) from gcsfs.zonal_file import ZonalFile logger = logging.getLogger("gcsfs") @@ -44,6 +56,31 @@ class BucketType(Enum): } +@contextlib.asynccontextmanager +async def _get_mrd_from_pool_or_mrd(mrd_or_pool): + """ + Helper function to yield an AsyncMultiRangeDownloader + whether a single instance or an MRDPool is provided. + """ + if isinstance(mrd_or_pool, MRDPool): + async with mrd_or_pool.get_mrd() as m: + yield m + elif isinstance(mrd_or_pool, AsyncMultiRangeDownloader): + yield mrd_or_pool + else: + raise TypeError( + f"Expected MRDPool or AsyncMultiRangeDownloader, got {type(mrd_or_pool)}" + ) + + +async def _get_mrd_size(mrd_or_pool): + """Helper to extract the persisted_size from either a pool or a single MRD.""" + if mrd_or_pool is None: + return None + async with _get_mrd_from_pool_or_mrd(mrd_or_pool) as m: + return m.persisted_size + + class ExtendedGcsFileSystem(GCSFileSystem): """ This class will be used when GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT env variable is set to true. @@ -89,6 +126,10 @@ def __init__(self, *args, finalize_on_close=False, **kwargs): if self.credentials.token == "anon": self.credential = AnonymousCredentials() self._storage_layout_cache = {} + self._memmove_executor = ThreadPoolExecutor( + max_workers=kwargs.get("memmove_max_workers", 8) + ) + weakref.finalize(self, self._memmove_executor.shutdown) @property def _user_project(self): @@ -285,73 +326,163 @@ async def _is_zonal_bucket(self, bucket): return bucket_type == BucketType.ZONAL_HIERARCHICAL async def _fetch_range_split( - self, path, start=None, chunk_lengths=None, mrd=None, size=None, **kwargs + self, + path, + start, + chunk_lengths, + concurrency, + mrd=None, + size=None, + **kwargs, ): """ - Reading multiple reads in one large stream. + Reading multiple adjacent ranges concurrently. - Optimized for Zonal Buckets: - Leverages AsyncMultiRangeDownloader.download_ranges() to fetch all requested - 'chunk_lengths' (chunks) concurrently in a single batch request, significantly - improving performance for ReadAheadV2. + Delegates concurrent fetching of individual chunks directly to `_cat_file`. """ + file_size = size or await _get_mrd_size(mrd) + if file_size is None: + logger.warning( + f"AsyncMultiRangeDownloader (MRD) for {path} has no 'persisted_size'. " + "Falling back to _info() to get the file size." + ) + file_size = (await self._info(path))["size"] + start_offset = start if start is not None else 0 + if start_offset >= file_size or start_offset + sum(chunk_lengths) > file_size: + raise RuntimeError("Request not satisfiable.") + + pool_created_here = False bucket, object_name, generation = self.split_path(path) - mrd_created = False - try: - if mrd is None: - # Check before creating MRD - if not await self._is_zonal_bucket(bucket): - raise RuntimeError( - "Internal error, this method is only supported for zonal buckets!" - ) - await self._get_grpc_client() - mrd = await zb_hns_utils.init_mrd( - self.grpc_client, bucket, object_name, generation - ) - mrd_created = True + if mrd is None: + # If no mrd is provided, we create one with pool size equal to passed concurrency. + pool_size = min(len(chunk_lengths), concurrency) + mrd = zb_hns_utils.MRDPool( + self, bucket, object_name, generation, pool_size=pool_size + ) + await mrd.initialize() + pool_created_here = True - file_size = size or mrd.persisted_size - if file_size is None: - logger.warning( - f"AsyncMultiRangeDownloader (MRD) for {path} has no 'persisted_size'. " - "Falling back to _info() to get the file size." + tasks = [] + try: + current_offset = start_offset + + cat_kwargs = kwargs.copy() + + for length in chunk_lengths: + end_offset = current_offset + length + tasks.append( + asyncio.create_task( + self._cat_file( + path, + start=current_offset, + end=end_offset, + mrd=mrd, + # Distribute the concurrency budget proportionally. + # Since these outer tasks are already concurrent, this is typically 1. + # However, if a large chunk dominates the total size, it receives + # higher concurrency to prevent it from becoming a bottleneck. + concurrency=max( + 1, length * concurrency // sum(chunk_lengths) + ), + **cat_kwargs, + ) + ) ) - file_size = (await self._info(path))["size"] - - if chunk_lengths: - start_offset = start if start is not None else 0 - current_offset = start_offset + current_offset = end_offset - if start_offset >= file_size or ( - chunk_lengths is not None - and start_offset + sum(chunk_lengths) > file_size - ): - raise RuntimeError("Request not satisfiable.") - - read_ranges = [] # To pass to MRD + results = await asyncio.gather(*tasks, return_exceptions=True) - for length in chunk_lengths: - read_ranges.append((current_offset, length)) - current_offset += length + # Bubble up any exceptions encountered during concurrent fetching + for res in results: + if isinstance(res, Exception): + raise res + + return results + except BaseException: + for t in tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + finally: + if pool_created_here: + await mrd.close() - return await zb_hns_utils.download_ranges(read_ranges, mrd) - else: - end = kwargs.get("end") - offset, length = await self._process_limits_to_offset_and_length( - path, start, end, file_size + async def _concurrent_mrd_fetch(self, offset, length, concurrency, mrd_or_pool): + """Helper to handle concurrent chunk downloads into a DirectMemmoveBuffer.""" + concurrency = ( + concurrency if length >= self.MIN_CHUNK_SIZE_FOR_CONCURRENCY else 1 + ) + result_bytes = PyBytes_FromStringAndSize(None, length) + buffer_ptr = PyBytes_AsString(result_bytes) + + part_size = length // concurrency + tasks = [] + buffers = [] + loop = asyncio.get_running_loop() + + # Track if the core download process failed + has_error = False + + async def _download(o, s, b, mrd_or_pool): + async with _get_mrd_from_pool_or_mrd(mrd_or_pool) as m_client: + await m_client.download_ranges([(o, s, b)]) + + for i in range(concurrency): + part_offset = offset + (i * part_size) + actual_size = part_size if i < concurrency - 1 else length - (i * part_size) + + part_address = buffer_ptr + (part_offset - offset) + buf = DirectMemmoveBuffer( + part_address, + part_address + actual_size, + self._memmove_executor, + ) + buffers.append(buf) + tasks.append( + asyncio.create_task( + _download(part_offset, actual_size, buf, mrd_or_pool) ) + ) - data = await zb_hns_utils.download_range( - offset=offset, length=length, mrd=mrd - ) - return [data] + try: + results = await asyncio.gather(*tasks, return_exceptions=True) + for res in results: + if isinstance(res, Exception): + has_error = True + raise res + except BaseException: + has_error = True + for t in tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise finally: - if mrd_created: - await zb_hns_utils.close_mrd(mrd) - - async def _cat_file(self, path, start=None, end=None, mrd=None, **kwargs): + for buf in buffers: + try: + await loop.run_in_executor(None, buf.close) + except BufferError: + # If we are already handling a network/download exception, + # ignore the BufferError (which is just a symptom of the drop). + # If there's no download error, this means the buffer logic + # itself failed, so we must surface the error. + if not has_error: + raise + + return result_bytes + + async def _cat_file( + self, + path, + start=None, + end=None, + mrd=None, + concurrency=zb_hns_utils.DEFAULT_CONCURRENCY, + **kwargs, + ): """Fetch a file's contents as bytes, with an optimized path for Zonal buckets. This method overrides the parent `_cat_file` to read objects in Zonal buckets using gRPC. @@ -360,47 +491,60 @@ async def _cat_file(self, path, start=None, end=None, mrd=None, **kwargs): path (str): The full GCS path to the file (e.g., "bucket/object"). start (int, optional): The starting byte position to read from. end (int, optional): The ending byte position to read to. - mrd (AsyncMultiRangeDownloader, optional): An existing multi-range - downloader instance. If not provided, a new one will be created for Zonal buckets. + mrd (AsyncMultiRangeDownloader, MRDPool, optional): An existing multi-range + downloader instance or a pool of MRD. If not provided, a new one will be created for Zonal buckets. + concurrency (int, optional): The max number of concurrent request to fetch the data. Returns: bytes: The content of the file or file range. """ - try: - mrd_created = False + pool_created_here = False - # A new MRD is required when read is done directly by the - # GCSFilesystem class without creating a GCSFile object first. - if mrd is None: - bucket, object_name, generation = self.split_path(path) + # A new MRDPool is required when read is done directly by the + # GCSFilesystem class without creating a GCSFile object first. + if mrd is None: + bucket, object_name, generation = self.split_path(path) + if not await self._is_zonal_bucket(bucket): # Fall back to default implementation if not a zonal bucket - if not await self._is_zonal_bucket(bucket): - return await super()._cat_file(path, start=start, end=end, **kwargs) - - await self._get_grpc_client() - mrd = await zb_hns_utils.init_mrd( - self.grpc_client, bucket, object_name, generation + return await super()._cat_file( + path, start=start, end=end, concurrency=concurrency, **kwargs ) - mrd_created = True - file_size = mrd.persisted_size + # Instantiate an MRDPool locally for this call + mrd = zb_hns_utils.MRDPool( + self, bucket, object_name, generation, pool_size=concurrency + ) + await mrd.initialize() + pool_created_here = True + + try: + file_size = await _get_mrd_size(mrd) if file_size is None: logger.warning( f"AsyncMultiRangeDownloader (MRD) for {path} has no 'persisted_size'. " "Falling back to _info() to get the file size. " "This may result in incorrect behavior for unfinalized objects." ) + file_size = (await self._info(path))["size"] + offset, length = await self._process_limits_to_offset_and_length( path, start, end, file_size ) - return await zb_hns_utils.download_range( - offset=offset, length=length, mrd=mrd + if length == 0: + return b"" + + return await self._concurrent_mrd_fetch( + offset, + length, + concurrency if length >= self.MIN_CHUNK_SIZE_FOR_CONCURRENCY else 1, + mrd, ) + finally: - # Explicit cleanup if we created the MRD - if mrd_created: - await zb_hns_utils.close_mrd(mrd) + # If we created a temporary pool specifically for this _cat_file call, clean it up + if pool_created_here: + await mrd.close() async def _is_bucket_hns_enabled(self, bucket): """Checks if a bucket has Hierarchical Namespace enabled.""" diff --git a/gcsfs/tests/test_extended_gcsfs.py b/gcsfs/tests/test_extended_gcsfs.py index a86074ff..3598a3a6 100644 --- a/gcsfs/tests/test_extended_gcsfs.py +++ b/gcsfs/tests/test_extended_gcsfs.py @@ -1,4 +1,5 @@ # Integration tests for ExtendedGcsFileSystem +import asyncio import contextlib import io import multiprocessing @@ -23,6 +24,8 @@ from gcsfs.extended_gcsfs import ( BucketType, ExtendedGcsFileSystem, + _get_mrd_from_pool_or_mrd, + _get_mrd_size, initiate_upload, simple_upload, upload_chunk, @@ -35,6 +38,7 @@ ) from gcsfs.tests.settings import TEST_BUCKET, TEST_ZONAL_BUCKET from gcsfs.tests.utils import tempdir, tmpfile +from gcsfs.zb_hns_utils import MRDPool file = "test/accounts.1.json" file_path = f"{TEST_ZONAL_BUCKET}/{file}" @@ -172,11 +176,15 @@ def test_read_block_zb(extended_gcsfs, gcs_bucket_mocks, subtests): ) if expected_data: - call_args = mocks["downloader"].download_ranges.call_args - assert call_args is not None, "download_ranges was not called" + call_args_list = mocks[ + "downloader" + ].download_ranges.await_args_list + assert call_args_list, "download_ranges was not called" - # Get the actual list of ranges passed: [(start, end, buffer), ...] - actual_ranges = call_args[0][0] + # Aggregate all the ranges passed across concurrent calls + actual_ranges = [] + for call in call_args_list: + actual_ranges.extend(call[0][0]) if delimiter: assert len(actual_ranges) >= 1 @@ -387,7 +395,7 @@ def test_multithreaded_read_disjoint_ranges_zb(extended_gcsfs, gcs_bucket_mocks) if mocks: assert mocks["create_mrd"].call_count == len(read_tasks) - assert mocks["downloader"].download_ranges.call_count == len(read_tasks) + assert mocks["downloader"].download_ranges.call_count >= len(read_tasks) assert mocks["downloader"].close.call_count == len(read_tasks) @@ -424,7 +432,7 @@ def test_multithreaded_read_overlapping_ranges_zb(extended_gcsfs, gcs_bucket_moc if mocks: assert mocks["create_mrd"].call_count == len(read_tasks) - assert mocks["downloader"].download_ranges.call_count == len(read_tasks) + assert mocks["downloader"].download_ranges.call_count >= len(read_tasks) assert mocks["downloader"].close.call_count == len(read_tasks) @@ -502,7 +510,7 @@ def test_multithreaded_read_chunk_boundary_zb(extended_gcsfs, gcs_bucket_mocks): if mocks: assert mocks["create_mrd"].call_count == len(read_tasks) - assert mocks["downloader"].download_ranges.call_count == len(read_tasks) + assert mocks["downloader"].download_ranges.call_count >= len(read_tasks) assert mocks["downloader"].close.call_count == len(read_tasks) @@ -545,7 +553,7 @@ def test_multithreaded_read_high_concurrency_zb(extended_gcsfs, gcs_bucket_mocks assert mocks["create_mrd"].call_count == _NUM_CONCURRENCY_THREADS assert ( mocks["downloader"].download_ranges.call_count - == _NUM_CONCURRENCY_THREADS + >= _NUM_CONCURRENCY_THREADS ) assert mocks["downloader"].close.call_count == _NUM_CONCURRENCY_THREADS @@ -629,7 +637,7 @@ def _run_task_and_store_result(task_idx, fs, path, offset, length): assert mocks["create_mrd"].call_count == _NUM_FAIL_SURVIVE_THREADS assert ( - mocks["downloader"].download_ranges.call_count == _NUM_FAIL_SURVIVE_THREADS + mocks["downloader"].download_ranges.call_count >= _NUM_FAIL_SURVIVE_THREADS ) assert mocks["downloader"].close.call_count == _NUM_FAIL_SURVIVE_THREADS @@ -1136,3 +1144,458 @@ async def mock_is_zonal(bucket): ) as mock_super_cp: await async_gcs._cp_file(source_path, dest_path) mock_super_cp.assert_awaited_once() + + +def test_read_block_zb(extended_gcsfs, gcs_bucket_mocks, subtests): + file_size = len( + json_data + ) # We need the file size to predict if readahead will trigger + + for param in read_block_params: + with subtests.test(id=param.id): + offset, length, delimiter, expected_data = param.values + path = file_path + + with gcs_bucket_mocks( + json_data, bucket_type_val=BucketType.ZONAL_HIERARCHICAL + ) as mocks: + result = extended_gcsfs.read_block(path, offset, length, delimiter) + + assert result == expected_data + + if mocks: + mocks["sync_lookup_bucket_type"].assert_called_once_with( + TEST_ZONAL_BUCKET + ) + + if expected_data: + call_args_list = mocks[ + "downloader" + ].download_ranges.await_args_list + assert call_args_list, "download_ranges was not called" + + # Aggregate all the ranges passed across concurrent calls + actual_ranges = [] + for call in call_args_list: + actual_ranges.extend(call[0][0]) + + if delimiter: + # fsspec dynamically calculates read block offsets when hunting + # for delimiters. We just assert that it requested ranges. + assert len(actual_ranges) >= 1 + else: + req_end = offset + length + if req_end >= file_size: + expected_chunks = 1 + else: + expected_chunks = 2 + + assert ( + len(actual_ranges) == expected_chunks + ), f"Expected {expected_chunks} chunks (Request + Readahead), got {len(actual_ranges)}" + assert actual_ranges[0][0] == offset + if len(actual_ranges) == 2: + assert actual_ranges[1][0] == offset + length + else: + mocks["downloader"].download_ranges.assert_not_called() + + +def test_mrd_stream_cleanup(extended_gcsfs, gcs_bucket_mocks): + """ + Tests that mrd stream is properly closed during file lifecycle. + """ + with gcs_bucket_mocks( + json_data, bucket_type_val=BucketType.ZONAL_HIERARCHICAL + ) as mocks: + if not extended_gcsfs.on_google: + + def close_side_effect(): + mocks["downloader"].is_stream_open = False + + mocks["downloader"].close.side_effect = close_side_effect + + with extended_gcsfs.open(file_path, "rb") as f: + # Triggering a read ensures the internal pool/mrd is used + f.read() + assert not f.closed + + assert f.closed + if mocks and not extended_gcsfs.on_google: + # Verify the downloader was properly shut down by the MRDPool context manager + mocks["downloader"].close.assert_awaited() + + +@pytest.mark.asyncio +async def test_concurrent_mrd_fetch_success(extended_gcsfs): + """Tests that _concurrent_mrd_fetch successfully downloads and stitches chunks.""" + # Add spec so isinstance() checks pass + mock_pool = mock.AsyncMock(spec=MRDPool) + mock_mrd = mock.AsyncMock(spec=AsyncMultiRangeDownloader) + + # Set up the context manager mock return value + mock_pool.get_mrd.return_value.__aenter__.return_value = mock_mrd + + async def fake_download(ranges): + for offset, length, buf in ranges: + buf.write(b"A" * length) + + mock_mrd.download_ranges.side_effect = fake_download + + result = await extended_gcsfs._concurrent_mrd_fetch( + offset=0, length=5 * 1024 * 1024, concurrency=4, mrd_or_pool=mock_pool + ) + + assert len(result) == 5 * 1024 * 1024 + assert result == b"A" * 5 * 1024 * 1024 + assert mock_mrd.download_ranges.call_count == 4 + + +@pytest.mark.asyncio +async def test_concurrent_mrd_fetch_exception_masking(extended_gcsfs): + """ + Tests that original exceptions in concurrent fetches are not masked by BufferErrors. + """ + # Add spec so isinstance() checks pass + mock_pool = mock.AsyncMock(spec=MRDPool) + mock_mrd = mock.AsyncMock(spec=AsyncMultiRangeDownloader) + + # Set up the context manager mock return value + mock_pool.get_mrd.return_value.__aenter__.return_value = mock_mrd + + call_count = 0 + + async def failing_download(ranges): + nonlocal call_count + call_count += 1 + if call_count == 2: + # Simulate a network drop on the second concurrent chunk + raise DataCorruption(None, "Simulated Network Drop") + for offset, length, buf in ranges: + buf.write(b"A" * length) + + mock_mrd.download_ranges.side_effect = failing_download + + with pytest.raises(DataCorruption, match="Simulated Network Drop"): + await extended_gcsfs._concurrent_mrd_fetch( + offset=0, length=5 * 1024 * 1024, concurrency=4, mrd_or_pool=mock_pool + ) + + +@pytest.mark.asyncio +async def test_get_mrd_from_pool_or_mrd_with_pool(): + """Tests yielding an MRD when an MRDPool is provided.""" + mock_mrd = mock.AsyncMock(spec=AsyncMultiRangeDownloader) + mock_pool = mock.AsyncMock(spec=MRDPool) + # Set up the context manager mock return value for get_mrd() + mock_pool.get_mrd.return_value.__aenter__.return_value = mock_mrd + + async with _get_mrd_from_pool_or_mrd(mock_pool) as mrd: + assert mrd is mock_mrd + + mock_pool.get_mrd.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_mrd_from_pool_or_mrd_with_mrd(): + """Tests yielding the MRD directly when a single AsyncMultiRangeDownloader is provided.""" + mock_mrd = mock.AsyncMock(spec=AsyncMultiRangeDownloader) + + async with _get_mrd_from_pool_or_mrd(mock_mrd) as mrd: + assert mrd is mock_mrd + + +@pytest.mark.asyncio +async def test_get_mrd_from_pool_or_mrd_invalid_type(): + """Tests that a TypeError is raised when an unsupported type is passed.""" + with pytest.raises( + TypeError, match="Expected MRDPool or AsyncMultiRangeDownloader" + ): + async with _get_mrd_from_pool_or_mrd("invalid_string_type") as _: + pass + + +@pytest.mark.asyncio +async def test_get_mrd_size_with_pool(): + """Tests extracting persisted_size from an MRDPool.""" + mock_mrd = mock.AsyncMock(spec=AsyncMultiRangeDownloader) + mock_mrd.persisted_size = 1024 + + mock_pool = mock.AsyncMock(spec=MRDPool) + mock_pool.get_mrd.return_value.__aenter__.return_value = mock_mrd + + size = await _get_mrd_size(mock_pool) + assert size == 1024 + + +@pytest.mark.asyncio +async def test_get_mrd_size_with_mrd(): + """Tests extracting persisted_size directly from an AsyncMultiRangeDownloader.""" + mock_mrd = mock.AsyncMock(spec=AsyncMultiRangeDownloader) + mock_mrd.persisted_size = 2048 + + size = await _get_mrd_size(mock_mrd) + assert size == 2048 + + +@pytest.mark.asyncio +async def test_fetch_range_split_out_of_bounds(extended_gcsfs): + """Tests that _fetch_range_split raises an error if requested chunks exceed file size.""" + with mock.patch.object( + extended_gcsfs, "_info", new_callable=mock.AsyncMock + ) as mock_info: + mock_info.return_value = {"size": 100} + + # Requesting 20 bytes starting at offset 90 exceeds the 100 byte file size + with pytest.raises(RuntimeError, match="Request not satisfiable"): + await extended_gcsfs._fetch_range_split( + "bucket/obj", start=90, chunk_lengths=[20], concurrency=1 + ) + + +@pytest.mark.asyncio +async def test_fetch_range_split_concurrent_success(extended_gcsfs): + """Tests MRDPool creation, cleanup, and concurrent _cat_file dispatching.""" + mock_pool = mock.AsyncMock() + + with contextlib.ExitStack() as stack: + stack.enter_context( + mock.patch( + "gcsfs.extended_gcsfs.zb_hns_utils.MRDPool", return_value=mock_pool + ) + ) + stack.enter_context( + mock.patch.object(extended_gcsfs, "_is_zonal_bucket", return_value=True) + ) + stack.enter_context( + mock.patch.object(extended_gcsfs, "_info", return_value={"size": 100}) + ) + + mock_cat = stack.enter_context( + mock.patch.object(extended_gcsfs, "_cat_file", new_callable=mock.AsyncMock) + ) + mock_cat.side_effect = [b"chunk1", b"chunk2"] + + # Fetch two chunks: 5 bytes and 15 bytes, starting at offset 10 + result = await extended_gcsfs._fetch_range_split( + "bucket/obj", start=10, chunk_lengths=[5, 15], concurrency=4 + ) + + assert result == [b"chunk1", b"chunk2"] + + # Pool should be initialized and closed safely + mock_pool.initialize.assert_awaited_once() + mock_pool.close.assert_awaited_once() + + # Verify _cat_file was called for each chunk with the correct bounds + assert mock_cat.call_count == 2 + + call_1_kwargs = mock_cat.call_args_list[0].kwargs + assert call_1_kwargs["start"] == 10 + assert call_1_kwargs["end"] == 15 + assert call_1_kwargs["concurrency"] == 1 # Ensures we don't nest thread pools + + call_2_kwargs = mock_cat.call_args_list[1].kwargs + assert call_2_kwargs["start"] == 15 + assert call_2_kwargs["end"] == 30 + assert call_2_kwargs["concurrency"] == 3 + + +@pytest.mark.asyncio +async def test_fetch_range_split_concurrent_exception(extended_gcsfs): + """Tests that exceptions in the concurrent tasks bubble up correctly.""" + mock_pool = mock.AsyncMock() + + with contextlib.ExitStack() as stack: + stack.enter_context( + mock.patch( + "gcsfs.extended_gcsfs.zb_hns_utils.MRDPool", return_value=mock_pool + ) + ) + stack.enter_context( + mock.patch.object(extended_gcsfs, "_is_zonal_bucket", return_value=True) + ) + stack.enter_context( + mock.patch.object(extended_gcsfs, "_info", return_value={"size": 100}) + ) + + mock_cat = stack.enter_context( + mock.patch.object(extended_gcsfs, "_cat_file", new_callable=mock.AsyncMock) + ) + + # Simulate a crash on one of the chunks + mock_cat.side_effect = [b"chunk1", DataCorruption(None, "Task failed")] + + with pytest.raises(DataCorruption, match="Task failed"): + await extended_gcsfs._fetch_range_split( + "bucket/obj", start=10, chunk_lengths=[5, 15], concurrency=4 + ) + + # Pool must still be closed even if gathering the tasks raises an exception + mock_pool.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_cat_file_zero_length_read(extended_gcsfs): + """Tests that _cat_file returns empty bytes and cleans up if length resolves to 0.""" + mock_pool = mock.AsyncMock() + + with contextlib.ExitStack() as stack: + stack.enter_context( + mock.patch( + "gcsfs.extended_gcsfs.zb_hns_utils.MRDPool", return_value=mock_pool + ) + ) + stack.enter_context( + mock.patch.object(extended_gcsfs, "_is_zonal_bucket", return_value=True) + ) + stack.enter_context( + mock.patch("gcsfs.extended_gcsfs._get_mrd_size", return_value=100) + ) + + stack.enter_context( + mock.patch.object( + extended_gcsfs, + "_process_limits_to_offset_and_length", + return_value=(50, 0), # Offset 50, Length 0 + ) + ) + mock_concurrent_fetch = stack.enter_context( + mock.patch.object( + extended_gcsfs, "_concurrent_mrd_fetch", new_callable=mock.AsyncMock + ) + ) + + result = await extended_gcsfs._cat_file("bucket/obj", start=50, end=50) + + assert result == b"" + + # It should exit early before fetching anything + mock_concurrent_fetch.assert_not_awaited() + + # Pool should still be initialized and immediately closed + mock_pool.initialize.assert_awaited_once() + mock_pool.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_cat_file_concurrency_threshold(extended_gcsfs): + """Tests that _cat_file passes concurrency=1 to the fetcher if under the chunk size threshold.""" + # Set a dummy threshold for the test + extended_gcsfs.MIN_CHUNK_SIZE_FOR_CONCURRENCY = 1000 + mock_pool = mock.AsyncMock() + + with contextlib.ExitStack() as stack: + stack.enter_context( + mock.patch( + "gcsfs.extended_gcsfs.zb_hns_utils.MRDPool", return_value=mock_pool + ) + ) + stack.enter_context( + mock.patch.object(extended_gcsfs, "_is_zonal_bucket", return_value=True) + ) + stack.enter_context( + mock.patch("gcsfs.extended_gcsfs._get_mrd_size", return_value=5000) + ) + + # Length 500 is less than threshold 1000 + stack.enter_context( + mock.patch.object( + extended_gcsfs, + "_process_limits_to_offset_and_length", + return_value=(0, 500), + ) + ) + mock_concurrent_fetch = stack.enter_context( + mock.patch.object( + extended_gcsfs, "_concurrent_mrd_fetch", new_callable=mock.AsyncMock + ) + ) + mock_concurrent_fetch.return_value = b"data" + + await extended_gcsfs._cat_file("bucket/obj", concurrency=4) + + # Because length (500) < MIN_CHUNK_SIZE_FOR_CONCURRENCY (1000), it should override concurrency to 1 + mock_concurrent_fetch.assert_awaited_once_with(0, 500, 1, mock_pool) + + +@pytest.mark.asyncio +async def test_concurrent_mrd_fetch_base_exception_cancellation(extended_gcsfs): + """Tests that pending tasks are cancelled if a BaseException occurs.""" + mock_pool = mock.AsyncMock(spec=MRDPool) + + # Create fake tasks to track cancellation + mock_task1 = mock.Mock(spec=asyncio.Task) + mock_task1.done.return_value = False # Pending, should be cancelled + + mock_task2 = mock.Mock(spec=asyncio.Task) + mock_task2.done.return_value = True # Finished, should NOT be cancelled + + with contextlib.ExitStack() as stack: + stack.enter_context( + mock.patch("asyncio.create_task", side_effect=[mock_task1, mock_task2]) + ) + mock_gather = stack.enter_context( + mock.patch("asyncio.gather", new_callable=mock.AsyncMock) + ) + # Force the BaseException path + mock_gather.side_effect = KeyboardInterrupt() + + with pytest.raises(KeyboardInterrupt): + await extended_gcsfs._concurrent_mrd_fetch( + offset=0, length=1024, concurrency=2, mrd_or_pool=mock_pool + ) + + # Assert exactly the logic: if not t.done(): t.cancel() + mock_task1.cancel.assert_called_once() + mock_task2.cancel.assert_not_called() + + +@pytest.mark.asyncio +async def test_concurrent_mrd_fetch_buffer_error_surfaced(extended_gcsfs): + """Tests that BufferError is surfaced if tasks succeed but buffers are underfilled.""" + mock_pool = mock.AsyncMock(spec=MRDPool) + mock_mrd = mock.AsyncMock(spec=AsyncMultiRangeDownloader) + mock_pool.get_mrd.return_value.__aenter__.return_value = mock_mrd + + async def underfilling_download(ranges): + for offset, length, buf in ranges: + # We intentionally write 1 byte LESS than requested. + # This causes no exception during the gather block (has_error = False), + # but triggers a BufferError when buf.close() is called. + buf.write(b"A" * (length - 1)) + + mock_mrd.download_ranges.side_effect = underfilling_download + + # We expect the BufferError to be raised up to the caller + with pytest.raises(BufferError, match="Buffer contains uninitialized data"): + await extended_gcsfs._concurrent_mrd_fetch( + offset=0, length=1024, concurrency=1, mrd_or_pool=mock_pool + ) + + +@pytest.mark.asyncio +async def test_cat_file_non_zonal_fallback(extended_gcsfs): + """Tests that _cat_file delegates to the parent class for non-zonal buckets.""" + with contextlib.ExitStack() as stack: + stack.enter_context( + mock.patch.object(extended_gcsfs, "_is_zonal_bucket", return_value=False) + ) + mock_super_cat = stack.enter_context( + mock.patch( + "gcsfs.core.GCSFileSystem._cat_file", new_callable=mock.AsyncMock + ) + ) + mock_super_cat.return_value = b"standard_bucket_data" + + # Call with arbitrary arguments + result = await extended_gcsfs._cat_file( + "standard_bucket/obj", start=10, end=20, concurrency=2, custom_arg="val" + ) + + # Assert data matches superclass return value + assert result == b"standard_bucket_data" + + # Assert the super method was given the exact arguments + mock_super_cat.assert_awaited_once_with( + "standard_bucket/obj", start=10, end=20, concurrency=2, custom_arg="val" + ) diff --git a/gcsfs/tests/test_extended_gcsfs_unit.py b/gcsfs/tests/test_extended_gcsfs_unit.py index ca588dff..2ba17288 100644 --- a/gcsfs/tests/test_extended_gcsfs_unit.py +++ b/gcsfs/tests/test_extended_gcsfs_unit.py @@ -201,7 +201,7 @@ def test_mrd_exception_handling(extended_gcsfs, gcs_bucket_mocks, exception_to_r with pytest.raises(exception_to_raise, match="Test exception raised"): extended_gcsfs.read_block(file_path, 0, 10) - mocks["downloader"].download_ranges.assert_called_once() + mocks["downloader"].download_ranges.call_count = 2 def test_mrd_created_once_for_zonal_file(extended_gcsfs, gcs_bucket_mocks): diff --git a/gcsfs/tests/test_zb_hns_utils.py b/gcsfs/tests/test_zb_hns_utils.py index 18d34d86..4733116f 100644 --- a/gcsfs/tests/test_zb_hns_utils.py +++ b/gcsfs/tests/test_zb_hns_utils.py @@ -1,3 +1,5 @@ +import concurrent.futures +import ctypes import logging from unittest import mock @@ -5,6 +7,7 @@ from google.api_core.exceptions import NotFound from gcsfs import zb_hns_utils +from gcsfs.zb_hns_utils import DirectMemmoveBuffer, MRDPool mock_grpc_client = mock.Mock() bucket_name = "test-bucket" @@ -256,3 +259,323 @@ async def test_download_ranges_validation_limit(): match="Invalid input - number of ranges cannot be more than 1000", ): await zb_hns_utils.download_ranges(ranges, mock_mrd) + + +@pytest.mark.asyncio +async def test_mrd_pool_close(): + gcsfs_mock = mock.Mock() + gcsfs_mock._get_grpc_client = mock.AsyncMock() + + mrd_instance_mock = mock.AsyncMock() + + with mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.create_mrd", + return_value=mrd_instance_mock, + ): + pool = MRDPool(gcsfs_mock, "bucket", "obj", "123", pool_size=1) + await pool.initialize() + + await pool.close() + mrd_instance_mock.close.assert_awaited_once() + assert len(pool._all_mrds) == 0 + + +@pytest.fixture +def mock_gcsfs(): + gcsfs_mock = mock.Mock() + gcsfs_mock._get_grpc_client = mock.AsyncMock() + return gcsfs_mock + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.create_mrd", + new_callable=mock.AsyncMock, +) +async def test_mrd_pool_scaling(create_mrd_mock, mock_gcsfs): + mrd_instance_mock = mock.AsyncMock() + mrd_instance_mock.persisted_size = 1024 + create_mrd_mock.return_value = mrd_instance_mock + + pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=2) + + await pool.initialize() + assert pool.persisted_size == 1024 + assert pool._active_count == 1 + create_mrd_mock.assert_awaited_once() + + async with pool.get_mrd() as mrd1: + assert mrd1 == mrd_instance_mock + + # Since mrd1 is in use, getting another one should spawn a new MRD + async with pool.get_mrd() as _: + assert pool._active_count == 2 + assert create_mrd_mock.call_count == 2 + + # Both should have been returned to the free queue + assert pool._free_mrds.qsize() == 2 + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.create_mrd", + new_callable=mock.AsyncMock, +) +async def test_mrd_pool_double_initialize(create_mrd_mock, mock_gcsfs): + pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=2) + + await pool.initialize() + await pool.initialize() # Second call should be a no-op + + assert pool._active_count == 1 + create_mrd_mock.assert_awaited_once() + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.create_mrd", + new_callable=mock.AsyncMock, +) +async def test_mrd_pool_get_mrd_creation_error(create_mrd_mock, mock_gcsfs): + # First creation succeeds during initialization + valid_mrd = mock.AsyncMock() + + # Second creation fails when pool tries to scale + create_mrd_mock.side_effect = [valid_mrd, Exception("Network Error")] + + pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=2) + await pool.initialize() + + # Consume the initialized MRD + async def consume_and_error(): + async with pool.get_mrd() as _: + # Try to get a second one, which forces a spawn that will fail + with pytest.raises(Exception, match="Network Error"): + async with pool.get_mrd() as _: + pass + + await consume_and_error() + + # Active count should remain 1 because the second creation failed and rolled back + assert pool._active_count == 1 + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.create_mrd", + new_callable=mock.AsyncMock, +) +async def test_mrd_pool_close_with_exceptions(create_mrd_mock, mock_gcsfs): + bad_mrd_instance = mock.AsyncMock() + bad_mrd_instance.close.side_effect = RuntimeError("Close failed") + create_mrd_mock.return_value = bad_mrd_instance + + pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=1) + await pool.initialize() + + with pytest.raises(RuntimeError, match="Close failed"): + await pool.close() + + bad_mrd_instance.close.assert_awaited_once() + assert len(pool._all_mrds) == 0 + + +@mock.patch("gcsfs.zb_hns_utils.ctypes.memmove") +def test_direct_memmove_buffer_error_handling(mock_memmove): + size = 20 + buffer_array = (ctypes.c_char * size)() + start_address = ctypes.addressof(buffer_array) + end_address = start_address + size + + # Simulate an access violation or similar error during memory copy + mock_memmove.side_effect = MemoryError("Segfault simulated") + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) + + # First write triggers the background error + future = buf.write(b"bad data") + + # Wait for the background thread to actually fail + with pytest.raises(MemoryError): + future.result() + + # Subsequent writes should raise the stored error immediately + with pytest.raises(MemoryError, match="Segfault simulated"): + buf.write(b"more data") + + # Close should also raise the stored error. + with pytest.raises(MemoryError, match="Segfault simulated"): + buf.close() + + executor.shutdown() + + +def test_direct_memmove_buffer(): + data1 = b"hello" + data2 = b"world" + + # Calculate exact size to prevent the new underflow check from failing + size = len(data1) + len(data2) + buffer_array = (ctypes.c_char * size)() + start_address = ctypes.addressof(buffer_array) + end_address = start_address + size + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) + + future1 = buf.write(data1) + future2 = buf.write(data2) + + future1.result() + future2.result() + buf.close() + + result_bytes = ctypes.string_at(start_address, len(data1) + len(data2)) + assert result_bytes == b"helloworld" + + executor.shutdown() + + +def test_direct_memmove_buffer_overflow(): + """Tests that writing past the allocated end_address raises a BufferError.""" + size = 10 + buffer_array = (ctypes.c_char * size)() + start_address = ctypes.addressof(buffer_array) + end_address = start_address + size + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) + + # Fill the buffer exactly to capacity + buf.write(b"1234567890") + + # Attempting to write even 1 more byte should trigger the overflow protection + with pytest.raises(BufferError, match="Attempted to write"): + buf.write(b"1") + + buf.close() + executor.shutdown() + + +def test_direct_memmove_buffer_underflow(): + """Tests that closing an incompletely filled buffer raises a BufferError.""" + size = 10 + buffer_array = (ctypes.c_char * size)() + start_address = ctypes.addressof(buffer_array) + end_address = start_address + size + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) + + # Write fewer bytes than the expected capacity + buf.write(b"12345") + + # Closing should detect that current_offset (5) < expected size (10) + with pytest.raises(BufferError, match="Buffer contains uninitialized data"): + buf.close() + + executor.shutdown() + + +@pytest.mark.asyncio +async def test_mrd_pool_queue_filled_during_lock_wait(mock_gcsfs): + pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=1) + mrd_mock = mock.AsyncMock() + + # Simulate _create_mrd so we correctly populate _all_mrds + async def fake_create_mrd(): + pool._all_mrds.append(mrd_mock) + return mrd_mock + + with mock.patch.object(pool, "_create_mrd", side_effect=fake_create_mrd): + await pool.initialize() + + side_effects = [True] + [False] * 10 + with mock.patch.object(pool._free_mrds, "empty", side_effect=side_effects): + async with pool.get_mrd() as mrd: + assert mrd == mrd_mock + + # We should not have spawned a new MRD + assert pool._active_count == 1 + + +@pytest.mark.asyncio +async def test_mrd_pool_round_robin_multi_request(mock_gcsfs): + pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=2) + mrd1 = mock.AsyncMock() + mrd2 = mock.AsyncMock() + + mrd_mocks = [mrd1, mrd2] + + # Ensure our mock actually appends to _all_mrds so the round-robin + # logic sees that there are available active MRDs to share. + async def fake_create_mrd(): + mrd = mrd_mocks.pop(0) + pool._all_mrds.append(mrd) + return mrd + + # Enable the multi-request feature manually for this test + pool.mrd_supports_multi_request = True + + with mock.patch.object(pool, "_create_mrd", side_effect=fake_create_mrd): + await pool.initialize() + + # Keep both MRDs checked out to force the pool to its maximum size + # and keep the free queue empty. + async with pool.get_mrd() as active_mrd1: + async with pool.get_mrd() as active_mrd2: + assert active_mrd1 == mrd1 + assert active_mrd2 == mrd2 + assert pool._free_mrds.empty() + assert pool._active_count == 2 + assert pool._rr_index == 0 + + # Requesting a 3rd MRD should trigger the round-robin logic + async with pool.get_mrd() as shared_mrd1: + assert shared_mrd1 == mrd1 + assert pool._rr_index == 1 + + # Requesting a 4th MRD should continue the round-robin + async with pool.get_mrd() as shared_mrd2: + assert shared_mrd2 == mrd2 + assert pool._rr_index == 0 + + # Requesting a 5th MRD should wrap around back to the first + async with pool.get_mrd() as shared_mrd3: + assert shared_mrd3 == mrd1 + assert pool._rr_index == 1 + + +@mock.patch("gcsfs.zb_hns_utils.ctypes.memmove") +def test_direct_memmove_buffer_submit_failure(mock_memmove): + """ + Tests that if executor.submit fails synchronously (e.g., executor is closed), + the internal locks, semaphores, and events are properly reset, and close() + does not hang. + """ + size = 10 + buffer_array = (ctypes.c_char * size)() + start_address = ctypes.addressof(buffer_array) + end_address = start_address + size + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) + + # Mock the submit method to simulate a closed executor throwing a RuntimeError + with mock.patch.object( + executor, "submit", side_effect=RuntimeError("Executor closed") + ): + # The write operation should raise the simulated RuntimeError + with pytest.raises(RuntimeError, match="Executor closed"): + buf.write(b"12345") + + # Verify that the internal tracking state was correctly rolled back + assert buf._pending_count == 0 + assert buf._done_event.is_set() + + # Calling close() should NOT hang. It should immediately raise the stored error. + with pytest.raises(RuntimeError, match="Executor closed"): + buf.close() + + executor.shutdown() diff --git a/gcsfs/tests/test_zonal_file.py b/gcsfs/tests/test_zonal_file.py index 454b484f..008cbad3 100644 --- a/gcsfs/tests/test_zonal_file.py +++ b/gcsfs/tests/test_zonal_file.py @@ -10,6 +10,7 @@ from gcsfs.tests.settings import TEST_ZONAL_BUCKET from gcsfs.tests.utils import tempdir, tmpfile +from gcsfs.zonal_file import ZonalFile test_data = b"hello world" @@ -25,6 +26,16 @@ ) +@pytest.fixture +def mock_gcsfs(): + fs = mock.Mock() + fs._split_path.return_value = ("test-bucket", "test-key", "123") + fs.split_path.return_value = ("test-bucket", "test-key", "123") + fs.info.return_value = {"size": 1000, "generation": "123", "name": "test-key"} + fs.loop = mock.Mock() + return fs + + @pytest.mark.parametrize( "setup_action, error_match", [ @@ -479,3 +490,186 @@ def test_pipe_overwrite_in_zonal_bucket(self, extended_gcsfs, file_path): extended_gcsfs.pipe(remote_path, overwrite_data, finalize_on_close=True) assert extended_gcsfs.cat(remote_path) == overwrite_data + + +def test_zonal_file_fetch_range_without_prefetch_engine(mock_gcsfs): + """Tests _fetch_range routing to the gcsfs underlying methods when no prefetch engine exists.""" + + # We need a custom fake sync function to actually execute the inner `_do_fetch` + # coroutine so we can assert the routing logic inside of it. + def fake_sync(loop, func, *args, **kwargs): + import asyncio + import inspect + + res = func(*args, **kwargs) + if inspect.iscoroutine(res): + return asyncio.run(res) + return res + + with mock.patch("gcsfs.zonal_file.asyn.sync", side_effect=fake_sync): + # We patch MRDPool.initialize specifically so ZonalFile.__init__ doesn't crash + # trying to hit a non-existent gRPC client during object creation. + with mock.patch( + "gcsfs.zb_hns_utils.MRDPool.initialize", new_callable=mock.AsyncMock + ): + zf = ZonalFile( + gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb" + ) + + zf._prefetch_engine = None # Ensure it's bypassed + + # Explicitly initialize these as AsyncMocks so they can be awaited safely + mock_gcsfs._fetch_range_split = mock.AsyncMock(return_value=[b"split_data"]) + mock_gcsfs._cat_file = mock.AsyncMock(return_value=b"cat_data") + + result = zf._fetch_range(start=10, chunk_lengths=[5]) + + assert result == [b"split_data"] + mock_gcsfs._fetch_range_split.assert_awaited_once_with( + zf.path, + concurrency=1, + start=10, + chunk_lengths=[5], + size=zf.size, + mrd=zf.mrd_pool, + ) + + result = zf._fetch_range(start=10, end=20) + + assert result == b"cat_data" + mock_gcsfs._cat_file.assert_awaited_once_with( + zf.path, start=10, end=20, concurrency=zf.pool_size, mrd=zf.mrd_pool + ) + + # Test catch of "not satisfiable" + mock_gcsfs._cat_file.side_effect = RuntimeError("not satisfiable") + result = zf._fetch_range(start=10, end=20) + assert result == b"" + + zf.closed = True + + +@mock.patch("gcsfs.zonal_file.asyn.sync") +@mock.patch("gcsfs.zb_hns_utils.MRDPool") +@pytest.mark.asyncio +async def test_zonal_file_async_fetch_range(mock_mrd_pool, mock_sync, mock_gcsfs): + """Tests the native coroutine called by the BackgroundPrefetcher.""" + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + mock_gcsfs._concurrent_mrd_fetch = mock.AsyncMock(return_value=b"async data") + result = await zf._async_fetch_range(start_offset=0, total_size=100, split_factor=2) + assert result == b"async data" + mock_gcsfs._concurrent_mrd_fetch.assert_awaited_once_with(0, 100, 2, zf.mrd_pool) + zf.close() + + +@mock.patch("gcsfs.zonal_file.asyn.sync") +@mock.patch("gcsfs.zb_hns_utils.MRDPool") +def test_zonal_file_fetch_range_with_prefetch_engine( + mock_mrd_pool, mock_sync, mock_gcsfs +): + """Tests _fetch_range routing through the prefetch engine.""" + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + mock_engine = mock.Mock() + zf._prefetch_engine = mock_engine + + mock_engine._fetch.return_value = b"all_data" + result = zf._fetch_range(start=0, end=10) + assert result == b"all_data" + mock_engine._fetch.assert_called_once_with(0, 10) + + mock_engine.reset_mock() + mock_engine._fetch.side_effect = [b"chunk1", b"chunk2"] + result = zf._fetch_range(start=0, chunk_lengths=[6, 6]) + assert result == [b"chunk1", b"chunk2"] + mock_engine._fetch.assert_has_calls([mock.call(0, 6), mock.call(6, 12)]) + + mock_engine.reset_mock() + mock_engine._fetch.side_effect = None + mock_engine._fetch.return_value = b"short" + + result = zf._fetch_range(start=0, chunk_lengths=[10]) + assert result == [b""] + zf.close() + + +@mock.patch("gcsfs.zonal_file.asyn.sync") +@mock.patch("gcsfs.zb_hns_utils.MRDPool") +def test_zonal_file_pool_size_initialization(mock_mrd_pool, mock_sync, mock_gcsfs): + """Tests that pool_size is correctly set based on kwargs and env vars.""" + zf1 = ZonalFile( + gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb", pool_size=10 + ) + assert zf1.pool_size == 10 + zf1.closed = True + + zf2 = ZonalFile( + gcsfs=mock_gcsfs, + path="gs://test-bucket/test-key", + mode="rb", + use_experimental_adaptive_prefetching=True, + ) + assert zf2.pool_size == 1 + assert zf2._prefetch_engine is not None + zf2.closed = True + + zf3 = ZonalFile( + gcsfs=mock_gcsfs, + path="gs://test-bucket/test-key", + mode="rb", + use_experimental_adaptive_prefetching=False, + ) + assert zf3.pool_size == 1 + assert zf3._prefetch_engine is None + zf3.closed = True + + +@mock.patch("gcsfs.zonal_file.asyn.sync") +@mock.patch("gcsfs.zb_hns_utils.MRDPool") +def test_zonal_file_fetch_range_mutually_exclusive( + mock_mrd_pool, mock_sync, mock_gcsfs +): + """Tests that providing both end and chunk_lengths raises a ValueError.""" + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + with pytest.raises( + ValueError, match="mutually exclusive and cannot be used together" + ): + zf._fetch_range(start=0, end=10, chunk_lengths=[10]) + zf.close() + + +@mock.patch("gcsfs.zonal_file.asyn.sync") +@mock.patch("gcsfs.zb_hns_utils.MRDPool") +def test_zonal_file_close_cleans_up_new_pools(mock_mrd_pool, mock_sync, mock_gcsfs): + """Tests that close() properly tears down the prefetch engine and MRD pool using hasattr.""" + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + mock_engine = mock.Mock() + zf._prefetch_engine = mock_engine + mock_pool = mock.Mock() + zf.mrd_pool = mock_pool + zf.close() + + mock_engine.close.assert_called_once() + expected_call = mock.call(mock_gcsfs.loop, mock_pool.close) + assert expected_call in mock_sync.call_args_list + + +@mock.patch("gcsfs.zonal_file.asyn.sync") +@mock.patch("gcsfs.zb_hns_utils.MRDPool") +def test_zonal_file_fetch_range_unhandled_runtime_error( + mock_mrd_pool, mock_sync, mock_gcsfs +): + """Tests that a RuntimeError not containing 'not satisfiable' is re-raised.""" + zf = ZonalFile(gcsfs=mock_gcsfs, path="gs://test-bucket/test-key", mode="rb") + mock_engine = mock.Mock() + zf._prefetch_engine = mock_engine + mock_engine._fetch.side_effect = RuntimeError( + "A completely different error occurred" + ) + + with pytest.raises(RuntimeError, match="A completely different error occurred"): + zf._fetch_range(start=0, end=10) + + with pytest.raises(RuntimeError, match="A completely different error occurred"): + zf._fetch_range(start=0, chunk_lengths=[10]) + + zf.close() diff --git a/gcsfs/zb_hns_utils.py b/gcsfs/zb_hns_utils.py index 30f1c2c1..1b85005d 100644 --- a/gcsfs/zb_hns_utils.py +++ b/gcsfs/zb_hns_utils.py @@ -1,5 +1,9 @@ +import asyncio +import contextlib +import ctypes import logging import os +import threading from io import BytesIO from google.api_core.exceptions import NotFound @@ -17,6 +21,15 @@ logger = logging.getLogger("gcsfs") +PyBytes_FromStringAndSize = ctypes.pythonapi.PyBytes_FromStringAndSize +PyBytes_FromStringAndSize.argtypes = (ctypes.c_void_p, ctypes.c_ssize_t) +PyBytes_FromStringAndSize.restype = ctypes.py_object + +PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString +PyBytes_AsString.argtypes = (ctypes.py_object,) +PyBytes_AsString.restype = ctypes.c_void_p + + async def init_mrd(grpc_client, bucket_name, object_name, generation=None): """ Creates the AsyncMultiRangeDownloader using an existing client. @@ -163,3 +176,273 @@ async def close_aaow(aaow, finalize_on_close=False): logger.warning( f"Error closing AsyncAppendableObjectWriter for {aaow.bucket_name}/{aaow.object_name}: {e}" ) + + +class DirectMemmoveBuffer: + """ + A buffer-like object that writes data directly to memory asynchronously. + + This class provides a `write` interface that queues `ctypes.memmove` operations + to a thread pool executor, limiting the maximum number of concurrent pending + writes using a semaphore. It is useful for high-performance data transfers + where memory copies need to be offloaded from the main thread. + """ + + def __init__(self, start_address, end_address, executor, max_pending=5): + """ + Initializes the DirectMemmoveBuffer. + + Args: + start_address (int): The starting memory address where data will be written. + end_address (int): The absolute ending memory address. Writes exceeding + this boundary will be rejected to prevent overflows. + executor (concurrent.futures.Executor): The thread pool executor to run the + memmove operations. The lifecycle of this executor is managed by the caller. + max_pending (int, optional): The maximum number of pending write operations + allowed in the queue. Defaults to 5. + """ + self.start_address = start_address + self.end_address = end_address + self.executor = executor + + # Volatile state variables. Must only be amended while holding self._lock. + self.current_offset = 0 + self._pending_count = 0 + self._error = None + + # Primitives: + # 1. semaphore: Provides backpressure by limiting the number of active tasks. + # 2. _lock: Protects mutations to the volatile state variables above. + # 3. _done_event: Signals when the queue of active background tasks reaches zero. + self.semaphore = threading.Semaphore(max_pending) + self._lock = threading.Lock() + self._done_event = threading.Event() + self._done_event.set() + + def _decrement_pending(self): + """Helper to cleanly release concurrency primitives after a task finishes.""" + self.semaphore.release() + with self._lock: + self._pending_count -= 1 + if self._pending_count == 0: + self._done_event.set() + + def write(self, data): + """ + Schedules a write operation to memory. + + Calculates the destination address based on the current offset, increments the offset, + and submits the memory move operation to the executor. Blocks if the number of + pending operations reaches `max_pending`. + + Args: + data: The data to be written to memory. Must support the buffer protocol. + + Returns: + concurrent.futures.Future: A future object representing the execution of the + memory move operation. + + Raises: + Exception: If any previous asynchronous write operation encountered an error. + BufferError: If the write exceeds the allocated memory boundaries. + """ + if self._error: + raise self._error + + size = len(data) + with self._lock: + dest = self.start_address + self.current_offset + if dest + size > self.end_address: + error_msg = ( + f"Attempted to write {size} bytes " + f"at offset {self.current_offset}. " + f"Max capacity is {self.end_address - self.start_address} bytes." + ) + raise BufferError(error_msg) + + self.current_offset += size + data_bytes = bytes(data) if not isinstance(data, bytes) else data + + self.semaphore.acquire() + with self._lock: + if self._pending_count == 0: + self._done_event.clear() + self._pending_count += 1 + + try: + return self.executor.submit(self._do_memmove, dest, data_bytes, size) + except BaseException as e: + self._error = e + self._decrement_pending() + raise e + + def _do_memmove(self, dest, data_bytes, size): + try: + ctypes.memmove(dest, data_bytes, size) + except Exception as e: + self._error = e + raise e + finally: + self._decrement_pending() + + def close(self): + """ + Waits for all pending write operations to complete and checks for errors. + Blocks the calling thread until the queue of memory operations is entirely + processed. + + Raises: + Exception: If any background write operation failed during execution. + BufferError: If the buffer was not filled to the expected capacity. + """ + self._done_event.wait() + if self._error: + raise self._error + + expected_size = self.end_address - self.start_address + if self.current_offset < expected_size: + error_msg = ( + f"Expected {expected_size} bytes, " + f"but only received {self.current_offset} bytes. " + f"Buffer contains uninitialized data." + ) + raise BufferError(error_msg) + + +class MRDPool: + """Manages a pool of AsyncMultiRangeDownloader objects with on-demand scaling.""" + + def __init__( + self, + gcsfs, + bucket_name, + object_name, + generation, + pool_size, + ): + """ + Initializes the MRDPool. + + Args: + gcsfs (gcsfs.GCSFileSystem): The GCS filesystem client used for the downloads. + bucket_name (str): The name of the GCS bucket. + object_name (str): The target object/blob name in the bucket. + generation (int or str): The specific generation of the GCS object to download. + pool_size (int): The maximum number of concurrent downloaders allowed in the pool. + """ + self.gcsfs = gcsfs + self.bucket_name = bucket_name + self.object_name = object_name + self.generation = generation + self.pool_size = pool_size + self._free_mrds = asyncio.Queue(maxsize=pool_size) + self._active_count = 0 + self._lock = asyncio.Lock() + self.persisted_size = None + self._initialized = False + self._closed = False + + self._all_mrds = [] + self._rr_index = 0 + self.mrd_supports_multi_request = ( + False # Change this to true once mrd supports concurrent reuqests. + ) + + async def _create_mrd(self): + await self.gcsfs._get_grpc_client() + mrd = await init_mrd( + self.gcsfs.grpc_client, self.bucket_name, self.object_name, self.generation + ) + self._all_mrds.append(mrd) + return mrd + + async def initialize(self): + """Initializes the MRDPool by creating the first downloader instance.""" + async with self._lock: + + if self._closed: + raise RuntimeError("Cannot initialize a closed MRDPool.") + + if not self._initialized and self._active_count == 0: + mrd = await self._create_mrd() + self.persisted_size = mrd.persisted_size + self._free_mrds.put_nowait(mrd) + self._active_count += 1 + + self._initialized = True + + @contextlib.asynccontextmanager + async def get_mrd(self): + """ + Dynamically provisions MRDs using an async context manager. + + If a downloader is available in the pool, it is yielded immediately. If the + pool is empty but hasn't reached `pool_size`, a new downloader is spawned + on demand. Automatically returns thec downloader to the free queue upon exit. + + Yields: + AsyncMultiRangeDownloader: An active downloader ready for requests. + + Raises: + Exception: Bubbles up any exceptions encountered during MRD creation. + """ + create_new = False + used_from_queue = False + mrd = None + + async with self._lock: + if self._closed: + raise RuntimeError("MRDPool is closed.") + + if self._free_mrds.empty(): + if self._active_count < self.pool_size: + self._active_count += 1 + create_new = True + elif self.mrd_supports_multi_request and self._all_mrds: + # Pool is full, queue is empty, and we are allowed to share a busy MRD. + # Get the mrd in round robin fasion. + mrd = self._all_mrds[self._rr_index] + self._rr_index = (self._rr_index + 1) % len(self._all_mrds) + + if create_new: + try: + mrd = await self._create_mrd() + except BaseException as e: + self._active_count -= 1 + raise e + elif mrd is None: + # We did not spawn a new one and we did not grab one via round-robin. + # This means we must wait for a free one from the queue. + mrd = await self._free_mrds.get() + used_from_queue = True + + try: + yield mrd + finally: + # Only return the MRD to the free queue if we were the ones who took it out + # or if we just spawned it. This prevents duplicate entries in the queue + # when multiple concurrent tasks share the same MRD via round-robin. + if (create_new or used_from_queue) and not self._closed: + self._free_mrds.put_nowait(mrd) + + async def close(self): + """ + Cleanly shut down all MRDs. + + Iterates through all instantiated downloaders and calls their close methods + """ + async with self._lock: + if self._closed: + return + + tasks = [] + for mrd in self._all_mrds: + tasks.append(mrd.close()) + try: + results = await asyncio.gather(*tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + raise result + finally: + self._all_mrds.clear() + self._closed = True diff --git a/gcsfs/zonal_file.py b/gcsfs/zonal_file.py index 96f14cc3..34edf924 100644 --- a/gcsfs/zonal_file.py +++ b/gcsfs/zonal_file.py @@ -38,6 +38,7 @@ def __init__( fixed_key_metadata=None, generation=None, kms_key_name=None, + pool_size=zb_hns_utils.DEFAULT_CONCURRENCY, finalize_on_close=False, flush_interval_bytes=_DEFAULT_FLUSH_INTERVAL_BYTES, **kwargs, @@ -59,19 +60,21 @@ def __init__( bucket, key, generation = gcsfs._split_path(path) if not key: raise OSError("Attempt to open a bucket") - self.mrd = None self.aaow = None self.finalize_on_close = finalize_on_close self.finalized = False self.mode = mode self.flush_interval_bytes = flush_interval_bytes self.gcsfs = gcsfs + self.pool_size = pool_size object_size = None if "r" in self.mode: - self.mrd = asyn.sync( - self.gcsfs.loop, self._init_mrd, bucket, key, generation + self.mrd_pool = zb_hns_utils.MRDPool( + self.gcsfs, bucket, key, generation, self.pool_size ) - object_size = self.mrd.persisted_size + asyn.sync(self.gcsfs.loop, self.mrd_pool.initialize) + object_size = self.mrd_pool.persisted_size + if object_size is None: logger.warning( "AsyncMultiRangeDownloader (MRD) exists but has no 'persisted_size'. " @@ -151,34 +154,94 @@ def _ensure_aaow(self): self.flush_interval_bytes, ) - def _fetch_range(self, start=None, end=None, chunk_lengths=None): + def _fetch_range( + self, + start: int | None = None, + end: int | None = None, + chunk_lengths: list[int] | None = None, + ): """ Overrides the default _fetch_range to implement the gRPC read path. - See super() class for documentation. + Args: + start: The start offset for requested bytes (included). + end: The end offset for requested bytes (excluded). + chunk_lengths: A list of integers specifying the sizes of sequential chunks to read + starting from the start offset. This cannot be used at the same time as the end parameter. + + Returns: + A single bytes object if chunk_lengths is None, or a list of bytes objects corresponding + to the requested chunk sizes. If the range cannot be satisfied, it returns empty bytes + or a list with empty bytes. + + Raises: + ValueError: If both end and chunk_lengths are provided. + RuntimeError: If an underlying fetch operation fails for an unexpected reason. """ if end is not None and chunk_lengths is not None: raise ValueError( "The end and chunk_lengths arguments are mutually exclusive and cannot be used together." ) - try: + if self._prefetch_engine: + # This block is basically where caches and prefetch engines may overlap. + # We plan to remove this behaviour in future. + + try: + if chunk_lengths is None: + return self._prefetch_engine._fetch(start, end) + + # Fetch chunks sequentially through the prefetch engine + # Spawning concurrent task is worst here, because that would act as seek for prefetcher. + results = [] + current_offset = start if start is not None else 0 + for length in chunk_lengths: + data = self._prefetch_engine._fetch( + current_offset, current_offset + length + ) + results.append(data) + current_offset += length + if length != len(data): + raise RuntimeError("not satisfiable") + return results + except RuntimeError as e: + if "not satisfiable" in str(e): + return b"" if chunk_lengths is None else [b""] + raise + + # non-prefetch route + async def _do_fetch(): if chunk_lengths is not None: - return asyn.sync( - self.fs.loop, - self.gcsfs._fetch_range_split, + return await self.gcsfs._fetch_range_split( self.path, + concurrency=self.concurrency, start=start, chunk_lengths=chunk_lengths, size=self.size, - mrd=self.mrd, + mrd=self.mrd_pool, ) - return self.gcsfs.cat_file(self.path, start=start, end=end, mrd=self.mrd) + + return await self.gcsfs._cat_file( + self.path, + start=start, + end=end, + concurrency=self.concurrency, + mrd=self.mrd_pool, + ) + + try: + return asyn.sync(self.fs.loop, _do_fetch) except RuntimeError as e: if "not satisfiable" in str(e): return b"" if chunk_lengths is None else [b""] raise + async def _async_fetch_range(self, start_offset, total_size, split_factor=1): + """The native coroutine called by the BackgroundPrefetcher.""" + return await self.gcsfs._concurrent_mrd_fetch( + start_offset, total_size, split_factor, self.mrd_pool + ) + def write(self, data): """ Writes data using AsyncAppendableObjectWriter. @@ -311,10 +374,12 @@ def close(self): """ if self.closed: return + # super is closed before aaow since flush may need aaow super().close() - # Helper method safely handles mrd=None. - asyn.sync(self.gcsfs.loop, zb_hns_utils.close_mrd, self.mrd) + + if hasattr(self, "mrd_pool") and self.mrd_pool: + asyn.sync(self.gcsfs.loop, self.mrd_pool.close) # Only close aaow if the stream is open if self.aaow and self.aaow._is_stream_open: