Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pkg/transport/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ func (t *HTTPTransport) Start(ctx context.Context) error {
// paths so they reach the correct endpoint on the remote server.
var remoteBasePath string

// remoteRawQuery holds the raw query string from the remote URL (e.g.,
// "toolsets=core,alerting" from "https://mcp.example.com/mcp?toolsets=core,alerting").
// This must be forwarded on every outbound request or it is silently dropped.
var remoteRawQuery string

if t.remoteURL != "" {
// For remote MCP servers, construct target URI from remote URL
remoteURL, err := url.Parse(t.remoteURL)
Expand All @@ -267,9 +272,11 @@ func (t *HTTPTransport) Start(ctx context.Context) error {
// The target URI only has scheme+host, so without this the remote path is lost.
remoteBasePath = remoteURL.Path

remoteRawQuery = remoteURL.RawQuery

//nolint:gosec // G706: logging proxy port and remote URL from config
slog.Debug("setting up transparent proxy to forward to remote URL",
"port", t.proxyPort, "target", targetURI, "base_path", remoteBasePath)
"port", t.proxyPort, "target", targetURI, "base_path", remoteBasePath, "raw_query", remoteRawQuery)
} else {
if t.containerName == "" {
return transporterrors.ErrContainerNameNotSet
Expand Down Expand Up @@ -311,6 +318,7 @@ func (t *HTTPTransport) Start(ctx context.Context) error {
if remoteBasePath != "" {
proxyOptions = append(proxyOptions, transparent.WithRemoteBasePath(remoteBasePath))
}
proxyOptions = append(proxyOptions, transparent.WithRemoteRawQuery(remoteRawQuery))

// Create the transparent proxy
t.proxy = transparent.NewTransparentProxyWithOptions(
Expand Down
125 changes: 125 additions & 0 deletions pkg/transport/http_remote_query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package transport

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/stacklok/toolhive/pkg/transport/proxy/transparent"
"github.com/stacklok/toolhive/pkg/transport/types"
)

// TestHTTPTransport_Start_RemoteURLQueryParams verifies that HTTPTransport.Start()
// correctly extracts the raw query from the remoteURL and wires it into the
// transparent proxy so every upstream request carries those query parameters.
func TestHTTPTransport_Start_RemoteURLQueryParams(t *testing.T) {
t.Parallel()

tests := []struct {
name string
remoteQuery string // query string appended to the remote registration URL
expectedQuery string // raw query the upstream server should receive
description string
}{
{
name: "query params from registration URL are forwarded to upstream",
remoteQuery: "toolsets=core,alerting",
expectedQuery: "toolsets=core,alerting",
description: "Datadog case: toolset selection params must reach the upstream server",
},
{
name: "multiple query params are all forwarded to upstream",
remoteQuery: "toolsets=core,alerting&version=2",
expectedQuery: "toolsets=core,alerting&version=2",
description: "Multiple params must all be forwarded, none dropped",
},
{
name: "no query params — upstream receives empty query string",
remoteQuery: "",
expectedQuery: "",
description: "Without configured query params, upstream receives an empty query string",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

var receivedQuery atomic.Value

upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedQuery.Store(r.URL.RawQuery)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":"1","result":{"protocolVersion":"2024-11-05"}}`))
}))
defer upstream.Close()

remoteURL := upstream.URL + "/mcp"
if tt.remoteQuery != "" {
remoteURL += "?" + tt.remoteQuery
}

// Use port 0 so the OS assigns a free port.
transport := NewHTTPTransport(
types.TransportTypeStreamableHTTP,
LocalhostIPv4,
0, // proxyPort: OS-assigned
0, // targetPort: unused for remote
nil, // deployer: nil for remote
false, // debug
"", // targetHost: unused for remote
nil, // authInfoHandler
nil, // prometheusHandler
nil, // prefixHandlers
"", // endpointPrefix
false, // trustProxyHeaders
)
transport.SetRemoteURL(remoteURL)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

require.NoError(t, transport.Start(ctx))
defer func() {
assert.NoError(t, transport.Stop(context.Background()))
}()

// Retrieve the actual listening address from the underlying proxy.
tp, ok := transport.proxy.(*transparent.TransparentProxy)
require.True(t, ok, "proxy should be a TransparentProxy")
addr := tp.ListenerAddr()
require.NotEmpty(t, addr, "proxy should be listening")

// POST to the clean proxy URL (no query params) so only the
// proxy-configured remoteRawQuery is the source of upstream query params.
proxyURL := fmt.Sprintf("http://%s/mcp", addr)
body := `{"jsonrpc":"2.0","method":"initialize","id":"1","params":{}}`
req, err := http.NewRequest(http.MethodPost, proxyURL, strings.NewReader(body))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")

client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusOK, resp.StatusCode)

actualQuery, _ := receivedQuery.Load().(string)
assert.Equal(t, tt.expectedQuery, actualQuery,
"%s: upstream server received wrong query string", tt.description)
})
}
}
117 changes: 117 additions & 0 deletions pkg/transport/proxy/transparent/remote_path_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,123 @@ import (
"github.com/stretchr/testify/require"
)

