Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions api/PclusterApiHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import json
import os
import re
import shlex
import time

import boto3
Expand Down Expand Up @@ -273,7 +274,7 @@ def ssm_command(region, instance_id, user, run_command):
else:
ssm = boto3.client("ssm")

command = f"runuser -l {user} -c '{run_command}'"
command = f"runuser -l {shlex.quote(user)} -c {shlex.quote(run_command)}"

ssm_resp = ssm.send_command(
InstanceIds=[instance_id],
Expand Down Expand Up @@ -370,7 +371,7 @@ def sacct():
body = request.json

price_guess = None
sacct_args = " ".join(f"--{k} {v}" for k, v in body.items())
sacct_args = " ".join(f"--{shlex.quote(str(k))} {shlex.quote(str(v))}" for k, v in body.items())
sacct_args += " --allusers" if "user" not in body else ""

if "jobs" not in body:
Expand Down Expand Up @@ -411,7 +412,7 @@ def scontrol_job():
return {"message": "You must specify a job id."}, 400

job_data = (
ssm_command(request.args.get("region"), instance_id, user, f"scontrol show job {job_id} -o").strip().split(" ")
ssm_command(request.args.get("region"), instance_id, user, f"scontrol show job {shlex.quote(job_id)} -o").strip().split(" ")
)
if isinstance(job_data, tuple):
return job_data
Expand Down Expand Up @@ -439,7 +440,7 @@ def cancel_job():
user = request.args.get("user", "ec2-user")
instance_id = request.args.get("instance_id")
job_id = request.args.get("job_id")
ssm_command(request.args.get("region"), instance_id, user, f"scancel {job_id}")
ssm_command(request.args.get("region"), instance_id, user, f"scancel {shlex.quote(job_id)}")
return {"status": "success"}


Expand All @@ -456,7 +457,8 @@ def get_dcv_session():
else:
ssm = boto3.client("ssm")

command = f"runuser -l {user} -c '{dcv_command} {session_directory}'"
inner_command = f"{dcv_command} {shlex.quote(session_directory)}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not shlexing the dcv_command as well?

Copy link
Copy Markdown
Contributor Author

@hehe7318 hehe7318 Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See

dcv_command = "/opt/parallelcluster/scripts/pcluster_dcv_connect.sh"

It's hardcoded, not user input.

command = f"runuser -l {shlex.quote(user)} -c {shlex.quote(inner_command)}"

ssm_resp = ssm.send_command(
InstanceIds=[instance_id],
Expand Down
58 changes: 58 additions & 0 deletions api/tests/test_command_injection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
import shlex
import pytest
from unittest import mock

from api.PclusterApiHandler import ssm_command
from api.pcm_globals import _logger_ctxvar


@pytest.fixture(autouse=True)
def bind_logger():
"""ssm_command writes to a request-scoped logger proxy; bind a stdlib
logger to the contextvar so calls don't fail outside a Flask request."""
token = _logger_ctxvar.set(logging.getLogger("test"))
yield
_logger_ctxvar.reset(token)


@pytest.fixture
def mock_ssm_send(mocker):
mock_client = mock.MagicMock()
mock_client.send_command.return_value = {"Command": {"CommandId": "cmd-id"}}
mock_client.get_command_invocation.return_value = {"Status": "Success"}

mocker.patch("api.PclusterApiHandler.boto3.client", return_value=mock_client)
mocker.patch("api.PclusterApiHandler.time.sleep")
mocker.patch(
"api.PclusterApiHandler.read_and_delete_ssm_output_from_cloudwatch",
return_value="",
)
return mock_client


def _sent_command(mock_client):
return mock_client.send_command.call_args.kwargs["Parameters"]["commands"][0]


def test_malicious_user_stays_a_single_argument(mock_ssm_send):
"""A user string with a single quote must not break out of the -l argument."""
malicious = "ec2-user';touch /tmp/pwned;'"
ssm_command("us-east-1", "i-1234", malicious, "sacct")

tokens = shlex.split(_sent_command(mock_ssm_send))
# Expected shell parse: ['runuser', '-l', <malicious literal>, '-c', 'sacct']
# If shlex.quote were missing, the shell would see extra tokens / commands
# from the injected ';touch /tmp/pwned;' portion.
assert tokens[0:2] == ["runuser", "-l"]
assert tokens[2] == malicious
assert tokens[3:] == ["-c", "sacct"]


def test_malicious_run_command_stays_a_single_argument(mock_ssm_send):
"""A run_command with shell metacharacters must be passed as one -c argument."""
malicious = "sacct';rm -rf /;'"
ssm_command("us-east-1", "i-1234", "ec2-user", malicious)

tokens = shlex.split(_sent_command(mock_ssm_send))
assert tokens == ["runuser", "-l", "ec2-user", "-c", malicious]
2 changes: 1 addition & 1 deletion api/tests/validation/test_api_custom_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def test_size_not_exceeding_failing():
),
])
def test_is_safe_path(path: str, expected_result: bool):
assert is_safe_path(path) == expected_result
assert is_safe_path(path) == expected_result
Loading