diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index b9d8af64f..7dd5cb3db 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -100,6 +100,7 @@ var ( Version: version, Host: viper.GetString("host"), Port: viper.GetInt("port"), + BaseURL: viper.GetString("base-url"), ExportTranslations: viper.GetBool("export-translations"), EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), @@ -132,6 +133,7 @@ func init() { rootCmd.PersistentFlags().Bool("insider-mode", false, "Enable insider features") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port") + rootCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -148,6 +150,7 @@ func init() { _ = viper.BindPFlag("insider-mode", rootCmd.PersistentFlags().Lookup("insider-mode")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) _ = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) + _ = viper.BindPFlag("base-url", rootCmd.PersistentFlags().Lookup("base-url")) // Add subcommands rootCmd.AddCommand(stdioCmd) diff --git a/pkg/http/handler.go b/pkg/http/handler.go index bee065196..671ebe3a0 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -9,6 +9,7 @@ import ( "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/github/github-mcp-server/pkg/http/middleware" + "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/go-chi/chi/v5" @@ -26,11 +27,13 @@ type HTTPMcpHandler struct { t translations.TranslationHelperFunc githubMcpServerFactory GitHubMCPServerFactoryFunc inventoryFactoryFunc InventoryFactoryFunc + oauthCfg *oauth.Config } type HTTPMcpHandlerOptions struct { GitHubMcpServerFactory GitHubMCPServerFactoryFunc InventoryFactory InventoryFactoryFunc + OAuthConfig *oauth.Config } type HTTPMcpHandlerOption func(*HTTPMcpHandlerOptions) @@ -47,6 +50,12 @@ func WithInventoryFactory(f InventoryFactoryFunc) HTTPMcpHandlerOption { } } +func WithOAuthConfig(cfg *oauth.Config) HTTPMcpHandlerOption { + return func(o *HTTPMcpHandlerOptions) { + o.OAuthConfig = cfg + } +} + func NewHTTPMcpHandler( ctx context.Context, cfg *HTTPServerConfig, @@ -77,6 +86,7 @@ func NewHTTPMcpHandler( t: t, githubMcpServerFactory: githubMcpServerFactory, inventoryFactoryFunc: inventoryFactory, + oauthCfg: opts.OAuthConfig, } } @@ -134,7 +144,7 @@ func (h *HTTPMcpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { Stateless: true, }) - middleware.ExtractUserToken()(mcpHandler).ServeHTTP(w, r) + middleware.ExtractUserToken(h.oauthCfg)(mcpHandler).ServeHTTP(w, r) } func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) { diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go index b73104c34..1e0d3be47 100644 --- a/pkg/http/headers/headers.go +++ b/pkg/http/headers/headers.go @@ -21,6 +21,15 @@ const ( // RealIPHeader is a standard HTTP Header used to indicate the real IP address of the client. RealIPHeader = "X-Real-IP" + // ForwardedHostHeader is a standard HTTP Header for preserving the original Host header when proxying. + ForwardedHostHeader = "X-Forwarded-Host" + // ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying. + ForwardedProtoHeader = "X-Forwarded-Proto" + + // OriginalPathHeader is set to preserve the original request path + // before the /mcp prefix was stripped during proxying. + OriginalPathHeader = "X-GitHub-Original-Path" + // RequestHmacHeader is used to authenticate requests to the Raw API. RequestHmacHeader = "Request-Hmac" diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index c2e5c6382..93b93279e 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -10,6 +10,7 @@ import ( ghcontext "github.com/github/github-mcp-server/pkg/context" httpheaders "github.com/github/github-mcp-server/pkg/http/headers" "github.com/github/github-mcp-server/pkg/http/mark" + "github.com/github/github-mcp-server/pkg/http/oauth" ) type authType int @@ -40,14 +41,14 @@ var supportedThirdPartyTokenPrefixes = []string{ // were 40 characters long and only contained the characters a-f and 0-9. var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) -func ExtractUserToken() func(next http.Handler) http.Handler { +func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, token, err := parseAuthorizationHeader(r) if err != nil { // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec if errors.Is(err, errMissingAuthorizationHeader) { - // sendAuthChallenge(w, r, cfg, obsv) + sendAuthChallenge(w, r, oauthCfg) return } // For other auth errors (bad format, unsupported), return 400 @@ -63,6 +64,15 @@ func ExtractUserToken() func(next http.Handler) http.Handler { }) } } + +// sendAuthChallenge sends a 401 Unauthorized response with WWW-Authenticate header +// containing the OAuth protected resource metadata URL as per RFC 6750 and MCP spec. +func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.Config) { + resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, "mcp") + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL)) + http.Error(w, "Unauthorized", http.StatusUnauthorized) +} + func parseAuthorizationHeader(req *http.Request) (authType authType, token string, _ error) { authHeader := req.Header.Get(httpheaders.AuthorizationHeader) if authHeader == "" { diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go new file mode 100644 index 000000000..f24db6786 --- /dev/null +++ b/pkg/http/oauth/oauth.go @@ -0,0 +1,244 @@ +// Package oauth provides OAuth 2.0 Protected Resource Metadata (RFC 9728) support +// for the GitHub MCP Server HTTP mode. +package oauth + +import ( + "bytes" + _ "embed" + "fmt" + "html" + "net/http" + "net/url" + "strings" + "text/template" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/go-chi/chi/v5" +) + +const ( + // OAuthProtectedResourcePrefix is the well-known path prefix for OAuth protected resource metadata. + OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource" + + // DefaultAuthorizationServer is GitHub's OAuth authorization server. + DefaultAuthorizationServer = "https://github.com/login/oauth" +) + +//go:embed protected_resource.json.tmpl +var protectedResourceTemplate []byte + +// SupportedScopes lists all OAuth scopes that may be required by MCP tools. +var SupportedScopes = []string{ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace", +} + +// Config holds the OAuth configuration for the MCP server. +type Config struct { + // BaseURL is the publicly accessible URL where this server is hosted. + // This is used to construct the OAuth resource URL. + BaseURL string + + // AuthorizationServer is the OAuth authorization server URL. + // Defaults to GitHub's OAuth server if not specified. + AuthorizationServer string + + // ResourcePath is the resource path suffix (e.g., "/mcp"). + // If empty, defaults to "/" + ResourcePath string +} + +// ProtectedResourceData contains the data needed to build an OAuth protected resource response. +type ProtectedResourceData struct { + ResourceURL string + AuthorizationServer string +} + +// AuthHandler handles OAuth-related HTTP endpoints. +type AuthHandler struct { + cfg *Config + protectedResourceTemplate *template.Template +} + +// NewAuthHandler creates a new OAuth auth handler. +func NewAuthHandler(cfg *Config) (*AuthHandler, error) { + if cfg == nil { + cfg = &Config{} + } + + // Default authorization server to GitHub + if cfg.AuthorizationServer == "" { + cfg.AuthorizationServer = DefaultAuthorizationServer + } + + tmpl, err := template.New("protected-resource").Parse(string(protectedResourceTemplate)) + if err != nil { + return nil, fmt.Errorf("failed to parse protected resource template: %w", err) + } + + return &AuthHandler{ + cfg: cfg, + protectedResourceTemplate: tmpl, + }, nil +} + +// routePatterns defines the route patterns for OAuth protected resource metadata. +var routePatterns = []string{ + "", // Root: /.well-known/oauth-protected-resource + "/readonly", // Read-only mode + "/x/{toolset}", + "/x/{toolset}/readonly", +} + +// RegisterRoutes registers the OAuth protected resource metadata routes. +func (h *AuthHandler) RegisterRoutes(r chi.Router) { + for _, pattern := range routePatterns { + for _, route := range h.routesForPattern(pattern) { + path := OAuthProtectedResourcePrefix + route + r.Get(path, h.handleProtectedResource) + r.Options(path, h.handleProtectedResource) // CORS support + } + } +} + +// routesForPattern generates route variants for a given pattern. +// GitHub strips the /mcp prefix before forwarding, so we register both variants: +// - With /mcp prefix: for direct access or when GitHub doesn't strip +// - Without /mcp prefix: for when GitHub has stripped the prefix +func (h *AuthHandler) routesForPattern(pattern string) []string { + return []string{ + pattern, + "/mcp" + pattern, + pattern + "/", + "/mcp" + pattern + "/", + } +} + +// handleProtectedResource handles requests for OAuth protected resource metadata. +func (h *AuthHandler) handleProtectedResource(w http.ResponseWriter, r *http.Request) { + // Extract the resource path from the URL + resourcePath := strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix) + if resourcePath == "" || resourcePath == "/" { + resourcePath = "/" + } else { + resourcePath = strings.TrimPrefix(resourcePath, "/") + } + + data, err := h.GetProtectedResourceData(r, html.EscapeString(resourcePath)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var buf bytes.Buffer + if err := h.protectedResourceTemplate.Execute(&buf, data); err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Set CORS headers + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(buf.Bytes()) +} + +// GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs. +// It checks for the X-GitHub-Original-Path header set by GitHub, which contains +// the exact path the client requested before the /mcp prefix was stripped. +// If the header is not present, it falls back to +// restoring the /mcp prefix. +func GetEffectiveResourcePath(r *http.Request) string { + // Check for the original path header from GitHub (preferred method) + if originalPath := r.Header.Get(headers.OriginalPathHeader); originalPath != "" { + return originalPath + } + + // Fallback: GitHub strips /mcp prefix, so we need to restore it for the external URL + if r.URL.Path == "/" { + return "/mcp" + } + return "/mcp" + r.URL.Path +} + +// GetProtectedResourceData builds the OAuth protected resource data for a request. +func (h *AuthHandler) GetProtectedResourceData(r *http.Request, resourcePath string) (*ProtectedResourceData, error) { + host, scheme := GetEffectiveHostAndScheme(r, h.cfg) + + // Build the base URL + baseURL := fmt.Sprintf("%s://%s", scheme, host) + if h.cfg.BaseURL != "" { + baseURL = strings.TrimSuffix(h.cfg.BaseURL, "/") + } + + // Build the resource URL using url.JoinPath for proper path handling + var resourceURL string + var err error + if resourcePath == "/" { + resourceURL = baseURL + "/" + } else { + resourceURL, err = url.JoinPath(baseURL, resourcePath) + if err != nil { + return nil, fmt.Errorf("failed to build resource URL: %w", err) + } + } + + return &ProtectedResourceData{ + ResourceURL: resourceURL, + AuthorizationServer: h.cfg.AuthorizationServer, + }, nil +} + +// GetEffectiveHostAndScheme returns the effective host and scheme for a request. +// It checks X-Forwarded-Host and X-Forwarded-Proto headers first (set by proxies), +// then falls back to the request's Host and TLS state. +func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { //nolint:revive // parameters are required by http.oauth.BuildResourceMetadataURL signature + // Check for forwarded headers first (typically set by reverse proxies) + if forwardedHost := r.Header.Get(headers.ForwardedHostHeader); forwardedHost != "" { + host = forwardedHost + } else { + host = r.Host + } + + // Determine scheme + switch { + case r.Header.Get(headers.ForwardedProtoHeader) != "": + scheme = strings.ToLower(r.Header.Get(headers.ForwardedProtoHeader)) + case r.TLS != nil: + scheme = "https" + default: + // Default to HTTPS in production scenarios + scheme = "https" + } + + return host, scheme +} + +// BuildResourceMetadataURL constructs the full URL to the OAuth protected resource metadata endpoint. +func BuildResourceMetadataURL(r *http.Request, cfg *Config, resourcePath string) string { + host, scheme := GetEffectiveHostAndScheme(r, cfg) + + if cfg != nil && cfg.BaseURL != "" { + baseURL := strings.TrimSuffix(cfg.BaseURL, "/") + return baseURL + OAuthProtectedResourcePrefix + "/" + strings.TrimPrefix(resourcePath, "/") + } + + path := OAuthProtectedResourcePrefix + if resourcePath != "" && resourcePath != "/" { + path = path + "/" + strings.TrimPrefix(resourcePath, "/") + } + + return fmt.Sprintf("%s://%s%s", scheme, host, path) +} diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go new file mode 100644 index 000000000..035f5c35b --- /dev/null +++ b/pkg/http/oauth/oauth_test.go @@ -0,0 +1,677 @@ +package oauth + +import ( + "crypto/tls" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAuthHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + expectedAuthServer string + expectedResourcePath string + }{ + { + name: "nil config uses defaults", + cfg: nil, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "", + }, + { + name: "empty config uses defaults", + cfg: &Config{}, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "", + }, + { + name: "custom authorization server", + cfg: &Config{ + AuthorizationServer: "https://custom.example.com/oauth", + }, + expectedAuthServer: "https://custom.example.com/oauth", + expectedResourcePath: "", + }, + { + name: "custom base URL and resource path", + cfg: &Config{ + BaseURL: "https://example.com", + ResourcePath: "/mcp", + }, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "/mcp", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + require.NotNil(t, handler) + + assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer) + assert.NotNil(t, handler.protectedResourceTemplate) + }) + } +} + +func TestGetEffectiveHostAndScheme(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupRequest func() *http.Request + cfg *Config + expectedHost string + expectedScheme string + }{ + { + name: "basic request without forwarding headers", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", // defaults to https + }, + { + name: "request with X-Forwarded-Host header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + return req + }, + cfg: &Config{}, + expectedHost: "public.example.com", + expectedScheme: "https", + }, + { + name: "request with X-Forwarded-Proto header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.Header.Set(headers.ForwardedProtoHeader, "http") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", + }, + { + name: "request with both forwarding headers", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + req.Header.Set(headers.ForwardedProtoHeader, "https") + return req + }, + cfg: &Config{}, + expectedHost: "public.example.com", + expectedScheme: "https", + }, + { + name: "request with TLS", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{} + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", + }, + { + name: "X-Forwarded-Proto takes precedence over TLS", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{} + req.Header.Set(headers.ForwardedProtoHeader, "http") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", + }, + { + name: "scheme is lowercased", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.Header.Set(headers.ForwardedProtoHeader, "HTTPS") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + host, scheme := GetEffectiveHostAndScheme(req, tc.cfg) + + assert.Equal(t, tc.expectedHost, host) + assert.Equal(t, tc.expectedScheme, scheme) + }) + } +} + +func TestGetEffectiveResourcePath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupRequest func() *http.Request + expectedPath string + }{ + { + name: "root path without original path header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + return req + }, + expectedPath: "/mcp", + }, + { + name: "non-root path without original path header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/readonly", nil) + return req + }, + expectedPath: "/mcp/readonly", + }, + { + name: "with X-GitHub-Original-Path header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/readonly", nil) + req.Header.Set(headers.OriginalPathHeader, "/mcp/x/repos/readonly") + return req + }, + expectedPath: "/mcp/x/repos/readonly", + }, + { + name: "original path header takes precedence", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/something-else", nil) + req.Header.Set(headers.OriginalPathHeader, "/mcp/custom/path") + return req + }, + expectedPath: "/mcp/custom/path", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + path := GetEffectiveResourcePath(req) + + assert.Equal(t, tc.expectedPath, path) + }) + } +} + +func TestGetProtectedResourceData(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + setupRequest func() *http.Request + resourcePath string + expectedResourceURL string + expectedAuthServer string + expectError bool + }{ + { + name: "basic request with root resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/", + expectedResourceURL: "https://api.example.com/", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "basic request with custom resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://api.example.com/mcp", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "with custom base URL", + cfg: &Config{ + BaseURL: "https://custom.example.com", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://custom.example.com/mcp", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "with custom authorization server", + cfg: &Config{ + AuthorizationServer: "https://auth.example.com/oauth", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://api.example.com/mcp", + expectedAuthServer: "https://auth.example.com/oauth", + }, + { + name: "base URL with trailing slash is trimmed", + cfg: &Config{ + BaseURL: "https://custom.example.com/", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://custom.example.com/mcp", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "nested resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp/x/repos", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp/x/repos", + expectedResourceURL: "https://api.example.com/mcp/x/repos", + expectedAuthServer: DefaultAuthorizationServer, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + + req := tc.setupRequest() + data, err := handler.GetProtectedResourceData(req, tc.resourcePath) + + if tc.expectError { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.expectedResourceURL, data.ResourceURL) + assert.Equal(t, tc.expectedAuthServer, data.AuthorizationServer) + }) + } +} + +func TestBuildResourceMetadataURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + setupRequest func() *http.Request + resourcePath string + expectedURL string + }{ + { + name: "root path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/", + expectedURL: "https://api.example.com/.well-known/oauth-protected-resource", + }, + { + name: "with custom resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedURL: "https://api.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "with base URL config", + cfg: &Config{ + BaseURL: "https://custom.example.com", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedURL: "https://custom.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "with forwarded headers", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + req.Header.Set(headers.ForwardedProtoHeader, "https") + return req + }, + resourcePath: "/mcp", + expectedURL: "https://public.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "nil config uses request host", + cfg: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "", + expectedURL: "https://api.example.com/.well-known/oauth-protected-resource", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + url := BuildResourceMetadataURL(req, tc.cfg, tc.resourcePath) + + assert.Equal(t, tc.expectedURL, url) + }) + } +} + +func TestHandleProtectedResource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + path string + host string + method string + expectedStatusCode int + expectedScopes []string + validateResponse func(t *testing.T, body map[string]any) + }{ + { + name: "GET request returns protected resource metadata", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + expectedScopes: SupportedScopes, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "GitHub MCP Server", body["resource_name"]) + assert.Contains(t, body["resource"], "api.example.com") + + authServers, ok := body["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, authServers, 1) + assert.Equal(t, DefaultAuthorizationServer, authServers[0]) + }, + }, + { + name: "OPTIONS request for CORS", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodOptions, + expectedStatusCode: http.StatusOK, + }, + { + name: "path with /mcp suffix", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix + "/mcp", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Contains(t, body["resource"], "/mcp") + }, + }, + { + name: "path with /readonly suffix", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix + "/readonly", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Contains(t, body["resource"], "/readonly") + }, + }, + { + name: "custom authorization server in response", + cfg: &Config{ + AuthorizationServer: "https://custom.auth.example.com/oauth", + }, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + authServers, ok := body["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, authServers, 1) + assert.Equal(t, "https://custom.auth.example.com/oauth", authServers[0]) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(tc.method, tc.path, nil) + req.Host = tc.host + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatusCode, rec.Code) + + // Check CORS headers + assert.Equal(t, "*", rec.Header().Get("Access-Control-Allow-Origin")) + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "GET") + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "OPTIONS") + + if tc.method == http.MethodGet && tc.validateResponse != nil { + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var body map[string]any + err := json.Unmarshal(rec.Body.Bytes(), &body) + require.NoError(t, err) + + tc.validateResponse(t, body) + + // Verify scopes if expected + if tc.expectedScopes != nil { + scopes, ok := body["scopes_supported"].([]any) + require.True(t, ok) + assert.Len(t, scopes, len(tc.expectedScopes)) + } + } + }) + } +} + +func TestRegisterRoutes(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(&Config{}) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + // List of expected routes that should be registered + expectedRoutes := []string{ + OAuthProtectedResourcePrefix, + OAuthProtectedResourcePrefix + "/", + OAuthProtectedResourcePrefix + "/mcp", + OAuthProtectedResourcePrefix + "/mcp/", + OAuthProtectedResourcePrefix + "/readonly", + OAuthProtectedResourcePrefix + "/readonly/", + OAuthProtectedResourcePrefix + "/mcp/readonly", + OAuthProtectedResourcePrefix + "/mcp/readonly/", + OAuthProtectedResourcePrefix + "/x/repos", + OAuthProtectedResourcePrefix + "/mcp/x/repos", + } + + for _, route := range expectedRoutes { + t.Run("route:"+route, func(t *testing.T) { + // Test GET + req := httptest.NewRequest(http.MethodGet, route, nil) + req.Host = "api.example.com" + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "GET %s should return 200", route) + + // Test OPTIONS (CORS) + req = httptest.NewRequest(http.MethodOptions, route, nil) + req.Host = "api.example.com" + rec = httptest.NewRecorder() + router.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "OPTIONS %s should return 200", route) + }) + } +} + +func TestSupportedScopes(t *testing.T) { + t.Parallel() + + // Verify all expected scopes are present + expectedScopes := []string{ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace", + } + + assert.Equal(t, expectedScopes, SupportedScopes) +} + +func TestProtectedResourceResponseFormat(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(&Config{}) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(http.MethodGet, OAuthProtectedResourcePrefix, nil) + req.Host = "api.example.com" + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response map[string]any + err = json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Verify all required RFC 9728 fields are present + assert.Contains(t, response, "resource") + assert.Contains(t, response, "authorization_servers") + assert.Contains(t, response, "bearer_methods_supported") + assert.Contains(t, response, "scopes_supported") + + // Verify resource name (optional but we include it) + assert.Contains(t, response, "resource_name") + assert.Equal(t, "GitHub MCP Server", response["resource_name"]) + + // Verify bearer_methods_supported contains "header" + bearerMethods, ok := response["bearer_methods_supported"].([]any) + require.True(t, ok) + assert.Contains(t, bearerMethods, "header") + + // Verify authorization_servers is an array with GitHub OAuth + authServers, ok := response["authorization_servers"].([]any) + require.True(t, ok) + assert.Len(t, authServers, 1) + assert.Equal(t, DefaultAuthorizationServer, authServers[0]) +} + +func TestOAuthProtectedResourcePrefix(t *testing.T) { + t.Parallel() + + // RFC 9728 specifies this well-known path + assert.Equal(t, "/.well-known/oauth-protected-resource", OAuthProtectedResourcePrefix) +} + +func TestDefaultAuthorizationServer(t *testing.T) { + t.Parallel() + + assert.Equal(t, "https://github.com/login/oauth", DefaultAuthorizationServer) +} diff --git a/pkg/http/oauth/protected_resource.json.tmpl b/pkg/http/oauth/protected_resource.json.tmpl new file mode 100644 index 000000000..7a9257404 --- /dev/null +++ b/pkg/http/oauth/protected_resource.json.tmpl @@ -0,0 +1,20 @@ +{ + "resource_name": "GitHub MCP Server", + "resource": "{{.ResourceURL}}", + "authorization_servers": ["{{.AuthorizationServer}}"], + "bearer_methods_supported": ["header"], + "scopes_supported": [ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace" + ] +} diff --git a/pkg/http/server.go b/pkg/http/server.go index c14ae9eee..2ff942d80 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -12,6 +12,7 @@ import ( "time" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" @@ -28,6 +29,10 @@ type HTTPServerConfig struct { // Port to listen on (default: 8082) Port int + // BaseURL is the publicly accessible URL of this server for OAuth resource metadata. + // If not set, the server will derive the URL from incoming request headers. + BaseURL string + // ExportTranslations indicates if we should export translations // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions ExportTranslations bool @@ -98,6 +103,17 @@ func RunHTTPServer(cfg HTTPServerConfig) error { r := chi.NewRouter() + // Register OAuth protected resource metadata endpoints + oauthCfg := &oauth.Config{ + BaseURL: cfg.BaseURL, + } + oauthHandler, err := oauth.NewAuthHandler(oauthCfg) + if err != nil { + return fmt.Errorf("failed to create OAuth handler: %w", err) + } + oauthHandler.RegisterRoutes(r) + logger.Info("OAuth protected resource endpoints registered", "baseURL", cfg.BaseURL) + handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger) handler.RegisterRoutes(r)