// TestRemoteQueryForwarding verifies that the transparent proxy correctly
// forwards query parameters from the remote URL configuration to every
// outbound request.
//
// Scenario: remoteURL is https://mcp.datadoghq.com/mcp?toolsets=core,alerting
// Without this fix the query params are silently dropped and the remote
// server receives /mcp with no toolsets, returning only default tools.
func TestRemoteQueryForwarding(t *testing.T) {
t.Parallel()

tests := []struct {
name string
remoteRawQuery string // Query from registration URL
clientRawQuery string // Additional query from client request
expectedRawQuery string // Query that should arrive at the remote server
description string
}{
{
name: "remote query only, no client query",
remoteRawQuery: "toolsets=core,alerting",
clientRawQuery: "",
expectedRawQuery: "toolsets=core,alerting",
description: "Datadog case: remote query params forwarded when client sends none",
},
{
name: "remote query merged with client query",
remoteRawQuery: "toolsets=core,alerting",
clientRawQuery: "session=abc",
expectedRawQuery: "toolsets=core,alerting&session=abc",
description: "Remote params take precedence, client params appended",
},
{
name: "no remote query, client query preserved",
remoteRawQuery: "",
clientRawQuery: "session=abc",
expectedRawQuery: "session=abc",
description: "Without remote query, client query passes through unchanged",
},
{
name: "no remote query and no client query",
remoteRawQuery: "",
clientRawQuery: "",
expectedRawQuery: "",
description: "No query params in either direction",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

var receivedQuery atomic.Value

remoteServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedQuery.Store(r.URL.RawQuery)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":"1","result":{"protocolVersion":"2024-11-05"}}`))
}))
defer remoteServer.Close()

parsedRemote, err := url.Parse(remoteServer.URL)
require.NoError(t, err)
targetURI := (&url.URL{
Scheme: parsedRemote.Scheme,
Host: parsedRemote.Host,
}).String()

var opts []Option
if tt.remoteRawQuery != "" {
opts = append(opts, WithRemoteRawQuery(tt.remoteRawQuery))
}

proxy := NewTransparentProxyWithOptions(
"127.0.0.1", 0, targetURI,
nil, nil, nil,
false, true, "streamable-http",
nil, nil,
"", false,
nil, // middlewares
opts...,
)

ctx := context.Background()
err = proxy.Start(ctx)
require.NoError(t, err)
defer func() {
assert.NoError(t, proxy.Stop(context.Background()))
}()

addr := proxy.ListenerAddr()
require.NotEmpty(t, addr)

proxyURL := fmt.Sprintf("http://%s/mcp", addr)
if tt.clientRawQuery != "" {
proxyURL += "?" + tt.clientRawQuery
}

body := `{"jsonrpc":"2.0","method":"initialize","id":"1","params":{}}`
req, err := http.NewRequest(http.MethodPost, proxyURL, strings.NewReader(body))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")

client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusOK, resp.StatusCode)

actualQuery, _ := receivedQuery.Load().(string)
assert.Equal(t, tt.expectedRawQuery, actualQuery,
"%s: remote server received wrong query string", tt.description)
})
}
}

// TestRemotePathForwarding verifies that the transparent proxy correctly
// forwards requests to the remote server's full path, not just the host.
//
Expand Down
42 changes: 42 additions & 0 deletions pkg/transport/proxy/transparent/transparent_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ type TransparentProxy struct {
// URI only contains the scheme and host.
remoteBasePath string

// remoteRawQuery holds the raw query string from the remote URL (e.g.,
// "toolsets=core,alerting" from "https://mcp.example.com/mcp?toolsets=core,alerting").
// When set, it is merged into every outbound request so query parameters
// from the original registration URL are never silently dropped.
remoteRawQuery string

// Deprecated: trustProxyHeaders indicates whether to trust X-Forwarded-* headers (moved to SSEResponseProcessor)
trustProxyHeaders bool

Expand Down Expand Up @@ -178,6 +184,18 @@ func WithRemoteBasePath(basePath string) Option {
}
}

// WithRemoteRawQuery sets the raw query string from the remote URL.
// When set, these query parameters are merged into every outbound request,
// ensuring query parameters from the original registration URL are always forwarded.
// Ignores empty strings; default (no query forwarding) will be used.
func WithRemoteRawQuery(rawQuery string) Option {
return func(p *TransparentProxy) {
if rawQuery != "" {
p.remoteRawQuery = rawQuery
}
}
}

// withHealthCheckPingTimeout sets the health check ping timeout.
// This is primarily useful for testing with shorter timeouts.
// Ignores non-positive timeouts; default will be used.
Expand Down Expand Up @@ -505,6 +523,21 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
pr.Out.URL.RawPath = ""
}

// Merge query parameters from the remote URL into the outbound request.
// Remote params are prepended so they appear first; most HTTP servers
// adopt first-value-wins semantics for duplicate keys, ensuring operator
// configured values (e.g., toolsets=core,alerting) take precedence over
// any client-supplied params with the same key.
// Raw string concatenation is intentional: url.Values.Encode() would
// percent-encode characters like commas that some APIs expect as literals.
if p.remoteRawQuery != "" {
merged := p.remoteRawQuery
if pr.Out.URL.RawQuery != "" {
merged += "&" + pr.Out.URL.RawQuery
}
pr.Out.URL.RawQuery = merged
}

// Inject OpenTelemetry trace propagation headers for downstream tracing
if pr.Out.Context() != nil {
otel.GetTextMapPropagator().Inject(pr.Out.Context(), propagation.HeaderCarrier(pr.Out.Header))
Expand Down Expand Up @@ -605,6 +638,15 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
return nil
}

// ListenerAddr returns the network address the proxy is listening on.
// Returns an empty string if the proxy has not been started.
func (p *TransparentProxy) ListenerAddr() string {
if p.listener == nil {
return ""
}
return p.listener.Addr().String()
}

// CloseListener closes the listener for the transparent proxy.
func (p *TransparentProxy) CloseListener() error {
if p.listener != nil {
Expand Down
Loading
Loading