Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/thv/app/run_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type RunFlags struct {
ProxyPort int
TargetPort int
TargetHost string
Publish []string

// Server configuration
Name string
Expand Down Expand Up @@ -154,6 +155,8 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
"target-host",
transport.LocalhostIPv4,
"Host to forward traffic to (only applicable to SSE or Streamable HTTP transport)")
cmd.Flags().StringArrayVarP(&config.Publish, "publish", "p", []string{},
"Publish a container's port(s) to the host (format: hostPort:containerPort)")
cmd.Flags().StringVar(
&config.PermissionProfile,
"permission-profile",
Expand Down Expand Up @@ -606,6 +609,7 @@ func buildRunnerConfig(
LoadGlobal: runFlags.IgnoreGlobally,
PrintOverlays: runFlags.PrintOverlays,
}),
runner.WithPublish(runFlags.Publish),
}

// Load tools override configuration
Expand Down
1 change: 1 addition & 0 deletions docs/cli/thv_run.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions docs/server/docs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions docs/server/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions docs/server/swagger.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 17 additions & 9 deletions pkg/container/docker/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,7 @@ func generatePortBindings(labels map[string]string,
portBindings map[string][]runtime.PortBinding) (map[string][]runtime.PortBinding, int, error) {
var hostPort int
// check if we need to map to a random port of not
if _, ok := labels["toolhive-auxiliary"]; ok && labels["toolhive-auxiliary"] == "true" {
if _, ok := labels[ToolhiveAuxiliaryWorkloadLabel]; ok && labels[ToolhiveAuxiliaryWorkloadLabel] == LabelValueTrue {
// find first port
var err error
for _, bindings := range portBindings {
Expand All @@ -1633,17 +1633,25 @@ func generatePortBindings(labels map[string]string,
}
}
} else {
// bind to a random host port
hostPort = networking.FindAvailable()
if hostPort == 0 {
return nil, 0, fmt.Errorf("could not find an available port")
}

// first port binding needs to map to the host port
// For consistency, we only use FindAvailable for the primary port if it's not already set
for key, bindings := range portBindings {
if len(bindings) > 0 {
bindings[0].HostPort = fmt.Sprintf("%d", hostPort)
portBindings[key] = bindings
hostPortStr := bindings[0].HostPort
if hostPortStr == "" || hostPortStr == "0" {
hostPort = networking.FindAvailable()
if hostPort == 0 {
return nil, 0, fmt.Errorf("could not find an available port")
}
bindings[0].HostPort = fmt.Sprintf("%d", hostPort)
portBindings[key] = bindings
} else {
var err error
hostPort, err = strconv.Atoi(hostPortStr)
if err != nil {
return nil, 0, fmt.Errorf("failed to convert host port %s to int: %w", hostPortStr, err)
}
}
Comment thread
jerm-dro marked this conversation as resolved.
break
}
}
Expand Down
37 changes: 37 additions & 0 deletions pkg/container/docker/client_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,43 @@ func TestGeneratePortBindings_NonAuxiliaryAssignsRandomPortAndMutatesFirstBindin
assert.Equal(t, 1, countMatches, "expected exactly one first binding to be updated to hostPort=%s", expected)
}

func TestGeneratePortBindings_NonAuxiliaryKeepsExplicitHostPort(t *testing.T) {
t.Parallel()

labels := map[string]string{} // not auxiliary
in := map[string][]runtime.PortBinding{
"8080/tcp": {
{HostIP: "", HostPort: "9090"},
},
}
out, hostPort, err := generatePortBindings(labels, in)
require.NoError(t, err)
require.Equal(t, 9090, hostPort)

require.Contains(t, out, "8080/tcp")
require.Len(t, out["8080/tcp"], 1)
assert.Equal(t, "9090", out["8080/tcp"][0].HostPort)
}

func TestGeneratePortBindings_NonAuxiliaryAssignsRandomPortForZero(t *testing.T) {
t.Parallel()

labels := map[string]string{} // not auxiliary
in := map[string][]runtime.PortBinding{
"8080/tcp": {
{HostIP: "", HostPort: "0"},
},
}
out, hostPort, err := generatePortBindings(labels, in)
require.NoError(t, err)
require.NotZero(t, hostPort)

require.Contains(t, out, "8080/tcp")
require.Len(t, out["8080/tcp"], 1)
assert.NotEqual(t, "0", out["8080/tcp"][0].HostPort)
assert.Equal(t, fmt.Sprintf("%d", hostPort), out["8080/tcp"][0].HostPort)
Comment thread
jerm-dro marked this conversation as resolved.
}

func TestAddEgressEnvVars_SetsAll(t *testing.T) {
t.Parallel()

Expand Down
44 changes: 44 additions & 0 deletions pkg/networking/port.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"log/slog"
"math/big"
"net"
"strconv"
"strings"

gopsutilnet "github.com/shirou/gopsutil/v4/net"
)
Expand Down Expand Up @@ -201,3 +203,45 @@ func GetProcessOnPort(port int) (int, error) {
}
return 0, nil
}

// ParsePortSpec parses a port specification string in the format "hostPort:containerPort" or just "containerPort".
// Returns the host port string and container port integer.
// If only a container port is provided, a random available host port is selected.
func ParsePortSpec(portSpec string) (string, int, error) {
slog.Debug("Parsing port spec", "spec", portSpec)
// Check if it's in host:container format
if strings.Contains(portSpec, ":") {
parts := strings.Split(portSpec, ":")
if len(parts) != 2 {
return "", 0, fmt.Errorf("invalid port specification: %s (expected 'hostPort:containerPort')", portSpec)
}

hostPortStr := parts[0]
containerPortStr := parts[1]

// Verify host port is a valid integer (or empty string if we supported random host port with :, but here we expect explicit)
if _, err := strconv.Atoi(hostPortStr); err != nil {
return "", 0, fmt.Errorf("invalid host port in spec '%s': %w", portSpec, err)
}

containerPort, err := strconv.Atoi(containerPortStr)
if err != nil {
return "", 0, fmt.Errorf("invalid container port in spec '%s': %w", portSpec, err)
}

return hostPortStr, containerPort, nil
}

// Try parsing as just container port
containerPort, err := strconv.Atoi(portSpec)
if err == nil {
// Find a random available host port
hostPort := FindAvailable()
if hostPort == 0 {
return "", 0, fmt.Errorf("could not find an available port for container port %d", containerPort)
}
return fmt.Sprintf("%d", hostPort), containerPort, nil
}

return "", 0, fmt.Errorf("invalid port specification: %s (expected 'hostPort:containerPort' or 'containerPort')", portSpec)
}
76 changes: 69 additions & 7 deletions pkg/networking/port_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,12 @@ func TestValidateCallbackPort(t *testing.T) {
err := networking.ValidateCallbackPort(tt.port, tt.clientID)

if tt.wantError {
if err == nil {
t.Errorf("ValidateCallbackPort() expected error but got nil")
} else if tt.errorMsg != "" && err.Error() != tt.errorMsg {
t.Errorf("ValidateCallbackPort() error = %v, want %v", err.Error(), tt.errorMsg)
require.Error(t, err)
if tt.errorMsg != "" {
require.EqualError(t, err, tt.errorMsg)
}
} else {
if err != nil {
t.Errorf("ValidateCallbackPort() unexpected error = %v", err)
}
require.NoError(t, err)
}
})
}
Expand Down Expand Up @@ -134,3 +131,68 @@ func TestGetProcessOnPort_PortInUse(t *testing.T) {
require.NoError(t, err)
assert.NotZero(t, pid, "port is in use, GetProcessOnPort should return the process PID")
}

