Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions docker/transport/sshconn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@


class SSHSocket(socket.socket):
def __init__(self, host):
def __init__(self, host, socket_path=None):
super().__init__(
socket.AF_INET, socket.SOCK_STREAM)
self.host = host
self.port = None
self.user = None
self.socket_path = socket_path
if ':' in self.host:
self.host, self.port = self.host.split(':')
if '@' in self.host:
Expand All @@ -39,7 +40,12 @@ def connect(self, **kwargs):
if self.port:
args = args + ['-p', self.port]

args = args + ['--', self.host, 'docker system dial-stdio']
if self.socket_path and self.socket_path != '':
docker_cmd = f"docker --host unix://{self.socket_path} system dial-stdio"
else:
docker_cmd = "docker system dial-stdio"

args = args + ['--', self.host, docker_cmd]

preexec_func = None
if not constants.IS_WINDOWS_PLATFORM:
Expand Down Expand Up @@ -96,21 +102,27 @@ def close(self):


class SSHConnection(urllib3.connection.HTTPConnection):
def __init__(self, ssh_transport=None, timeout=60, host=None):
def __init__(self, ssh_transport=None, timeout=60, host=None, socket_path=None):
super().__init__(
'localhost', timeout=timeout
)
self.ssh_transport = ssh_transport
self.timeout = timeout
self.ssh_host = host
self.ssh_socket_path = socket_path

def connect(self):
if self.ssh_socket_path and self.ssh_socket_path != '':
remote_cmd = f"docker --host unix://{self.ssh_socket_path} system dial-stdio"
else:
remote_cmd = "docker system dial-stdio"

if self.ssh_transport:
sock = self.ssh_transport.open_session()
sock.settimeout(self.timeout)
sock.exec_command('docker system dial-stdio')
sock.exec_command(remote_cmd)
else:
sock = SSHSocket(self.ssh_host)
sock = SSHSocket(self.ssh_host, socket_path=self.ssh_socket_path)
sock.settimeout(self.timeout)
sock.connect()

Expand All @@ -120,7 +132,7 @@ def connect(self):
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
scheme = 'ssh'

def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None):
def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None, socket_path=None):
super().__init__(
'localhost', timeout=timeout, maxsize=maxsize
)
Expand All @@ -129,9 +141,15 @@ def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None):
if ssh_client:
self.ssh_transport = ssh_client.get_transport()
self.ssh_host = host
self.ssh_socket_path = socket_path

def _new_conn(self):
return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host)
return SSHConnection(
self.ssh_transport,
self.timeout,
self.ssh_host,
self.ssh_socket_path
)

# When re-using connections, urllib3 calls fileno() on our
# SSH channel instance, quickly overloading our fd limit. To avoid this,
Expand Down Expand Up @@ -171,9 +189,9 @@ def __init__(self, base_url, timeout=60,
self._create_paramiko_client(base_url)
self._connect()

self.ssh_host = base_url
if base_url.startswith('ssh://'):
self.ssh_host = base_url[len('ssh://'):]
parsed = urllib.parse.urlparse(base_url)
self.ssh_socket_path = parsed.path
self.ssh_host = parsed.netloc

self.timeout = timeout
self.max_pool_size = max_pool_size
Expand Down Expand Up @@ -223,7 +241,8 @@ def get_connection(self, url, proxies=None):
ssh_client=self.ssh_client,
timeout=self.timeout,
maxsize=self.max_pool_size,
host=self.ssh_host
host=self.ssh_host,
socket_path=self.ssh_socket_path
)
with self.pools.lock:
pool = self.pools.get(url)
Expand All @@ -238,7 +257,8 @@ def get_connection(self, url, proxies=None):
ssh_client=self.ssh_client,
timeout=self.timeout,
maxsize=self.max_pool_size,
host=self.ssh_host
host=self.ssh_host,
socket_path=self.ssh_socket_path
)
self.pools[url] = pool

Expand Down
14 changes: 9 additions & 5 deletions docker/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,16 @@ def parse_host(addr, is_win32=False, tls=False):
f'Invalid bind address format: {addr}'
)

if parsed_url.path and proto == 'ssh':
raise errors.DockerException(
f'Invalid bind address format: no path allowed for this protocol: {addr}'
)
path = parsed_url.path
if proto == 'ssh':
# Support "ssh://user@host/<path>" where <path> is an absolute path to
# the docker daemon unix socket.
if path != '':
if not path.startswith('/'):
raise errors.DockerException(
f'Invalid bind address format: invalid ssh socket path: {addr}'
)
else:
path = parsed_url.path
if proto == 'unix' and parsed_url.hostname is not None:
# For legacy reasons, we consider unix://path
# to be valid and equivalent to unix:///path
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/sshadapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ def test_ssh_hostname_prefix_trim():
base_url="ssh://user@hostname:1234", shell_out=True)
assert conn.ssh_host == "user@hostname:1234"

@staticmethod
def test_ssh_hostname_trim_with_socket_path():
conn = docker.transport.SSHHTTPAdapter(
base_url="ssh://user@hostname:1234/var/run/docker-1.sock",
shell_out=True
)
assert conn.ssh_host == "user@hostname:1234"
assert conn.ssh_socket_path == "/var/run/docker-1.sock"

@staticmethod
def test_ssh_parse_url():
c = SSHSocket(host="user@hostname:1234")
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def test_parse_host(self):
'tcp://',
'udp://127.0.0.1',
'udp://127.0.0.1:2375',
'ssh://:22/path',
'tcp://netloc:3333/path?q=1',
'unix:///sock/path#fragment',
'https://netloc:3333/path;params',
Expand Down Expand Up @@ -312,6 +311,9 @@ def test_parse_host(self):
'ssh://': 'ssh://127.0.0.1:22',
'ssh://user@localhost:22': 'ssh://user@localhost:22',
'ssh://user@remote': 'ssh://user@remote:22',
'ssh://user@remote/var/run/docker.sock': (
'ssh://user@remote:22/var/run/docker.sock'
),
}

for host in invalid_hosts:
Expand Down