"""OAuth 2.0 authorization code + PKCE auth for MCP server.""" import base64 import hashlib import json import logging import secrets import time from urllib.parse import urlencode, urlparse, parse_qs from starlette.requests import Request from starlette.responses import JSONResponse, RedirectResponse from starlette.routing import Route from .config import load_credentials logger = logging.getLogger(__name__) # In-memory stores _active_tokens: dict[str, float] = {} # token_hash -> expiry _auth_codes: dict[str, dict] = {} # code -> {client_id, redirect_uri, code_challenge, expires} TOKEN_LIFETIME = 3600 # 1 hour CODE_LIFETIME = 300 # 5 minutes def _get_oauth_credentials() -> tuple[str, str]: """Load OAuth client_id and client_secret from credentials file.""" creds = load_credentials() client_id = creds.get("OAUTH_CLIENT_ID", "") client_secret = creds.get("OAUTH_CLIENT_SECRET", "") if not client_id or not client_secret: raise RuntimeError("OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET must be set in credentials") return client_id, client_secret def _hash(value: str) -> str: return hashlib.sha256(value.encode()).hexdigest() def _cleanup_expired(): """Remove expired tokens and codes.""" now = time.time() for store in (_active_tokens, _auth_codes): expired = [k for k, v in store.items() if (v if isinstance(v, float) else v.get("expires", 0)) < now] for k in expired: del store[k] def _verify_pkce(code_verifier: str, code_challenge: str) -> bool: """Verify PKCE S256 challenge.""" digest = hashlib.sha256(code_verifier.encode("ascii")).digest() computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") return secrets.compare_digest(computed, code_challenge) def validate_bearer_token(token: str) -> bool: """Check if a bearer token is valid and not expired.""" _cleanup_expired() token_hash = _hash(token) return token_hash in _active_tokens and _active_tokens[token_hash] > time.time() async def authorize_endpoint(request: Request): """OAuth 2.0 Authorization endpoint. GET /authorize?response_type=code&client_id=...&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=... Single-user setup: auto-approves if client_id matches, redirects with code. """ params = dict(request.query_params) response_type = params.get("response_type") client_id = params.get("client_id", "") redirect_uri = params.get("redirect_uri", "") code_challenge = params.get("code_challenge", "") code_challenge_method = params.get("code_challenge_method", "") state = params.get("state", "") if response_type != "code": return JSONResponse( {"error": "unsupported_response_type"}, status_code=400, ) try: expected_id, _ = _get_oauth_credentials() except RuntimeError: return JSONResponse({"error": "server_error"}, status_code=500) if not secrets.compare_digest(client_id, expected_id): logger.warning(f"OAuth authorize: invalid client_id from {request.client.host}") return JSONResponse( {"error": "invalid_client", "error_description": "Unknown client_id"}, status_code=401, ) if code_challenge_method and code_challenge_method != "S256": return JSONResponse( {"error": "invalid_request", "error_description": "Only S256 code_challenge_method supported"}, status_code=400, ) # Auto-approve: generate authorization code code = secrets.token_urlsafe(32) _auth_codes[code] = { "client_id": client_id, "redirect_uri": redirect_uri, "code_challenge": code_challenge, "expires": time.time() + CODE_LIFETIME, } _cleanup_expired() logger.info(f"OAuth authorization code issued for {client_id[:12]}... -> {redirect_uri}") # Redirect back with code redirect_params = {"code": code} if state: redirect_params["state"] = state separator = "&" if "?" in redirect_uri else "?" return RedirectResponse( url=f"{redirect_uri}{separator}{urlencode(redirect_params)}", status_code=302, ) async def token_endpoint(request: Request) -> JSONResponse: """OAuth 2.0 token endpoint. Supports: - grant_type=authorization_code (with PKCE) - grant_type=client_credentials (direct) """ # Parse request data from any source data = {} # Always include query params data.update(dict(request.query_params)) # Parse body for POST if request.method == "POST": try: raw_body = await request.body() logger.info(f"TOKEN POST body ({len(raw_body)} bytes): {raw_body[:500]}") content_type = request.headers.get("content-type", "") if "application/json" in content_type and raw_body: data.update(json.loads(raw_body)) elif raw_body: from urllib.parse import parse_qs parsed = parse_qs(raw_body.decode(), keep_blank_values=True) data.update({k: v[0] if len(v) == 1 else v for k, v in parsed.items()}) except Exception as e: logger.warning(f"TOKEN body parse error: {e}") logger.info(f"TOKEN {request.method} data_keys={list(data.keys())} " f"grant_type={data.get('grant_type', '(missing)')}") # Bare GET with no params = probe/ping, return 200 with metadata if request.method == "GET" and not data: return JSONResponse({ "grant_types_supported": ["authorization_code", "client_credentials"], "token_endpoint_auth_methods_supported": ["client_secret_post", "none"], "code_challenge_methods_supported": ["S256"], }) grant_type = data.get("grant_type", "") logger.debug(f"Token request: method={request.method} grant_type={grant_type} params={list(data.keys())}") # If there's a code param but no grant_type, assume authorization_code if not grant_type and "code" in data: grant_type = "authorization_code" if grant_type == "authorization_code": return await _handle_auth_code_grant(data, request) elif grant_type == "client_credentials": return await _handle_client_credentials_grant(data, request) else: logger.warning(f"Token request with unsupported grant_type={grant_type!r}, data keys={list(data.keys())}") return JSONResponse( {"error": "unsupported_grant_type", "error_description": f"Supported: authorization_code, client_credentials. Got: {grant_type!r}"}, status_code=400, ) async def _handle_auth_code_grant(data: dict, request: Request) -> JSONResponse: """Exchange authorization code + PKCE verifier for access token.""" code = data.get("code", "") client_id = data.get("client_id", "") code_verifier = data.get("code_verifier", "") redirect_uri = data.get("redirect_uri", "") _cleanup_expired() if code not in _auth_codes: return JSONResponse( {"error": "invalid_grant", "error_description": "Invalid or expired authorization code"}, status_code=400, ) code_data = _auth_codes.pop(code) # Verify client_id matches if not secrets.compare_digest(client_id, code_data["client_id"]): return JSONResponse({"error": "invalid_client"}, status_code=401) # Verify redirect_uri matches if redirect_uri and redirect_uri != code_data["redirect_uri"]: return JSONResponse( {"error": "invalid_grant", "error_description": "redirect_uri mismatch"}, status_code=400, ) # Verify PKCE if code_data["code_challenge"]: if not code_verifier: return JSONResponse( {"error": "invalid_request", "error_description": "code_verifier required"}, status_code=400, ) if not _verify_pkce(code_verifier, code_data["code_challenge"]): return JSONResponse( {"error": "invalid_grant", "error_description": "PKCE verification failed"}, status_code=400, ) # Issue access token access_token = secrets.token_urlsafe(48) _active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME logger.info(f"OAuth token issued via auth code to {request.client.host}") return JSONResponse({ "access_token": access_token, "token_type": "Bearer", "expires_in": TOKEN_LIFETIME, }) async def _handle_client_credentials_grant(data: dict, request: Request) -> JSONResponse: """Direct client_id + client_secret exchange for access token.""" client_id = data.get("client_id", "") client_secret = data.get("client_secret", "") try: expected_id, expected_secret = _get_oauth_credentials() except RuntimeError: return JSONResponse({"error": "server_error"}, status_code=500) if not secrets.compare_digest(client_id, expected_id) or \ not secrets.compare_digest(client_secret, expected_secret): logger.warning(f"OAuth client_credentials auth failed from {request.client.host}") return JSONResponse({"error": "invalid_client"}, status_code=401) access_token = secrets.token_urlsafe(48) _active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME _cleanup_expired() logger.info(f"OAuth token issued via client_credentials to {request.client.host}") return JSONResponse({ "access_token": access_token, "token_type": "Bearer", "expires_in": TOKEN_LIFETIME, }) async def oauth_metadata(request: Request) -> JSONResponse: """OAuth 2.0 Authorization Server Metadata (RFC 8414).""" base = str(request.base_url).rstrip("/") return JSONResponse({ "issuer": base, "authorization_endpoint": f"{base}/authorize", "token_endpoint": f"{base}/token", "registration_endpoint": f"{base}/register", "token_endpoint_auth_methods_supported": ["none"], "grant_types_supported": ["authorization_code"], "response_types_supported": ["code"], "code_challenge_methods_supported": ["S256"], "scopes_supported": ["mcp"], }) async def protected_resource_metadata(request: Request) -> JSONResponse: """OAuth 2.0 Protected Resource Metadata (RFC 9728).""" base = str(request.base_url).rstrip("/") return JSONResponse({ "resource": base, "authorization_servers": [base], "scopes_supported": ["mcp"], "bearer_methods_supported": ["header"], }) # Routes to add to the app async def register_endpoint(request: Request) -> JSONResponse: """OAuth 2.0 Dynamic Client Registration (stub). Claude may try to register dynamically. Since we use pre-configured credentials, just return the existing client info. """ try: client_id, _ = _get_oauth_credentials() except RuntimeError: return JSONResponse({"error": "server_error"}, status_code=500) base = str(request.base_url).rstrip("/") logger.info(f"REGISTER endpoint called: method={request.method}") if request.method == "POST": body = await request.body() logger.info(f"REGISTER body: {body[:500]}") return JSONResponse({ "client_id": client_id, "client_name": "claude-desktop", "redirect_uris": ["https://claude.ai/api/mcp/auth_callback"], "grant_types": ["authorization_code"], "response_types": ["code"], "token_endpoint_auth_method": "none", }) auth_routes = [ Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]), Route("/.well-known/oauth-protected-resource", protected_resource_metadata, methods=["GET"]), Route("/.well-known/oauth-protected-resource/mcp", protected_resource_metadata, methods=["GET"]), Route("/authorize", authorize_endpoint, methods=["GET"]), Route("/token", token_endpoint, methods=["GET", "POST"]), Route("/register", register_endpoint, methods=["GET", "POST"]), ]