diff --git a/mcp_bridge/__main__.py b/mcp_bridge/__main__.py index 4a90eec..7555f0b 100644 --- a/mcp_bridge/__main__.py +++ b/mcp_bridge/__main__.py @@ -24,6 +24,7 @@ logger = logging.getLogger("mcp_bridge") # Paths that don't require auth PUBLIC_PATHS = { "/.well-known/oauth-authorization-server", + "/authorize", "/token", "/api/health", "/api/ingest", # Local-only, not exposed via NPM diff --git a/mcp_bridge/auth.py b/mcp_bridge/auth.py index 648708f..928fe5f 100644 --- a/mcp_bridge/auth.py +++ b/mcp_bridge/auth.py @@ -1,22 +1,26 @@ -"""OAuth 2.0 client credentials auth for MCP server.""" +"""OAuth 2.0 authorization code + PKCE auth for MCP server.""" +import base64 import hashlib import json import logging import secrets import time +from urllib.parse import urlencode, urlparse, parse_qs from starlette.requests import Request -from starlette.responses import JSONResponse +from starlette.responses import JSONResponse, RedirectResponse from starlette.routing import Route from .config import load_credentials logger = logging.getLogger(__name__) -# In-memory token store: token_hash -> expiry timestamp -_active_tokens: dict[str, float] = {} +# In-memory stores +_active_tokens: dict[str, float] = {} # token_hash -> expiry +_auth_codes: dict[str, dict] = {} # code -> {client_id, redirect_uri, code_challenge, expires} TOKEN_LIFETIME = 3600 # 1 hour +CODE_LIFETIME = 300 # 5 minutes def _get_oauth_credentials() -> tuple[str, str]: @@ -29,35 +33,105 @@ def _get_oauth_credentials() -> tuple[str, str]: return client_id, client_secret -def _hash_token(token: str) -> str: - return hashlib.sha256(token.encode()).hexdigest() +def _hash(value: str) -> str: + return hashlib.sha256(value.encode()).hexdigest() def _cleanup_expired(): - """Remove expired tokens.""" + """Remove expired tokens and codes.""" now = time.time() - expired = [h for h, exp in _active_tokens.items() if exp < now] - for h in expired: - del _active_tokens[h] + for store in (_active_tokens, _auth_codes): + expired = [k for k, v in store.items() + if (v if isinstance(v, float) else v.get("expires", 0)) < now] + for k in expired: + del store[k] + + +def _verify_pkce(code_verifier: str, code_challenge: str) -> bool: + """Verify PKCE S256 challenge.""" + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + return secrets.compare_digest(computed, code_challenge) def validate_bearer_token(token: str) -> bool: """Check if a bearer token is valid and not expired.""" _cleanup_expired() - token_hash = _hash_token(token) + token_hash = _hash(token) return token_hash in _active_tokens and _active_tokens[token_hash] > time.time() +async def authorize_endpoint(request: Request): + """OAuth 2.0 Authorization endpoint. + + GET /authorize?response_type=code&client_id=...&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=... + + Single-user setup: auto-approves if client_id matches, redirects with code. + """ + params = dict(request.query_params) + + response_type = params.get("response_type") + client_id = params.get("client_id", "") + redirect_uri = params.get("redirect_uri", "") + code_challenge = params.get("code_challenge", "") + code_challenge_method = params.get("code_challenge_method", "") + state = params.get("state", "") + + if response_type != "code": + return JSONResponse( + {"error": "unsupported_response_type"}, + status_code=400, + ) + + try: + expected_id, _ = _get_oauth_credentials() + except RuntimeError: + return JSONResponse({"error": "server_error"}, status_code=500) + + if not secrets.compare_digest(client_id, expected_id): + logger.warning(f"OAuth authorize: invalid client_id from {request.client.host}") + return JSONResponse( + {"error": "invalid_client", "error_description": "Unknown client_id"}, + status_code=401, + ) + + if code_challenge_method and code_challenge_method != "S256": + return JSONResponse( + {"error": "invalid_request", "error_description": "Only S256 code_challenge_method supported"}, + status_code=400, + ) + + # Auto-approve: generate authorization code + code = secrets.token_urlsafe(32) + _auth_codes[code] = { + "client_id": client_id, + "redirect_uri": redirect_uri, + "code_challenge": code_challenge, + "expires": time.time() + CODE_LIFETIME, + } + _cleanup_expired() + + logger.info(f"OAuth authorization code issued for {client_id[:12]}... -> {redirect_uri}") + + # Redirect back with code + redirect_params = {"code": code} + if state: + redirect_params["state"] = state + separator = "&" if "?" in redirect_uri else "?" + return RedirectResponse( + url=f"{redirect_uri}{separator}{urlencode(redirect_params)}", + status_code=302, + ) + + async def token_endpoint(request: Request) -> JSONResponse: - """OAuth 2.0 token endpoint (client_credentials grant). + """OAuth 2.0 token endpoint. - POST /token - Content-Type: application/x-www-form-urlencoded - - grant_type=client_credentials&client_id=...&client_secret=... + Supports: + - grant_type=authorization_code (with PKCE) + - grant_type=client_credentials (direct) """ try: - # Accept both form-encoded and JSON content_type = request.headers.get("content-type", "") if "application/json" in content_type: data = await request.json() @@ -71,38 +145,92 @@ async def token_endpoint(request: Request) -> JSONResponse: ) grant_type = data.get("grant_type") - if grant_type != "client_credentials": + + if grant_type == "authorization_code": + return await _handle_auth_code_grant(data, request) + elif grant_type == "client_credentials": + return await _handle_client_credentials_grant(data, request) + else: return JSONResponse( - {"error": "unsupported_grant_type", "error_description": "Only client_credentials is supported"}, + {"error": "unsupported_grant_type", + "error_description": "Supported: authorization_code, client_credentials"}, status_code=400, ) + +async def _handle_auth_code_grant(data: dict, request: Request) -> JSONResponse: + """Exchange authorization code + PKCE verifier for access token.""" + code = data.get("code", "") + client_id = data.get("client_id", "") + code_verifier = data.get("code_verifier", "") + redirect_uri = data.get("redirect_uri", "") + + _cleanup_expired() + + if code not in _auth_codes: + return JSONResponse( + {"error": "invalid_grant", "error_description": "Invalid or expired authorization code"}, + status_code=400, + ) + + code_data = _auth_codes.pop(code) + + # Verify client_id matches + if not secrets.compare_digest(client_id, code_data["client_id"]): + return JSONResponse({"error": "invalid_client"}, status_code=401) + + # Verify redirect_uri matches + if redirect_uri and redirect_uri != code_data["redirect_uri"]: + return JSONResponse( + {"error": "invalid_grant", "error_description": "redirect_uri mismatch"}, + status_code=400, + ) + + # Verify PKCE + if code_data["code_challenge"]: + if not code_verifier: + return JSONResponse( + {"error": "invalid_request", "error_description": "code_verifier required"}, + status_code=400, + ) + if not _verify_pkce(code_verifier, code_data["code_challenge"]): + return JSONResponse( + {"error": "invalid_grant", "error_description": "PKCE verification failed"}, + status_code=400, + ) + + # Issue access token + access_token = secrets.token_urlsafe(48) + _active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME + + logger.info(f"OAuth token issued via auth code to {request.client.host}") + return JSONResponse({ + "access_token": access_token, + "token_type": "Bearer", + "expires_in": TOKEN_LIFETIME, + }) + + +async def _handle_client_credentials_grant(data: dict, request: Request) -> JSONResponse: + """Direct client_id + client_secret exchange for access token.""" client_id = data.get("client_id", "") client_secret = data.get("client_secret", "") try: expected_id, expected_secret = _get_oauth_credentials() except RuntimeError: - logger.error("OAuth credentials not configured") - return JSONResponse( - {"error": "server_error", "error_description": "Auth not configured"}, - status_code=500, - ) + return JSONResponse({"error": "server_error"}, status_code=500) if not secrets.compare_digest(client_id, expected_id) or \ not secrets.compare_digest(client_secret, expected_secret): - logger.warning(f"OAuth auth failed from {request.client.host}") - return JSONResponse( - {"error": "invalid_client", "error_description": "Invalid client credentials"}, - status_code=401, - ) + logger.warning(f"OAuth client_credentials auth failed from {request.client.host}") + return JSONResponse({"error": "invalid_client"}, status_code=401) - # Issue token access_token = secrets.token_urlsafe(48) - _active_tokens[_hash_token(access_token)] = time.time() + TOKEN_LIFETIME + _active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME _cleanup_expired() - logger.info(f"OAuth token issued to {request.client.host}") + logger.info(f"OAuth token issued via client_credentials to {request.client.host}") return JSONResponse({ "access_token": access_token, "token_type": "Bearer", @@ -111,18 +239,16 @@ async def token_endpoint(request: Request) -> JSONResponse: async def oauth_metadata(request: Request) -> JSONResponse: - """OAuth 2.0 Authorization Server Metadata (RFC 8414). - - GET /.well-known/oauth-authorization-server - """ - # Build base URL from request + """OAuth 2.0 Authorization Server Metadata (RFC 8414).""" base = str(request.base_url).rstrip("/") return JSONResponse({ "issuer": base, + "authorization_endpoint": f"{base}/authorize", "token_endpoint": f"{base}/token", - "token_endpoint_auth_methods_supported": ["client_secret_post"], - "grant_types_supported": ["client_credentials"], - "response_types_supported": [], + "token_endpoint_auth_methods_supported": ["client_secret_post", "none"], + "grant_types_supported": ["authorization_code", "client_credentials"], + "response_types_supported": ["code"], + "code_challenge_methods_supported": ["S256"], "scopes_supported": ["mcp"], }) @@ -130,5 +256,6 @@ async def oauth_metadata(request: Request) -> JSONResponse: # Routes to add to the app auth_routes = [ Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]), + Route("/authorize", authorize_endpoint, methods=["GET"]), Route("/token", token_endpoint, methods=["POST"]), ]