diff --git a/pkg/transport/http.go b/pkg/transport/http.go index 8b11b5a7c0..8deb3f7d08 100644 --- a/pkg/transport/http.go +++ b/pkg/transport/http.go @@ -252,6 +252,11 @@ func (t *HTTPTransport) Start(ctx context.Context) error { // paths so they reach the correct endpoint on the remote server. var remoteBasePath string + // remoteRawQuery holds the raw query string from the remote URL (e.g., + // "toolsets=core,alerting" from "https://mcp.example.com/mcp?toolsets=core,alerting"). + // This must be forwarded on every outbound request or it is silently dropped. + var remoteRawQuery string + if t.remoteURL != "" { // For remote MCP servers, construct target URI from remote URL remoteURL, err := url.Parse(t.remoteURL) @@ -267,9 +272,11 @@ func (t *HTTPTransport) Start(ctx context.Context) error { // The target URI only has scheme+host, so without this the remote path is lost. remoteBasePath = remoteURL.Path + remoteRawQuery = remoteURL.RawQuery + //nolint:gosec // G706: logging proxy port and remote URL from config slog.Debug("setting up transparent proxy to forward to remote URL", - "port", t.proxyPort, "target", targetURI, "base_path", remoteBasePath) + "port", t.proxyPort, "target", targetURI, "base_path", remoteBasePath, "raw_query", remoteRawQuery) } else { if t.containerName == "" { return transporterrors.ErrContainerNameNotSet @@ -311,6 +318,7 @@ func (t *HTTPTransport) Start(ctx context.Context) error { if remoteBasePath != "" { proxyOptions = append(proxyOptions, transparent.WithRemoteBasePath(remoteBasePath)) } + proxyOptions = append(proxyOptions, transparent.WithRemoteRawQuery(remoteRawQuery)) // Create the transparent proxy t.proxy = transparent.NewTransparentProxyWithOptions( diff --git a/pkg/transport/http_remote_query_test.go b/pkg/transport/http_remote_query_test.go new file mode 100644 index 0000000000..abeb203567 --- /dev/null +++ b/pkg/transport/http_remote_query_test.go @@ -0,0 +1,125 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package transport + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/transport/proxy/transparent" + "github.com/stacklok/toolhive/pkg/transport/types" +) + +// TestHTTPTransport_Start_RemoteURLQueryParams verifies that HTTPTransport.Start() +// correctly extracts the raw query from the remoteURL and wires it into the +// transparent proxy so every upstream request carries those query parameters. +func TestHTTPTransport_Start_RemoteURLQueryParams(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + remoteQuery string // query string appended to the remote registration URL + expectedQuery string // raw query the upstream server should receive + description string + }{ + { + name: "query params from registration URL are forwarded to upstream", + remoteQuery: "toolsets=core,alerting", + expectedQuery: "toolsets=core,alerting", + description: "Datadog case: toolset selection params must reach the upstream server", + }, + { + name: "multiple query params are all forwarded to upstream", + remoteQuery: "toolsets=core,alerting&version=2", + expectedQuery: "toolsets=core,alerting&version=2", + description: "Multiple params must all be forwarded, none dropped", + }, + { + name: "no query params — upstream receives empty query string", + remoteQuery: "", + expectedQuery: "", + description: "Without configured query params, upstream receives an empty query string", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var receivedQuery atomic.Value + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedQuery.Store(r.URL.RawQuery) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":"1","result":{"protocolVersion":"2024-11-05"}}`)) + })) + defer upstream.Close() + + remoteURL := upstream.URL + "/mcp" + if tt.remoteQuery != "" { + remoteURL += "?" + tt.remoteQuery + } + + // Use port 0 so the OS assigns a free port. + transport := NewHTTPTransport( + types.TransportTypeStreamableHTTP, + LocalhostIPv4, + 0, // proxyPort: OS-assigned + 0, // targetPort: unused for remote + nil, // deployer: nil for remote + false, // debug + "", // targetHost: unused for remote + nil, // authInfoHandler + nil, // prometheusHandler + nil, // prefixHandlers + "", // endpointPrefix + false, // trustProxyHeaders + ) + transport.SetRemoteURL(remoteURL) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + require.NoError(t, transport.Start(ctx)) + defer func() { + assert.NoError(t, transport.Stop(context.Background())) + }() + + // Retrieve the actual listening address from the underlying proxy. + tp, ok := transport.proxy.(*transparent.TransparentProxy) + require.True(t, ok, "proxy should be a TransparentProxy") + addr := tp.ListenerAddr() + require.NotEmpty(t, addr, "proxy should be listening") + + // POST to the clean proxy URL (no query params) so only the + // proxy-configured remoteRawQuery is the source of upstream query params. + proxyURL := fmt.Sprintf("http://%s/mcp", addr) + body := `{"jsonrpc":"2.0","method":"initialize","id":"1","params":{}}` + req, err := http.NewRequest(http.MethodPost, proxyURL, strings.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + actualQuery, _ := receivedQuery.Load().(string) + assert.Equal(t, tt.expectedQuery, actualQuery, + "%s: upstream server received wrong query string", tt.description) + }) + } +} diff --git a/pkg/transport/proxy/transparent/remote_path_test.go b/pkg/transport/proxy/transparent/remote_path_test.go index 2e513123f6..5fcac58c46 100644 --- a/pkg/transport/proxy/transparent/remote_path_test.go +++ b/pkg/transport/proxy/transparent/remote_path_test.go @@ -18,6 +18,123 @@ import ( "github.com/stretchr/testify/require" ) +// TestRemoteQueryForwarding verifies that the transparent proxy correctly +// forwards query parameters from the remote URL configuration to every +// outbound request. +// +// Scenario: remoteURL is https://mcp.datadoghq.com/mcp?toolsets=core,alerting +// Without this fix the query params are silently dropped and the remote +// server receives /mcp with no toolsets, returning only default tools. +func TestRemoteQueryForwarding(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + remoteRawQuery string // Query from registration URL + clientRawQuery string // Additional query from client request + expectedRawQuery string // Query that should arrive at the remote server + description string + }{ + { + name: "remote query only, no client query", + remoteRawQuery: "toolsets=core,alerting", + clientRawQuery: "", + expectedRawQuery: "toolsets=core,alerting", + description: "Datadog case: remote query params forwarded when client sends none", + }, + { + name: "remote query merged with client query", + remoteRawQuery: "toolsets=core,alerting", + clientRawQuery: "session=abc", + expectedRawQuery: "toolsets=core,alerting&session=abc", + description: "Remote params take precedence, client params appended", + }, + { + name: "no remote query, client query preserved", + remoteRawQuery: "", + clientRawQuery: "session=abc", + expectedRawQuery: "session=abc", + description: "Without remote query, client query passes through unchanged", + }, + { + name: "no remote query and no client query", + remoteRawQuery: "", + clientRawQuery: "", + expectedRawQuery: "", + description: "No query params in either direction", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var receivedQuery atomic.Value + + remoteServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedQuery.Store(r.URL.RawQuery) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":"1","result":{"protocolVersion":"2024-11-05"}}`)) + })) + defer remoteServer.Close() + + parsedRemote, err := url.Parse(remoteServer.URL) + require.NoError(t, err) + targetURI := (&url.URL{ + Scheme: parsedRemote.Scheme, + Host: parsedRemote.Host, + }).String() + + var opts []Option + if tt.remoteRawQuery != "" { + opts = append(opts, WithRemoteRawQuery(tt.remoteRawQuery)) + } + + proxy := NewTransparentProxyWithOptions( + "127.0.0.1", 0, targetURI, + nil, nil, nil, + false, true, "streamable-http", + nil, nil, + "", false, + nil, // middlewares + opts..., + ) + + ctx := context.Background() + err = proxy.Start(ctx) + require.NoError(t, err) + defer func() { + assert.NoError(t, proxy.Stop(context.Background())) + }() + + addr := proxy.ListenerAddr() + require.NotEmpty(t, addr) + + proxyURL := fmt.Sprintf("http://%s/mcp", addr) + if tt.clientRawQuery != "" { + proxyURL += "?" + tt.clientRawQuery + } + + body := `{"jsonrpc":"2.0","method":"initialize","id":"1","params":{}}` + req, err := http.NewRequest(http.MethodPost, proxyURL, strings.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + actualQuery, _ := receivedQuery.Load().(string) + assert.Equal(t, tt.expectedRawQuery, actualQuery, + "%s: remote server received wrong query string", tt.description) + }) + } +} + // TestRemotePathForwarding verifies that the transparent proxy correctly // forwards requests to the remote server's full path, not just the host. // diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index feb4864df1..e5e48b38dd 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -108,6 +108,12 @@ type TransparentProxy struct { // URI only contains the scheme and host. remoteBasePath string + // remoteRawQuery holds the raw query string from the remote URL (e.g., + // "toolsets=core,alerting" from "https://mcp.example.com/mcp?toolsets=core,alerting"). + // When set, it is merged into every outbound request so query parameters + // from the original registration URL are never silently dropped. + remoteRawQuery string + // Deprecated: trustProxyHeaders indicates whether to trust X-Forwarded-* headers (moved to SSEResponseProcessor) trustProxyHeaders bool @@ -178,6 +184,18 @@ func WithRemoteBasePath(basePath string) Option { } } +// WithRemoteRawQuery sets the raw query string from the remote URL. +// When set, these query parameters are merged into every outbound request, +// ensuring query parameters from the original registration URL are always forwarded. +// Ignores empty strings; default (no query forwarding) will be used. +func WithRemoteRawQuery(rawQuery string) Option { + return func(p *TransparentProxy) { + if rawQuery != "" { + p.remoteRawQuery = rawQuery + } + } +} + // withHealthCheckPingTimeout sets the health check ping timeout. // This is primarily useful for testing with shorter timeouts. // Ignores non-positive timeouts; default will be used. @@ -505,6 +523,21 @@ func (p *TransparentProxy) Start(ctx context.Context) error { pr.Out.URL.RawPath = "" } + // Merge query parameters from the remote URL into the outbound request. + // Remote params are prepended so they appear first; most HTTP servers + // adopt first-value-wins semantics for duplicate keys, ensuring operator + // configured values (e.g., toolsets=core,alerting) take precedence over + // any client-supplied params with the same key. + // Raw string concatenation is intentional: url.Values.Encode() would + // percent-encode characters like commas that some APIs expect as literals. + if p.remoteRawQuery != "" { + merged := p.remoteRawQuery + if pr.Out.URL.RawQuery != "" { + merged += "&" + pr.Out.URL.RawQuery + } + pr.Out.URL.RawQuery = merged + } + // Inject OpenTelemetry trace propagation headers for downstream tracing if pr.Out.Context() != nil { otel.GetTextMapPropagator().Inject(pr.Out.Context(), propagation.HeaderCarrier(pr.Out.Header)) @@ -605,6 +638,15 @@ func (p *TransparentProxy) Start(ctx context.Context) error { return nil } +// ListenerAddr returns the network address the proxy is listening on. +// Returns an empty string if the proxy has not been started. +func (p *TransparentProxy) ListenerAddr() string { + if p.listener == nil { + return "" + } + return p.listener.Addr().String() +} + // CloseListener closes the listener for the transparent proxy. func (p *TransparentProxy) CloseListener() error { if p.listener != nil { diff --git a/pkg/transport/url.go b/pkg/transport/url.go index 66bfb73306..da11145676 100644 --- a/pkg/transport/url.go +++ b/pkg/transport/url.go @@ -48,28 +48,7 @@ func GenerateMCPServerURL(transportType string, proxyMode string, host string, p // ---- Remote path case ---- if remoteURL != "" { - targetURL, err := url.Parse(remoteURL) - if err != nil { - slog.Error("failed to parse target URI", "error", err) - return "" - } - - // Use remote path as-is; treat "/" as empty - path := targetURL.EscapedPath() - if path == "/" { - path = "" - } - - if isSSE { - if path == "" { - path = ssecommon.HTTPSSEEndpoint - } - return fmt.Sprintf("%s%s#%s", base, path, url.PathEscape(containerName)) - } - if isStreamable { - return fmt.Sprintf("%s%s", base, path) - } - return "" + return generateRemoteMCPServerURL(base, containerName, remoteURL, isSSE, isStreamable) } // ---- Local path case (use constants as-is) ---- @@ -85,3 +64,37 @@ func GenerateMCPServerURL(transportType string, proxyMode string, host string, p return "" } + +// generateRemoteMCPServerURL builds the proxy URL for a remote MCP server, +// using only the path from the remote URL. +// +// Query parameters are intentionally excluded from the generated client URL. +// The transparent proxy forwards them on every outbound request via +// WithRemoteRawQuery, so including them here would cause duplication — +// the upstream would receive the same parameter twice (e.g. +// "toolsets=core&toolsets=core"). Clients connect to the clean proxy +// URL; the proxy transparently appends the configured query string. +func generateRemoteMCPServerURL(base, containerName, remoteURL string, isSSE, isStreamable bool) string { + targetURL, err := url.Parse(remoteURL) + if err != nil { + slog.Error("failed to parse target URI", "error", err) + return "" + } + + // Use remote path as-is; treat "/" as empty + path := targetURL.EscapedPath() + if path == "/" { + path = "" + } + + if isSSE { + if path == "" { + path = ssecommon.HTTPSSEEndpoint + } + return fmt.Sprintf("%s%s#%s", base, path, url.PathEscape(containerName)) + } + if isStreamable { + return fmt.Sprintf("%s%s", base, path) + } + return "" +} diff --git a/pkg/transport/url_test.go b/pkg/transport/url_test.go index 5cb65ed7aa..fbcb362490 100644 --- a/pkg/transport/url_test.go +++ b/pkg/transport/url_test.go @@ -165,6 +165,42 @@ func TestGenerateMCPServerURL(t *testing.T) { targetURI: "http://remote.com/api", expected: "http://localhost:12345/api#test-container", }, + { + // Query params are excluded from the client URL — the proxy forwards + // them transparently via WithRemoteRawQuery to avoid duplication. + name: "Streamable HTTP with query parameters in targetURI strips query from client URL", + transportType: types.TransportTypeStreamableHTTP.String(), + proxyMode: "", + host: "localhost", + port: 12345, + containerName: "test-container", + targetURI: "https://mcp.datadoghq.com/api/unstable/mcp?toolsets=core,alerting,apm", + expected: "http://localhost:12345/api/unstable/mcp", + }, + { + // Query params are excluded from the client URL — the proxy forwards + // them transparently via WithRemoteRawQuery to avoid duplication. + name: "SSE transport with query parameters in targetURI strips query from client URL", + transportType: types.TransportTypeSSE.String(), + proxyMode: "", + host: "localhost", + port: 12345, + containerName: "test-container", + targetURI: "https://mcp.example.com/sse?token=abc123", + expected: "http://localhost:12345/sse#test-container", + }, + { + // Query params are excluded from the client URL — the proxy forwards + // them transparently via WithRemoteRawQuery to avoid duplication. + name: "SSE transport with query parameters and no path in targetURI strips query from client URL", + transportType: types.TransportTypeSSE.String(), + proxyMode: "", + host: "localhost", + port: 12345, + containerName: "test-container", + targetURI: "https://mcp.example.com?token=abc123", + expected: "http://localhost:12345/sse#test-container", + }, } for _, tt := range tests { diff --git a/test/e2e/remote_mcp_query_params_test.go b/test/e2e/remote_mcp_query_params_test.go new file mode 100644 index 0000000000..ff157d1ec9 --- /dev/null +++ b/test/e2e/remote_mcp_query_params_test.go @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package e2e_test + +import ( + "encoding/json" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/stacklok/toolhive/test/e2e" +) + +var _ = Describe("Remote MCP server with URL query parameters", + Label("remote", "mcp", "e2e", "proxy"), Serial, func() { + var config *e2e.TestConfig + + BeforeEach(func() { + config = e2e.NewTestConfig() + + // Check if thv binary is available + err := e2e.CheckTHVBinaryAvailable(config) + Expect(err).ToNot(HaveOccurred(), "thv binary should be available") + }) + + Context("when registering a remote server URL with query parameters", func() { + var serverName string + + BeforeEach(func() { + serverName = e2e.GenerateUniqueServerName("remote-query-params-test") + }) + + AfterEach(func() { + if config.CleanupAfter { + err := e2e.StopAndRemoveMCPServer(config, serverName) + Expect(err).ToNot(HaveOccurred(), "Should be able to stop and remove server") + } + }) + + It("should not include URL query parameters in the generated proxy URL [Serial]", func() { + By("Starting a remote MCP server with query parameters in the URL") + // Use the standard remote test server with a query parameter appended. + // The server ignores unknown params; we verify ToolHive strips them + // from the client-facing proxy URL (the proxy forwards them transparently). + registrationURL := remoteServerURL + "?toolsets=query-test" + e2e.NewTHVCommand(config, "run", + "--name", serverName, + registrationURL).ExpectSuccess() + + By("Waiting for the server to be running") + err := e2e.WaitForMCPServer(config, serverName, 30*time.Second) + Expect(err).ToNot(HaveOccurred(), "Server should be running within 30 seconds") + + By("Verifying the proxy URL does not contain query parameters from the registration URL") + stdout, _ := e2e.NewTHVCommand(config, "list", "--format", "json").ExpectSuccess() + + var workloads []WorkloadInfo + err = json.Unmarshal([]byte(stdout), &workloads) + Expect(err).ToNot(HaveOccurred(), "Should be able to parse JSON output") + + var serverInfo *WorkloadInfo + for i := range workloads { + if workloads[i].Name == serverName { + serverInfo = &workloads[i] + break + } + } + + Expect(serverInfo).ToNot(BeNil(), "Server should appear in the list") + // The proxy URL must not include query params — the transparent proxy + // forwards them to the upstream on every request via WithRemoteRawQuery. + // Including them in the client URL would cause duplication at the upstream. + Expect(serverInfo.URL).NotTo(ContainSubstring("toolsets=query-test"), + "Proxy URL should not include query parameters — the proxy forwards them transparently") + }) + }) + })