diff --git a/docker/transport/sshconn.py b/docker/transport/sshconn.py index 1870668010..5e73472309 100644 --- a/docker/transport/sshconn.py +++ b/docker/transport/sshconn.py @@ -18,12 +18,13 @@ class SSHSocket(socket.socket): - def __init__(self, host): + def __init__(self, host, socket_path=None): super().__init__( socket.AF_INET, socket.SOCK_STREAM) self.host = host self.port = None self.user = None + self.socket_path = socket_path if ':' in self.host: self.host, self.port = self.host.split(':') if '@' in self.host: @@ -39,7 +40,12 @@ def connect(self, **kwargs): if self.port: args = args + ['-p', self.port] - args = args + ['--', self.host, 'docker system dial-stdio'] + if self.socket_path and self.socket_path != '': + docker_cmd = f"docker --host unix://{self.socket_path} system dial-stdio" + else: + docker_cmd = "docker system dial-stdio" + + args = args + ['--', self.host, docker_cmd] preexec_func = None if not constants.IS_WINDOWS_PLATFORM: @@ -96,21 +102,27 @@ def close(self): class SSHConnection(urllib3.connection.HTTPConnection): - def __init__(self, ssh_transport=None, timeout=60, host=None): + def __init__(self, ssh_transport=None, timeout=60, host=None, socket_path=None): super().__init__( 'localhost', timeout=timeout ) self.ssh_transport = ssh_transport self.timeout = timeout self.ssh_host = host + self.ssh_socket_path = socket_path def connect(self): + if self.ssh_socket_path and self.ssh_socket_path != '': + remote_cmd = f"docker --host unix://{self.ssh_socket_path} system dial-stdio" + else: + remote_cmd = "docker system dial-stdio" + if self.ssh_transport: sock = self.ssh_transport.open_session() sock.settimeout(self.timeout) - sock.exec_command('docker system dial-stdio') + sock.exec_command(remote_cmd) else: - sock = SSHSocket(self.ssh_host) + sock = SSHSocket(self.ssh_host, socket_path=self.ssh_socket_path) sock.settimeout(self.timeout) sock.connect() @@ -120,7 +132,7 @@ def connect(self): class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): scheme = 'ssh' - def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None): + def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None, socket_path=None): super().__init__( 'localhost', timeout=timeout, maxsize=maxsize ) @@ -129,9 +141,15 @@ def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None): if ssh_client: self.ssh_transport = ssh_client.get_transport() self.ssh_host = host + self.ssh_socket_path = socket_path def _new_conn(self): - return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host) + return SSHConnection( + self.ssh_transport, + self.timeout, + self.ssh_host, + self.ssh_socket_path + ) # When re-using connections, urllib3 calls fileno() on our # SSH channel instance, quickly overloading our fd limit. To avoid this, @@ -171,9 +189,9 @@ def __init__(self, base_url, timeout=60, self._create_paramiko_client(base_url) self._connect() - self.ssh_host = base_url - if base_url.startswith('ssh://'): - self.ssh_host = base_url[len('ssh://'):] + parsed = urllib.parse.urlparse(base_url) + self.ssh_socket_path = parsed.path + self.ssh_host = parsed.netloc self.timeout = timeout self.max_pool_size = max_pool_size @@ -223,7 +241,8 @@ def get_connection(self, url, proxies=None): ssh_client=self.ssh_client, timeout=self.timeout, maxsize=self.max_pool_size, - host=self.ssh_host + host=self.ssh_host, + socket_path=self.ssh_socket_path ) with self.pools.lock: pool = self.pools.get(url) @@ -238,7 +257,8 @@ def get_connection(self, url, proxies=None): ssh_client=self.ssh_client, timeout=self.timeout, maxsize=self.max_pool_size, - host=self.ssh_host + host=self.ssh_host, + socket_path=self.ssh_socket_path ) self.pools[url] = pool diff --git a/docker/utils/utils.py b/docker/utils/utils.py index f36a3afb89..4353597107 100644 --- a/docker/utils/utils.py +++ b/docker/utils/utils.py @@ -277,12 +277,16 @@ def parse_host(addr, is_win32=False, tls=False): f'Invalid bind address format: {addr}' ) - if parsed_url.path and proto == 'ssh': - raise errors.DockerException( - f'Invalid bind address format: no path allowed for this protocol: {addr}' - ) + path = parsed_url.path + if proto == 'ssh': + # Support "ssh://user@host/" where is an absolute path to + # the docker daemon unix socket. + if path != '': + if not path.startswith('/'): + raise errors.DockerException( + f'Invalid bind address format: invalid ssh socket path: {addr}' + ) else: - path = parsed_url.path if proto == 'unix' and parsed_url.hostname is not None: # For legacy reasons, we consider unix://path # to be valid and equivalent to unix:///path diff --git a/tests/unit/sshadapter_test.py b/tests/unit/sshadapter_test.py index 8736662101..fb43629273 100644 --- a/tests/unit/sshadapter_test.py +++ b/tests/unit/sshadapter_test.py @@ -11,6 +11,15 @@ def test_ssh_hostname_prefix_trim(): base_url="ssh://user@hostname:1234", shell_out=True) assert conn.ssh_host == "user@hostname:1234" + @staticmethod + def test_ssh_hostname_trim_with_socket_path(): + conn = docker.transport.SSHHTTPAdapter( + base_url="ssh://user@hostname:1234/var/run/docker-1.sock", + shell_out=True + ) + assert conn.ssh_host == "user@hostname:1234" + assert conn.ssh_socket_path == "/var/run/docker-1.sock" + @staticmethod def test_ssh_parse_url(): c = SSHSocket(host="user@hostname:1234") diff --git a/tests/unit/utils_test.py b/tests/unit/utils_test.py index 21da0b58e8..e50ab04869 100644 --- a/tests/unit/utils_test.py +++ b/tests/unit/utils_test.py @@ -280,7 +280,6 @@ def test_parse_host(self): 'tcp://', 'udp://127.0.0.1', 'udp://127.0.0.1:2375', - 'ssh://:22/path', 'tcp://netloc:3333/path?q=1', 'unix:///sock/path#fragment', 'https://netloc:3333/path;params', @@ -312,6 +311,9 @@ def test_parse_host(self): 'ssh://': 'ssh://127.0.0.1:22', 'ssh://user@localhost:22': 'ssh://user@localhost:22', 'ssh://user@remote': 'ssh://user@remote:22', + 'ssh://user@remote/var/run/docker.sock': ( + 'ssh://user@remote:22/var/run/docker.sock' + ), } for host in invalid_hosts: