feat: switch OAuth to authorization code + PKCE flow

Claude Desktop uses authorization code flow, not client credentials.
Added /authorize endpoint that auto-approves (single-user setup) and
redirects with code. Token endpoint now supports both grant types.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Mikkel Georgsen 2026-03-30 10:44:05 +00:00
parent 205b978b89
commit 15e3582787
2 changed files with 168 additions and 40 deletions

View file

@ -24,6 +24,7 @@ logger = logging.getLogger("mcp_bridge")
# Paths that don't require auth # Paths that don't require auth
PUBLIC_PATHS = { PUBLIC_PATHS = {
"/.well-known/oauth-authorization-server", "/.well-known/oauth-authorization-server",
"/authorize",
"/token", "/token",
"/api/health", "/api/health",
"/api/ingest", # Local-only, not exposed via NPM "/api/ingest", # Local-only, not exposed via NPM

View file

@ -1,22 +1,26 @@
"""OAuth 2.0 client credentials auth for MCP server.""" """OAuth 2.0 authorization code + PKCE auth for MCP server."""
import base64
import hashlib import hashlib
import json import json
import logging import logging
import secrets import secrets
import time import time
from urllib.parse import urlencode, urlparse, parse_qs
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse, RedirectResponse
from starlette.routing import Route from starlette.routing import Route
from .config import load_credentials from .config import load_credentials
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# In-memory token store: token_hash -> expiry timestamp # In-memory stores
_active_tokens: dict[str, float] = {} _active_tokens: dict[str, float] = {} # token_hash -> expiry
_auth_codes: dict[str, dict] = {} # code -> {client_id, redirect_uri, code_challenge, expires}
TOKEN_LIFETIME = 3600 # 1 hour TOKEN_LIFETIME = 3600 # 1 hour
CODE_LIFETIME = 300 # 5 minutes
def _get_oauth_credentials() -> tuple[str, str]: def _get_oauth_credentials() -> tuple[str, str]:
@ -29,35 +33,105 @@ def _get_oauth_credentials() -> tuple[str, str]:
return client_id, client_secret return client_id, client_secret
def _hash_token(token: str) -> str: def _hash(value: str) -> str:
return hashlib.sha256(token.encode()).hexdigest() return hashlib.sha256(value.encode()).hexdigest()
def _cleanup_expired(): def _cleanup_expired():
"""Remove expired tokens.""" """Remove expired tokens and codes."""
now = time.time() now = time.time()
expired = [h for h, exp in _active_tokens.items() if exp < now] for store in (_active_tokens, _auth_codes):
for h in expired: expired = [k for k, v in store.items()
del _active_tokens[h] if (v if isinstance(v, float) else v.get("expires", 0)) < now]
for k in expired:
del store[k]
def _verify_pkce(code_verifier: str, code_challenge: str) -> bool:
"""Verify PKCE S256 challenge."""
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
return secrets.compare_digest(computed, code_challenge)
def validate_bearer_token(token: str) -> bool: def validate_bearer_token(token: str) -> bool:
"""Check if a bearer token is valid and not expired.""" """Check if a bearer token is valid and not expired."""
_cleanup_expired() _cleanup_expired()
token_hash = _hash_token(token) token_hash = _hash(token)
return token_hash in _active_tokens and _active_tokens[token_hash] > time.time() return token_hash in _active_tokens and _active_tokens[token_hash] > time.time()
async def authorize_endpoint(request: Request):
"""OAuth 2.0 Authorization endpoint.
GET /authorize?response_type=code&client_id=...&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...
Single-user setup: auto-approves if client_id matches, redirects with code.
"""
params = dict(request.query_params)
response_type = params.get("response_type")
client_id = params.get("client_id", "")
redirect_uri = params.get("redirect_uri", "")
code_challenge = params.get("code_challenge", "")
code_challenge_method = params.get("code_challenge_method", "")
state = params.get("state", "")
if response_type != "code":
return JSONResponse(
{"error": "unsupported_response_type"},
status_code=400,
)
try:
expected_id, _ = _get_oauth_credentials()
except RuntimeError:
return JSONResponse({"error": "server_error"}, status_code=500)
if not secrets.compare_digest(client_id, expected_id):
logger.warning(f"OAuth authorize: invalid client_id from {request.client.host}")
return JSONResponse(
{"error": "invalid_client", "error_description": "Unknown client_id"},
status_code=401,
)
if code_challenge_method and code_challenge_method != "S256":
return JSONResponse(
{"error": "invalid_request", "error_description": "Only S256 code_challenge_method supported"},
status_code=400,
)
# Auto-approve: generate authorization code
code = secrets.token_urlsafe(32)
_auth_codes[code] = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"expires": time.time() + CODE_LIFETIME,
}
_cleanup_expired()
logger.info(f"OAuth authorization code issued for {client_id[:12]}... -> {redirect_uri}")
# Redirect back with code
redirect_params = {"code": code}
if state:
redirect_params["state"] = state
separator = "&" if "?" in redirect_uri else "?"
return RedirectResponse(
url=f"{redirect_uri}{separator}{urlencode(redirect_params)}",
status_code=302,
)
async def token_endpoint(request: Request) -> JSONResponse: async def token_endpoint(request: Request) -> JSONResponse:
"""OAuth 2.0 token endpoint (client_credentials grant). """OAuth 2.0 token endpoint.
POST /token Supports:
Content-Type: application/x-www-form-urlencoded - grant_type=authorization_code (with PKCE)
- grant_type=client_credentials (direct)
grant_type=client_credentials&client_id=...&client_secret=...
""" """
try: try:
# Accept both form-encoded and JSON
content_type = request.headers.get("content-type", "") content_type = request.headers.get("content-type", "")
if "application/json" in content_type: if "application/json" in content_type:
data = await request.json() data = await request.json()
@ -71,38 +145,92 @@ async def token_endpoint(request: Request) -> JSONResponse:
) )
grant_type = data.get("grant_type") grant_type = data.get("grant_type")
if grant_type != "client_credentials":
if grant_type == "authorization_code":
return await _handle_auth_code_grant(data, request)
elif grant_type == "client_credentials":
return await _handle_client_credentials_grant(data, request)
else:
return JSONResponse( return JSONResponse(
{"error": "unsupported_grant_type", "error_description": "Only client_credentials is supported"}, {"error": "unsupported_grant_type",
"error_description": "Supported: authorization_code, client_credentials"},
status_code=400, status_code=400,
) )
async def _handle_auth_code_grant(data: dict, request: Request) -> JSONResponse:
"""Exchange authorization code + PKCE verifier for access token."""
code = data.get("code", "")
client_id = data.get("client_id", "")
code_verifier = data.get("code_verifier", "")
redirect_uri = data.get("redirect_uri", "")
_cleanup_expired()
if code not in _auth_codes:
return JSONResponse(
{"error": "invalid_grant", "error_description": "Invalid or expired authorization code"},
status_code=400,
)
code_data = _auth_codes.pop(code)
# Verify client_id matches
if not secrets.compare_digest(client_id, code_data["client_id"]):
return JSONResponse({"error": "invalid_client"}, status_code=401)
# Verify redirect_uri matches
if redirect_uri and redirect_uri != code_data["redirect_uri"]:
return JSONResponse(
{"error": "invalid_grant", "error_description": "redirect_uri mismatch"},
status_code=400,
)
# Verify PKCE
if code_data["code_challenge"]:
if not code_verifier:
return JSONResponse(
{"error": "invalid_request", "error_description": "code_verifier required"},
status_code=400,
)
if not _verify_pkce(code_verifier, code_data["code_challenge"]):
return JSONResponse(
{"error": "invalid_grant", "error_description": "PKCE verification failed"},
status_code=400,
)
# Issue access token
access_token = secrets.token_urlsafe(48)
_active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME
logger.info(f"OAuth token issued via auth code to {request.client.host}")
return JSONResponse({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": TOKEN_LIFETIME,
})
async def _handle_client_credentials_grant(data: dict, request: Request) -> JSONResponse:
"""Direct client_id + client_secret exchange for access token."""
client_id = data.get("client_id", "") client_id = data.get("client_id", "")
client_secret = data.get("client_secret", "") client_secret = data.get("client_secret", "")
try: try:
expected_id, expected_secret = _get_oauth_credentials() expected_id, expected_secret = _get_oauth_credentials()
except RuntimeError: except RuntimeError:
logger.error("OAuth credentials not configured") return JSONResponse({"error": "server_error"}, status_code=500)
return JSONResponse(
{"error": "server_error", "error_description": "Auth not configured"},
status_code=500,
)
if not secrets.compare_digest(client_id, expected_id) or \ if not secrets.compare_digest(client_id, expected_id) or \
not secrets.compare_digest(client_secret, expected_secret): not secrets.compare_digest(client_secret, expected_secret):
logger.warning(f"OAuth auth failed from {request.client.host}") logger.warning(f"OAuth client_credentials auth failed from {request.client.host}")
return JSONResponse( return JSONResponse({"error": "invalid_client"}, status_code=401)
{"error": "invalid_client", "error_description": "Invalid client credentials"},
status_code=401,
)
# Issue token
access_token = secrets.token_urlsafe(48) access_token = secrets.token_urlsafe(48)
_active_tokens[_hash_token(access_token)] = time.time() + TOKEN_LIFETIME _active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME
_cleanup_expired() _cleanup_expired()
logger.info(f"OAuth token issued to {request.client.host}") logger.info(f"OAuth token issued via client_credentials to {request.client.host}")
return JSONResponse({ return JSONResponse({
"access_token": access_token, "access_token": access_token,
"token_type": "Bearer", "token_type": "Bearer",
@ -111,18 +239,16 @@ async def token_endpoint(request: Request) -> JSONResponse:
async def oauth_metadata(request: Request) -> JSONResponse: async def oauth_metadata(request: Request) -> JSONResponse:
"""OAuth 2.0 Authorization Server Metadata (RFC 8414). """OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
GET /.well-known/oauth-authorization-server
"""
# Build base URL from request
base = str(request.base_url).rstrip("/") base = str(request.base_url).rstrip("/")
return JSONResponse({ return JSONResponse({
"issuer": base, "issuer": base,
"authorization_endpoint": f"{base}/authorize",
"token_endpoint": f"{base}/token", "token_endpoint": f"{base}/token",
"token_endpoint_auth_methods_supported": ["client_secret_post"], "token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
"grant_types_supported": ["client_credentials"], "grant_types_supported": ["authorization_code", "client_credentials"],
"response_types_supported": [], "response_types_supported": ["code"],
"code_challenge_methods_supported": ["S256"],
"scopes_supported": ["mcp"], "scopes_supported": ["mcp"],
}) })
@ -130,5 +256,6 @@ async def oauth_metadata(request: Request) -> JSONResponse:
# Routes to add to the app # Routes to add to the app
auth_routes = [ auth_routes = [
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]), Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
Route("/authorize", authorize_endpoint, methods=["GET"]),
Route("/token", token_endpoint, methods=["POST"]), Route("/token", token_endpoint, methods=["POST"]),
] ]