Skip to content

Commit 54ca1d5

Browse files
committed
Enhancement: Allow some protocol methods to bypass authentication
This is a requirement for the Docker MCP registry. Related to: docker/mcp-registry#164
1 parent 7dc2b23 commit 54ca1d5

File tree

2 files changed

+111
-25
lines changed

2 files changed

+111
-25
lines changed

cmd/mcp-http/main.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"os/signal"
1010
"regexp"
11+
"slices"
1112
"strings"
1213
"syscall"
1314
"time"
@@ -173,13 +174,41 @@ func tracerMiddleware(resources config.Resources, next http.Handler) http.Handle
173174
}
174175

175176
func authMiddleware(resources config.Resources, next http.Handler) http.Handler {
177+
whitelistEndpoints := map[string][]string{
178+
// health checks don't require authentication
179+
"/api/health": {http.MethodGet, http.MethodOptions},
180+
181+
// allow some protocol methods to bypass authentication
182+
//
183+
// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle
184+
// https://modelcontextprotocol.io/specification/2025-06-18/server/tools#listing-tools
185+
// https://modelcontextprotocol.io/specification/2025-06-18/server/resources#listing-resources
186+
// https://modelcontextprotocol.io/specification/2025-06-18/server/resources#resource-templates
187+
// https://modelcontextprotocol.io/specification/2025-06-18/server/prompts#listing-prompts
188+
"/": {http.MethodPost},
189+
"/tools/list": {http.MethodPost},
190+
"/resources/list": {http.MethodPost},
191+
"/resources/templates/list": {http.MethodPost},
192+
"/prompts/list": {http.MethodPost},
193+
}
194+
195+
whitelistPrefixEndpoints := map[string][]string{
196+
// OAuth2 endpoints cannot require authentication
197+
"/.well-known": {"GET", "OPTIONS"},
198+
}
199+
176200
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
177201
// some endpoints don't require auth
178-
if (r.URL.Path == "/api/health" || strings.HasPrefix(r.URL.Path, "/.well-known")) &&
179-
(r.Method == http.MethodGet || r.Method == http.MethodOptions) {
202+
if methods, ok := whitelistEndpoints[r.URL.Path]; ok && slices.Contains(methods, r.Method) {
180203
next.ServeHTTP(w, r)
181204
return
182205
}
206+
for prefix, methods := range whitelistPrefixEndpoints {
207+
if strings.HasPrefix(r.URL.Path, prefix) && slices.Contains(methods, r.Method) {
208+
next.ServeHTTP(w, r)
209+
return
210+
}
211+
}
183212

184213
requestLogger := resources.Logger().With(
185214
slog.String("method", r.Method),

cmd/mcp-stdio/main.go

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"log/slog"
1010
"os"
11+
"slices"
1112
"strings"
1213

1314
"github.com/mark3labs/mcp-go/mcp"
@@ -20,7 +21,23 @@ import (
2021
)
2122

2223
var (
23-
methods = methodsInput([]toolsets.Method{toolsets.MethodAll})
24+
methods = methodsInput([]toolsets.Method{toolsets.MethodAll})
25+
methodsWhitelist = []string{
26+
// allow some protocol methods to bypass authentication
27+
//
28+
// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle
29+
// https://modelcontextprotocol.io/specification/2025-06-18/server/tools#listing-tools
30+
// https://modelcontextprotocol.io/specification/2025-06-18/server/resources#listing-resources
31+
// https://modelcontextprotocol.io/specification/2025-06-18/server/resources#resource-templates
32+
// https://modelcontextprotocol.io/specification/2025-06-18/server/prompts#listing-prompts
33+
"initialize",
34+
"notifications/initialized",
35+
"logging/setLevel",
36+
"tools/list",
37+
"resources/list",
38+
"resources/templates/list",
39+
"prompts/list",
40+
}
2441
readOnly bool
2542
)
2643

@@ -34,33 +51,31 @@ func main() {
3451
flag.BoolVar(&readOnly, "read-only", false, "Restrict the server to read-only operations")
3552
flag.Parse()
3653

37-
if resources.Info.BearerToken == "" {
38-
mcpError(resources, errors.New("TW_MCP_BEARER_TOKEN environment variable is not set"), mcp.INVALID_PARAMS)
39-
exit(exitCodeSetupFailure)
40-
}
41-
4254
ctx := context.Background()
4355

44-
// detect the installation from the bearer token
45-
info, err := auth.GetBearerInfo(ctx, resources, resources.Info.BearerToken)
46-
if err != nil {
47-
mcpError(resources, fmt.Errorf("failed to authenticate: %s", err), mcp.INVALID_PARAMS)
48-
exit(exitCodeSetupFailure)
49-
}
56+
if resources.Info.BearerToken != "" {
57+
// detect the installation from the bearer token
58+
info, err := auth.GetBearerInfo(ctx, resources, resources.Info.BearerToken)
59+
if err != nil {
60+
mcpError(resources.Logger(), fmt.Errorf("failed to authenticate: %s", err), mcp.INVALID_PARAMS)
61+
exit(exitCodeSetupFailure)
62+
}
5063

51-
// inject customer URL in the context
52-
ctx = config.WithCustomerURL(ctx, info.URL)
53-
// inject bearer token in the context
54-
ctx = session.WithBearerTokenContext(ctx, session.NewBearerToken(resources.Info.BearerToken, info.URL))
64+
// inject customer URL in the context
65+
ctx = config.WithCustomerURL(ctx, info.URL)
66+
// inject bearer token in the context
67+
ctx = session.WithBearerTokenContext(ctx, session.NewBearerToken(resources.Info.BearerToken, info.URL))
68+
}
5569

5670
mcpServer, err := newMCPServer(resources)
5771
if err != nil {
58-
mcpError(resources, fmt.Errorf("failed to create MCP server: %s", err), mcp.INTERNAL_ERROR)
72+
mcpError(resources.Logger(), fmt.Errorf("failed to create MCP server: %s", err), mcp.INTERNAL_ERROR)
5973
exit(exitCodeSetupFailure)
6074
}
6175
mcpSTDIOServer := server.NewStdioServer(mcpServer)
62-
if err := mcpSTDIOServer.Listen(ctx, os.Stdin, os.Stdout); err != nil {
63-
mcpError(resources, fmt.Errorf("failed to serve: %s", err), mcp.INTERNAL_ERROR)
76+
stdinWrapper := newStdinWrapper(resources.Logger(), resources.Info.BearerToken != "", methodsWhitelist)
77+
if err := mcpSTDIOServer.Listen(ctx, stdinWrapper, os.Stdout); err != nil {
78+
mcpError(resources.Logger(), fmt.Errorf("failed to serve: %s", err), mcp.INTERNAL_ERROR)
6479
exit(exitCodeSetupFailure)
6580
}
6681
}
@@ -73,14 +88,16 @@ func newMCPServer(resources config.Resources) (*server.MCPServer, error) {
7388
return config.NewMCPServer(resources, group), nil
7489
}
7590

76-
func mcpError(resources config.Resources, err error, code int) {
91+
func mcpError(logger *slog.Logger, err error, code int) {
7792
mcpError := mcp.NewJSONRPCError(mcp.NewRequestId("startup"), code, err.Error(), nil)
78-
encoder := json.NewEncoder(os.Stdout)
79-
if err := encoder.Encode(mcpError); err != nil {
80-
resources.Logger().Error("failed to encode error",
93+
encoded, err := json.Marshal(mcpError)
94+
if err != nil {
95+
logger.Error("failed to encode error",
8196
slog.String("error", err.Error()),
8297
)
98+
return
8399
}
100+
fmt.Printf("%s\n", string(encoded))
84101
}
85102

86103
type methodsInput []toolsets.Method
@@ -110,6 +127,46 @@ func (t *methodsInput) Set(value string) error {
110127
return errs
111128
}
112129

130+
type stdinWrapper struct {
131+
logger *slog.Logger
132+
authenticated bool
133+
methodsWhitelist []string
134+
}
135+
136+
func newStdinWrapper(logger *slog.Logger, authenticated bool, methods []string) stdinWrapper {
137+
return stdinWrapper{
138+
logger: logger,
139+
authenticated: authenticated,
140+
methodsWhitelist: methods,
141+
}
142+
}
143+
144+
func (s stdinWrapper) Read(p []byte) (n int, err error) {
145+
if s.authenticated {
146+
return os.Stdin.Read(p)
147+
}
148+
buffer := make([]byte, len(p))
149+
n, err = os.Stdin.Read(buffer)
150+
if err != nil {
151+
return n, err
152+
}
153+
content := buffer[:n]
154+
if len(content) == 0 {
155+
return n, err
156+
}
157+
var baseMessage struct {
158+
Method string `json:"method"`
159+
}
160+
if err := json.Unmarshal(content, &baseMessage); err != nil {
161+
return 0, errors.New("parse error")
162+
}
163+
if !slices.Contains(s.methodsWhitelist, baseMessage.Method) {
164+
return 0, errors.New("not authenticated")
165+
}
166+
copy(p, buffer)
167+
return n, err
168+
}
169+
113170
type exitCode int
114171

115172
const (

0 commit comments

Comments
 (0)