Skip to content
290 changes: 217 additions & 73 deletions gcsfs/extended_gcsfs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import contextlib
import logging
import os
import uuid
import weakref
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from glob import has_magic

Expand All @@ -16,11 +19,20 @@
AsyncAppendableObjectWriter,
)
from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient
from google.cloud.storage.asyncio.async_multi_range_downloader import (
AsyncMultiRangeDownloader,
)

from gcsfs import __version__ as version
from gcsfs import zb_hns_utils
from gcsfs.core import GCSFile, GCSFileSystem
from gcsfs.retry import DEFAULT_RETRY_CONFIG, get_storage_control_retry_config
from gcsfs.zb_hns_utils import (
DirectMemmoveBuffer,
MRDPool,
PyBytes_AsString,
PyBytes_FromStringAndSize,
)
from gcsfs.zonal_file import ZonalFile

logger = logging.getLogger("gcsfs")
Expand All @@ -44,6 +56,31 @@ class BucketType(Enum):
}


@contextlib.asynccontextmanager
async def _get_mrd_from_pool_or_mrd(mrd_or_pool):
"""
Helper function to yield an AsyncMultiRangeDownloader
whether a single instance or an MRDPool is provided.
"""
if isinstance(mrd_or_pool, MRDPool):
async with mrd_or_pool.get_mrd() as m:
yield m
elif isinstance(mrd_or_pool, AsyncMultiRangeDownloader):
yield mrd_or_pool
else:
raise TypeError(
f"Expected MRDPool or AsyncMultiRangeDownloader, got {type(mrd_or_pool)}"
)


async def _get_mrd_size(mrd_or_pool):
"""Helper to extract the persisted_size from either a pool or a single MRD."""
if mrd_or_pool is None:
return None
async with _get_mrd_from_pool_or_mrd(mrd_or_pool) as m:
return m.persisted_size


class ExtendedGcsFileSystem(GCSFileSystem):
"""
This class will be used when GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT env variable is set to true.
Expand Down Expand Up @@ -89,6 +126,10 @@ def __init__(self, *args, finalize_on_close=False, **kwargs):
if self.credentials.token == "anon":
self.credential = AnonymousCredentials()
self._storage_layout_cache = {}
self._memmove_executor = ThreadPoolExecutor(
Comment thread
googlyrahman marked this conversation as resolved.
max_workers=kwargs.get("memmove_max_workers", 8)
Comment thread
googlyrahman marked this conversation as resolved.
)
weakref.finalize(self, self._memmove_executor.shutdown)
Comment thread
googlyrahman marked this conversation as resolved.

@property
def _user_project(self):
Expand Down Expand Up @@ -285,73 +326,163 @@ async def _is_zonal_bucket(self, bucket):
return bucket_type == BucketType.ZONAL_HIERARCHICAL

async def _fetch_range_split(
self, path, start=None, chunk_lengths=None, mrd=None, size=None, **kwargs
self,
path,
start,
chunk_lengths,
concurrency,
mrd=None,
size=None,
**kwargs,
):
"""
Reading multiple reads in one large stream.
Reading multiple adjacent ranges concurrently.

Optimized for Zonal Buckets:
Leverages AsyncMultiRangeDownloader.download_ranges() to fetch all requested
'chunk_lengths' (chunks) concurrently in a single batch request, significantly
improving performance for ReadAheadV2.
Delegates concurrent fetching of individual chunks directly to `_cat_file`.
"""
file_size = size or await _get_mrd_size(mrd)
if file_size is None:
logger.warning(
f"AsyncMultiRangeDownloader (MRD) for {path} has no 'persisted_size'. "
Comment thread
martindurant marked this conversation as resolved.
"Falling back to _info() to get the file size."
)
file_size = (await self._info(path))["size"]

start_offset = start if start is not None else 0
if start_offset >= file_size or start_offset + sum(chunk_lengths) > file_size:
raise RuntimeError("Request not satisfiable.")

pool_created_here = False
bucket, object_name, generation = self.split_path(path)
mrd_created = False
try:
if mrd is None:
# Check before creating MRD
if not await self._is_zonal_bucket(bucket):
raise RuntimeError(
"Internal error, this method is only supported for zonal buckets!"
)

await self._get_grpc_client()
mrd = await zb_hns_utils.init_mrd(
self.grpc_client, bucket, object_name, generation
)
mrd_created = True
if mrd is None:
# If no mrd is provided, we create one with pool size equal to passed concurrency.
pool_size = min(len(chunk_lengths), concurrency)
mrd = zb_hns_utils.MRDPool(
self, bucket, object_name, generation, pool_size=pool_size
)
await mrd.initialize()
pool_created_here = True

file_size = size or mrd.persisted_size
if file_size is None:
logger.warning(
f"AsyncMultiRangeDownloader (MRD) for {path} has no 'persisted_size'. "
"Falling back to _info() to get the file size."
tasks = []
try:
current_offset = start_offset

cat_kwargs = kwargs.copy()

for length in chunk_lengths:
end_offset = current_offset + length
Comment thread
googlyrahman marked this conversation as resolved.
tasks.append(
asyncio.create_task(
self._cat_file(
path,
start=current_offset,
end=end_offset,
mrd=mrd,
# Distribute the concurrency budget proportionally.
# Since these outer tasks are already concurrent, this is typically 1.
# However, if a large chunk dominates the total size, it receives
# higher concurrency to prevent it from becoming a bottleneck.
concurrency=max(
1, length * concurrency // sum(chunk_lengths)
),
Comment thread
googlyrahman marked this conversation as resolved.
**cat_kwargs,
)
)
)
file_size = (await self._info(path))["size"]

if chunk_lengths:
start_offset = start if start is not None else 0
current_offset = start_offset
current_offset = end_offset

if start_offset >= file_size or (
chunk_lengths is not None
and start_offset + sum(chunk_lengths) > file_size
):
raise RuntimeError("Request not satisfiable.")

read_ranges = [] # To pass to MRD
results = await asyncio.gather(*tasks, return_exceptions=True)

for length in chunk_lengths:
read_ranges.append((current_offset, length))
current_offset += length
# Bubble up any exceptions encountered during concurrent fetching
for res in results:
if isinstance(res, Exception):
raise res

return results
except BaseException:
for t in tasks:
if not t.done():
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
finally:
if pool_created_here:
await mrd.close()
Comment thread
googlyrahman marked this conversation as resolved.

return await zb_hns_utils.download_ranges(read_ranges, mrd)
else:
end = kwargs.get("end")
offset, length = await self._process_limits_to_offset_and_length(
path, start, end, file_size
async def _concurrent_mrd_fetch(self, offset, length, concurrency, mrd_or_pool):
"""Helper to handle concurrent chunk downloads into a DirectMemmoveBuffer."""
concurrency = (
concurrency if length >= self.MIN_CHUNK_SIZE_FOR_CONCURRENCY else 1
Comment thread
martindurant marked this conversation as resolved.
)
result_bytes = PyBytes_FromStringAndSize(None, length)
buffer_ptr = PyBytes_AsString(result_bytes)

part_size = length // concurrency
Comment thread
googlyrahman marked this conversation as resolved.
tasks = []
buffers = []
loop = asyncio.get_running_loop()

# Track if the core download process failed
has_error = False

async def _download(o, s, b, mrd_or_pool):
async with _get_mrd_from_pool_or_mrd(mrd_or_pool) as m_client:
await m_client.download_ranges([(o, s, b)])

for i in range(concurrency):
part_offset = offset + (i * part_size)
actual_size = part_size if i < concurrency - 1 else length - (i * part_size)

part_address = buffer_ptr + (part_offset - offset)
buf = DirectMemmoveBuffer(
part_address,
part_address + actual_size,
self._memmove_executor,
)
buffers.append(buf)
tasks.append(
asyncio.create_task(
_download(part_offset, actual_size, buf, mrd_or_pool)
)
)

data = await zb_hns_utils.download_range(
offset=offset, length=length, mrd=mrd
)
return [data]
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
for res in results:
if isinstance(res, Exception):
has_error = True
raise res
except BaseException:
Comment thread
martindurant marked this conversation as resolved.
has_error = True
for t in tasks:
if not t.done():
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
finally:
if mrd_created:
await zb_hns_utils.close_mrd(mrd)

async def _cat_file(self, path, start=None, end=None, mrd=None, **kwargs):
for buf in buffers:
try:
await loop.run_in_executor(None, buf.close)
except BufferError:
# If we are already handling a network/download exception,
# ignore the BufferError (which is just a symptom of the drop).
# If there's no download error, this means the buffer logic
# itself failed, so we must surface the error.
if not has_error:
raise

return result_bytes

async def _cat_file(
self,
path,
start=None,
end=None,
mrd=None,
concurrency=zb_hns_utils.DEFAULT_CONCURRENCY,
**kwargs,
):
"""Fetch a file's contents as bytes, with an optimized path for Zonal buckets.

