Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions fsspec/implementations/smb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@
class SMBFileSystem(AbstractFileSystem):
"""Allow reading and writing to Windows and Samba network shares.

**Security considerations**: This class is based on the smblient library
which uses a comnnection cache implemented as a module level global
dictionary variable that uses server and port as keys. Multiple instances
of this class that are created within the same process will (by library default)
share this cache, i.e. connecions to the same server will use the
credentials of the previous connections -- which might be from different
users.
This class tries to prevent credential leakage if use_global_cache is set to False (the default)
by creating an instance specific cache that is passed to the smblient functions
via kwargs.
Please consider carefully if you want to use SMBFileSystem in a
multiuser environment.

When using `fsspec.open()` for getting a file-like object the URI
should be specified as this format:
``smb://workgroup;user:password@server:port/share/folder/file.csv``.
Expand Down Expand Up @@ -73,6 +86,7 @@ def __init__(
register_session_retry_wait=1,
register_session_retry_factor=10,
auto_mkdir=False,
use_global_cache: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -147,6 +161,8 @@ def __init__(
)
self.register_session_retry_factor = register_session_retry_factor
self.auto_mkdir = auto_mkdir
# Initialize per‑instance connection cache. None uses global cache, dict isolates.
self._smb_conn_cache = None if use_global_cache else {}
self._connect()

@property
Expand Down Expand Up @@ -184,6 +200,7 @@ def _connect(self):
port=self._port,
encrypt=self.encrypt,
connection_timeout=self.timeout,
connection_cache=self._smb_conn_cache,
)
return
except (
Expand Down Expand Up @@ -230,23 +247,23 @@ def _get_kwargs_from_urls(path):
def mkdir(self, path, create_parents=True, **kwargs):
wpath = _as_unc_path(self.host, path)
if create_parents:
smbclient.makedirs(wpath, exist_ok=False, port=self._port, **kwargs)
smbclient.makedirs(wpath, exist_ok=False, port=self._port, connection_cache=self._smb_conn_cache, **kwargs)
else:
smbclient.mkdir(wpath, port=self._port, **kwargs)
smbclient.mkdir(wpath, port=self._port, connection_cache=self._smb_conn_cache, **kwargs)

def makedirs(self, path, exist_ok=False):
if _share_has_path(path):
wpath = _as_unc_path(self.host, path)
smbclient.makedirs(wpath, exist_ok=exist_ok, port=self._port)
smbclient.makedirs(wpath, exist_ok=exist_ok, port=self._port, connection_cache=self._smb_conn_cache)

def rmdir(self, path):
if _share_has_path(path):
wpath = _as_unc_path(self.host, path)
smbclient.rmdir(wpath, port=self._port)
smbclient.rmdir(wpath, port=self._port, connection_cache=self._smb_conn_cache)

def info(self, path, **kwargs):
wpath = _as_unc_path(self.host, path)
stats = smbclient.stat(wpath, port=self._port, **kwargs)
stats = smbclient.stat(wpath, port=self._port, connection_cache=self._smb_conn_cache, **kwargs)
if S_ISDIR(stats.st_mode):
stype = "directory"
elif S_ISLNK(stats.st_mode):
Expand All @@ -267,18 +284,18 @@ def info(self, path, **kwargs):
def created(self, path):
"""Return the created timestamp of a file as a datetime.datetime"""
wpath = _as_unc_path(self.host, path)
stats = smbclient.stat(wpath, port=self._port)
stats = smbclient.stat(wpath, port=self._port, connection_cache=self._smb_conn_cache)
return datetime.datetime.fromtimestamp(stats.st_ctime, tz=datetime.timezone.utc)

def modified(self, path):
"""Return the modified timestamp of a file as a datetime.datetime"""
wpath = _as_unc_path(self.host, path)
stats = smbclient.stat(wpath, port=self._port)
stats = smbclient.stat(wpath, port=self._port, connection_cache=self._smb_conn_cache)
return datetime.datetime.fromtimestamp(stats.st_mtime, tz=datetime.timezone.utc)

def ls(self, path, detail=True, **kwargs):
unc = _as_unc_path(self.host, path)
listed = smbclient.listdir(unc, port=self._port, **kwargs)
listed = smbclient.listdir(unc, port=self._port, connection_cache=self._smb_conn_cache, **kwargs)
dirs = ["/".join([path.rstrip("/"), p]) for p in listed]
if detail:
dirs = [self.info(d) for d in dirs]
Expand Down Expand Up @@ -311,14 +328,15 @@ def _open(
if "w" in mode and autocommit is False:
temp = _as_temp_path(self.host, path, self.temppath)
return SMBFileOpener(
wpath, temp, mode, port=self._port, block_size=bls, **kwargs
wpath, temp, mode, port=self._port, block_size=bls, connection_cache=self._smb_conn_cache, **kwargs
)
return smbclient.open_file(
wpath,
mode,
buffering=bls,
share_access=share_access,
port=self._port,
connection_cache=self._smb_conn_cache,
**kwargs,
)

Expand All @@ -328,21 +346,21 @@ def copy(self, path1, path2, **kwargs):
wpath2 = _as_unc_path(self.host, path2)
if self.auto_mkdir:
self.makedirs(self._parent(path2), exist_ok=True)
smbclient.copyfile(wpath1, wpath2, port=self._port, **kwargs)
smbclient.copyfile(wpath1, wpath2, port=self._port, connection_cache=self._smb_conn_cache, **kwargs)

def _rm(self, path):
if _share_has_path(path):
wpath = _as_unc_path(self.host, path)
stats = smbclient.stat(wpath, port=self._port)
stats = smbclient.stat(wpath, port=self._port, connection_cache=self._smb_conn_cache)
if S_ISDIR(stats.st_mode):
smbclient.rmdir(wpath, port=self._port)
smbclient.rmdir(wpath, port=self._port, connection_cache=self._smb_conn_cache)
else:
smbclient.remove(wpath, port=self._port)
smbclient.remove(wpath, port=self._port, connection_cache=self._smb_conn_cache)

def mv(self, path1, path2, recursive=None, maxdepth=None, **kwargs):
wpath1 = _as_unc_path(self.host, path1)
wpath2 = _as_unc_path(self.host, path2)
smbclient.rename(wpath1, wpath2, port=self._port, **kwargs)
smbclient.rename(wpath1, wpath2, port=self._port, connection_cache=self._smb_conn_cache, **kwargs)


def _as_unc_path(host, path):
Expand All @@ -368,7 +386,8 @@ def _share_has_path(path):
class SMBFileOpener:
"""writes to remote temporary file, move on commit"""

def __init__(self, path, temp, mode, port=445, block_size=-1, **kwargs):
def __init__(self, path, temp, mode, port=445, block_size=-1, connection_cache=None, **kwargs):
self._smb_conn_cache = connection_cache
self.path = path
self.temp = temp
self.mode = mode
Expand All @@ -386,17 +405,18 @@ def _open(self):
self.mode,
port=self.port,
buffering=self.block_size,
connection_cache=self._smb_conn_cache,
**self.kwargs,
)

def commit(self):
"""Move temp file to definitive on success."""
# TODO: use transaction support in SMB protocol
smbclient.replace(self.temp, self.path, port=self.port)
smbclient.replace(self.temp, self.path, port=self.port, connection_cache=self._smb_conn_cache)

def discard(self):
"""Remove the temp file on failure."""
smbclient.remove(self.temp, port=self.port)
smbclient.remove(self.temp, port=self.port, connection_cache=self._smb_conn_cache)

def __fspath__(self):
return self.path
Expand Down
12 changes: 11 additions & 1 deletion fsspec/implementations/tests/test_smb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Test SMBFileSystem class using a docker container
"""

import copy
import logging
import os
import shlex
Expand Down Expand Up @@ -94,13 +95,22 @@ def test_simple(smb_params):
fsmb.rm(adir, recursive=True)
assert not fsmb.exists(adir)

# test with a second SMB FS object wo using the password
smb_params_wopass = dict(**smb_params)
del smb_params_wopass["password"]
fsmb_wopass = fsspec.get_filesystem_class("smb")(**smb_params_wopass)
fsmb_wopass.mkdirs("/home/adir/justanotherdir/")


@pytest.mark.flaky(max_runs=3, rerun_filter=delay_rerun)
def test_auto_mkdir(smb_params):
adir = "/home/adir"
adir2 = "/home/adir/otherdir/"
afile = "/home/adir/otherdir/afile"
fsmb = fsspec.get_filesystem_class("smb")(**smb_params, auto_mkdir=True)

smb_params_wopass = dict(**smb_params)
del smb_params_wopass["password"]
fsmb = fsspec.get_filesystem_class("smb")(**smb_params_wopass, auto_mkdir=True)
fsmb.touch(afile)
assert fsmb.exists(adir)
assert fsmb.exists(adir2)
Expand Down
Loading