func TestParsePortSpec(t *testing.T) {
t.Parallel()

tests := []struct {
name string
portSpec string
expectedHostPort string
expectedContainer int
wantError bool
}{
{
name: "host:container",
portSpec: "8003:8001",
expectedHostPort: "8003",
expectedContainer: 8001,
wantError: false,
},
{
name: "container only",
portSpec: "8001",
expectedHostPort: "", // Random
expectedContainer: 8001,
wantError: false,
},
{
name: "invalid format",
portSpec: "invalid",
expectedHostPort: "",
expectedContainer: 0,
wantError: true,
},
{
name: "invalid host port",
portSpec: "abc:8001",
expectedHostPort: "",
expectedContainer: 0,
wantError: true,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

hostPort, containerPort, err := networking.ParsePortSpec(tt.portSpec)

if tt.wantError {
require.Error(t, err, "ParsePortSpec(%s) expected error", tt.portSpec)
return
}

require.NoError(t, err, "ParsePortSpec(%s) unexpected error", tt.portSpec)

if tt.expectedHostPort != "" {
require.Equal(t, tt.expectedHostPort, hostPort, "ParsePortSpec(%s) unexpected host port", tt.portSpec)
} else {
require.NotEmpty(t, hostPort, "ParsePortSpec(%s) hostPort is empty, want random port", tt.portSpec)
}

require.Equal(t, tt.expectedContainer, containerPort, "ParsePortSpec(%s) unexpected container port", tt.portSpec)
})
}
}
3 changes: 3 additions & 0 deletions pkg/runner/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ type RunConfig struct {
// TargetHost is the host to forward traffic to (only applicable to SSE transport)
TargetHost string `json:"target_host,omitempty" yaml:"target_host,omitempty"`

// Publish lists ports to publish to the host in format "hostPort:containerPort"
Publish []string `json:"publish,omitempty" yaml:"publish,omitempty"`

// PermissionProfileNameOrPath is the name or path of the permission profile
PermissionProfileNameOrPath string `json:"permission_profile_name_or_path,omitempty" yaml:"permission_profile_name_or_path,omitempty"` //nolint:lll

Expand Down
8 changes: 8 additions & 0 deletions pkg/runner/config_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ func WithTargetHost(targetHost string) RunConfigBuilderOption {
}
}