This method overrides the parent `_cat_file` to read objects in Zonal buckets using gRPC.
Expand All @@ -360,47 +491,60 @@ async def _cat_file(self, path, start=None, end=None, mrd=None, **kwargs):
path (str): The full GCS path to the file (e.g., "bucket/object").
start (int, optional): The starting byte position to read from.
end (int, optional): The ending byte position to read to.
mrd (AsyncMultiRangeDownloader, optional): An existing multi-range
downloader instance. If not provided, a new one will be created for Zonal buckets.
mrd (AsyncMultiRangeDownloader, MRDPool, optional): An existing multi-range
downloader instance or a pool of MRD. If not provided, a new one will be created for Zonal buckets.
concurrency (int, optional): The max number of concurrent request to fetch the data.

Returns:
bytes: The content of the file or file range.
"""
try:
mrd_created = False
pool_created_here = False

# A new MRD is required when read is done directly by the
# GCSFilesystem class without creating a GCSFile object first.
if mrd is None:
bucket, object_name, generation = self.split_path(path)
# A new MRDPool is required when read is done directly by the
# GCSFilesystem class without creating a GCSFile object first.
if mrd is None:
bucket, object_name, generation = self.split_path(path)
if not await self._is_zonal_bucket(bucket):
# Fall back to default implementation if not a zonal bucket
if not await self._is_zonal_bucket(bucket):
return await super()._cat_file(path, start=start, end=end, **kwargs)

await self._get_grpc_client()
mrd = await zb_hns_utils.init_mrd(
self.grpc_client, bucket, object_name, generation
return await super()._cat_file(
path, start=start, end=end, concurrency=concurrency, **kwargs
)
mrd_created = True

file_size = mrd.persisted_size
# Instantiate an MRDPool locally for this call
mrd = zb_hns_utils.MRDPool(
self, bucket, object_name, generation, pool_size=concurrency
)
await mrd.initialize()
pool_created_here = True

try:
file_size = await _get_mrd_size(mrd)
if file_size is None:
logger.warning(
f"AsyncMultiRangeDownloader (MRD) for {path} has no 'persisted_size'. "
"Falling back to _info() to get the file size. "
"This may result in incorrect behavior for unfinalized objects."
)
file_size = (await self._info(path))["size"]

offset, length = await self._process_limits_to_offset_and_length(
path, start, end, file_size
)

return await zb_hns_utils.download_range(
offset=offset, length=length, mrd=mrd
if length == 0:
return b""

return await self._concurrent_mrd_fetch(
offset,
length,
concurrency if length >= self.MIN_CHUNK_SIZE_FOR_CONCURRENCY else 1,
mrd,
)

finally:
# Explicit cleanup if we created the MRD
if mrd_created:
await zb_hns_utils.close_mrd(mrd)
# If we created a temporary pool specifically for this _cat_file call, clean it up
if pool_created_here:
await mrd.close()

async def _is_bucket_hns_enabled(self, bucket):
"""Checks if a bucket has Hierarchical Namespace enabled."""
Expand Down
Loading
Loading