// WithPublish sets the published ports
func WithPublish(publish []string) RunConfigBuilderOption {
return func(b *runConfigBuilder) error {
b.config.Publish = publish
return nil
}
}

// WithDebug sets debug mode
func WithDebug(debug bool) RunConfigBuilderOption {
return func(b *runConfigBuilder) error {
Expand Down
1 change: 1 addition & 0 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ func (r *Runner) Run(ctx context.Context) error {
r.Config.Host,
r.Config.TargetPort,
r.Config.TargetHost,
r.Config.Publish,
scalingConfig,
)
if err != nil {
Expand Down
30 changes: 29 additions & 1 deletion pkg/runtime/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/stacklok/toolhive-core/permissions"
rt "github.com/stacklok/toolhive/pkg/container/runtime"
"github.com/stacklok/toolhive/pkg/ignore"
"github.com/stacklok/toolhive/pkg/networking"
"github.com/stacklok/toolhive/pkg/transport/types"
)

Expand Down Expand Up @@ -50,6 +51,7 @@ func Setup(
host string,
targetPort int,
targetHost string,
publishedPorts []string,
scalingConfig *rt.ScalingConfig,
) (*SetupResult, error) {
// Add transport-specific environment variables
Expand All @@ -74,6 +76,26 @@ func Setup(
containerOptions := rt.NewDeployWorkloadOptions()
containerOptions.K8sPodTemplatePatch = k8sPodTemplatePatch
containerOptions.IgnoreConfig = ignoreConfig

// Process published ports
for _, portSpec := range publishedPorts {
hostPort, containerPort, err := networking.ParsePortSpec(portSpec)
if err != nil {
return nil, fmt.Errorf("failed to parse published port '%s': %w", portSpec, err)
}

// Add to exposed ports
containerPortStr := fmt.Sprintf("%d/tcp", containerPort)
containerOptions.ExposedPorts[containerPortStr] = struct{}{}

// Add to port bindings
// Check if we already have bindings for this port
bindings := containerOptions.PortBindings[containerPortStr]
bindings = append(bindings, rt.PortBinding{
HostPort: hostPort,
})
containerOptions.PortBindings[containerPortStr] = bindings
}
containerOptions.ScalingConfig = scalingConfig

if transportType == types.TransportTypeStdio {
Expand All @@ -92,7 +114,13 @@ func Setup(
}

// Set the port bindings
containerOptions.PortBindings[containerPortStr] = portBindings
// Note: if the user explicitly publishes the target port using --publish,
// we append the default transport binding to the list of bindings for that port.
if _, ok := containerOptions.PortBindings[containerPortStr]; ok {
containerOptions.PortBindings[containerPortStr] = append(containerOptions.PortBindings[containerPortStr], portBindings...)
} else {
containerOptions.PortBindings[containerPortStr] = portBindings
}
}

// Create the container
Expand Down
Loading