From 1dff4630fe9ed85dffdb4e70e2e1349fb5ef1e49 Mon Sep 17 00:00:00 2001 From: Mikkel Georgsen Date: Mon, 30 Mar 2026 11:32:01 +0000 Subject: [PATCH] fix: use FastMCP's InMemoryOAuthProvider instead of custom implementation Replaced hand-rolled OAuth with FastMCP's battle-tested InMemoryOAuthProvider. Handles DCR, PKCE, token exchange, refresh tokens, and revocation out of the box. Co-Authored-By: Claude Opus 4.6 (1M context) --- mcp_bridge/auth.py | 329 --------------------------------------- mcp_bridge/mcp_server.py | 163 +------------------ 2 files changed, 5 insertions(+), 487 deletions(-) delete mode 100644 mcp_bridge/auth.py diff --git a/mcp_bridge/auth.py b/mcp_bridge/auth.py deleted file mode 100644 index 1342ebb..0000000 --- a/mcp_bridge/auth.py +++ /dev/null @@ -1,329 +0,0 @@ -"""OAuth 2.0 authorization code + PKCE auth for MCP server.""" - -import base64 -import hashlib -import json -import logging -import secrets -import time -from urllib.parse import urlencode, urlparse, parse_qs - -from starlette.requests import Request -from starlette.responses import JSONResponse, RedirectResponse -from starlette.routing import Route - -from .config import load_credentials - -logger = logging.getLogger(__name__) - -# In-memory stores -_active_tokens: dict[str, float] = {} # token_hash -> expiry -_auth_codes: dict[str, dict] = {} # code -> {client_id, redirect_uri, code_challenge, expires} -TOKEN_LIFETIME = 3600 # 1 hour -CODE_LIFETIME = 300 # 5 minutes - - -def _get_oauth_credentials() -> tuple[str, str]: - """Load OAuth client_id and client_secret from credentials file.""" - creds = load_credentials() - client_id = creds.get("OAUTH_CLIENT_ID", "") - client_secret = creds.get("OAUTH_CLIENT_SECRET", "") - if not client_id or not client_secret: - raise RuntimeError("OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET must be set in credentials") - return client_id, client_secret - - -def _hash(value: str) -> str: - return hashlib.sha256(value.encode()).hexdigest() - - -def _cleanup_expired(): - """Remove expired tokens and codes.""" - now = time.time() - for store in (_active_tokens, _auth_codes): - expired = [k for k, v in store.items() - if (v if isinstance(v, float) else v.get("expires", 0)) < now] - for k in expired: - del store[k] - - -def _verify_pkce(code_verifier: str, code_challenge: str) -> bool: - """Verify PKCE S256 challenge.""" - digest = hashlib.sha256(code_verifier.encode("ascii")).digest() - computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") - return secrets.compare_digest(computed, code_challenge) - - -def validate_bearer_token(token: str) -> bool: - """Check if a bearer token is valid and not expired.""" - _cleanup_expired() - token_hash = _hash(token) - return token_hash in _active_tokens and _active_tokens[token_hash] > time.time() - - -async def authorize_endpoint(request: Request): - """OAuth 2.0 Authorization endpoint. - - GET /authorize?response_type=code&client_id=...&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=... - - Single-user setup: auto-approves if client_id matches, redirects with code. - """ - params = dict(request.query_params) - - response_type = params.get("response_type") - client_id = params.get("client_id", "") - redirect_uri = params.get("redirect_uri", "") - code_challenge = params.get("code_challenge", "") - code_challenge_method = params.get("code_challenge_method", "") - state = params.get("state", "") - - if response_type != "code": - return JSONResponse( - {"error": "unsupported_response_type"}, - status_code=400, - ) - - try: - expected_id, _ = _get_oauth_credentials() - except RuntimeError: - return JSONResponse({"error": "server_error"}, status_code=500) - - if not secrets.compare_digest(client_id, expected_id): - logger.warning(f"OAuth authorize: invalid client_id from {request.client.host}") - return JSONResponse( - {"error": "invalid_client", "error_description": "Unknown client_id"}, - status_code=401, - ) - - if code_challenge_method and code_challenge_method != "S256": - return JSONResponse( - {"error": "invalid_request", "error_description": "Only S256 code_challenge_method supported"}, - status_code=400, - ) - - # Auto-approve: generate authorization code - code = secrets.token_urlsafe(32) - _auth_codes[code] = { - "client_id": client_id, - "redirect_uri": redirect_uri, - "code_challenge": code_challenge, - "expires": time.time() + CODE_LIFETIME, - } - _cleanup_expired() - - logger.info(f"OAuth authorization code issued for {client_id[:12]}... -> {redirect_uri}") - - # Redirect back with code - redirect_params = {"code": code} - if state: - redirect_params["state"] = state - separator = "&" if "?" in redirect_uri else "?" - return RedirectResponse( - url=f"{redirect_uri}{separator}{urlencode(redirect_params)}", - status_code=302, - ) - - -async def token_endpoint(request: Request) -> JSONResponse: - """OAuth 2.0 token endpoint. - - Supports: - - grant_type=authorization_code (with PKCE) - - grant_type=client_credentials (direct) - """ - # Parse request data from any source - data = {} - - # Always include query params - data.update(dict(request.query_params)) - - # Parse body for POST - if request.method == "POST": - try: - raw_body = await request.body() - logger.info(f"TOKEN POST body ({len(raw_body)} bytes): {raw_body[:500]}") - content_type = request.headers.get("content-type", "") - if "application/json" in content_type and raw_body: - data.update(json.loads(raw_body)) - elif raw_body: - from urllib.parse import parse_qs - parsed = parse_qs(raw_body.decode(), keep_blank_values=True) - data.update({k: v[0] if len(v) == 1 else v for k, v in parsed.items()}) - except Exception as e: - logger.warning(f"TOKEN body parse error: {e}") - - logger.info(f"TOKEN {request.method} data_keys={list(data.keys())} " - f"grant_type={data.get('grant_type', '(missing)')}") - - # Bare GET with no params = probe/ping, return 200 with metadata - if request.method == "GET" and not data: - return JSONResponse({ - "grant_types_supported": ["authorization_code", "client_credentials"], - "token_endpoint_auth_methods_supported": ["client_secret_post", "none"], - "code_challenge_methods_supported": ["S256"], - }) - - grant_type = data.get("grant_type", "") - - logger.debug(f"Token request: method={request.method} grant_type={grant_type} params={list(data.keys())}") - - # If there's a code param but no grant_type, assume authorization_code - if not grant_type and "code" in data: - grant_type = "authorization_code" - - if grant_type == "authorization_code": - return await _handle_auth_code_grant(data, request) - elif grant_type == "client_credentials": - return await _handle_client_credentials_grant(data, request) - else: - logger.warning(f"Token request with unsupported grant_type={grant_type!r}, data keys={list(data.keys())}") - return JSONResponse( - {"error": "unsupported_grant_type", - "error_description": f"Supported: authorization_code, client_credentials. Got: {grant_type!r}"}, - status_code=400, - ) - - -async def _handle_auth_code_grant(data: dict, request: Request) -> JSONResponse: - """Exchange authorization code + PKCE verifier for access token.""" - code = data.get("code", "") - client_id = data.get("client_id", "") - code_verifier = data.get("code_verifier", "") - redirect_uri = data.get("redirect_uri", "") - - _cleanup_expired() - - if code not in _auth_codes: - return JSONResponse( - {"error": "invalid_grant", "error_description": "Invalid or expired authorization code"}, - status_code=400, - ) - - code_data = _auth_codes.pop(code) - - # Verify client_id matches - if not secrets.compare_digest(client_id, code_data["client_id"]): - return JSONResponse({"error": "invalid_client"}, status_code=401) - - # Verify redirect_uri matches - if redirect_uri and redirect_uri != code_data["redirect_uri"]: - return JSONResponse( - {"error": "invalid_grant", "error_description": "redirect_uri mismatch"}, - status_code=400, - ) - - # Verify PKCE - if code_data["code_challenge"]: - if not code_verifier: - return JSONResponse( - {"error": "invalid_request", "error_description": "code_verifier required"}, - status_code=400, - ) - if not _verify_pkce(code_verifier, code_data["code_challenge"]): - return JSONResponse( - {"error": "invalid_grant", "error_description": "PKCE verification failed"}, - status_code=400, - ) - - # Issue access token - access_token = secrets.token_urlsafe(48) - _active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME - - logger.info(f"OAuth token issued via auth code to {request.client.host}") - return JSONResponse({ - "access_token": access_token, - "token_type": "Bearer", - "expires_in": TOKEN_LIFETIME, - }) - - -async def _handle_client_credentials_grant(data: dict, request: Request) -> JSONResponse: - """Direct client_id + client_secret exchange for access token.""" - client_id = data.get("client_id", "") - client_secret = data.get("client_secret", "") - - try: - expected_id, expected_secret = _get_oauth_credentials() - except RuntimeError: - return JSONResponse({"error": "server_error"}, status_code=500) - - if not secrets.compare_digest(client_id, expected_id) or \ - not secrets.compare_digest(client_secret, expected_secret): - logger.warning(f"OAuth client_credentials auth failed from {request.client.host}") - return JSONResponse({"error": "invalid_client"}, status_code=401) - - access_token = secrets.token_urlsafe(48) - _active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME - _cleanup_expired() - - logger.info(f"OAuth token issued via client_credentials to {request.client.host}") - return JSONResponse({ - "access_token": access_token, - "token_type": "Bearer", - "expires_in": TOKEN_LIFETIME, - }) - - -async def oauth_metadata(request: Request) -> JSONResponse: - """OAuth 2.0 Authorization Server Metadata (RFC 8414).""" - base = str(request.base_url).rstrip("/") - return JSONResponse({ - "issuer": base, - "authorization_endpoint": f"{base}/authorize", - "token_endpoint": f"{base}/token", - "registration_endpoint": f"{base}/register", - "token_endpoint_auth_methods_supported": ["none"], - "grant_types_supported": ["authorization_code"], - "response_types_supported": ["code"], - "code_challenge_methods_supported": ["S256"], - "scopes_supported": ["mcp"], - }) - - -async def protected_resource_metadata(request: Request) -> JSONResponse: - """OAuth 2.0 Protected Resource Metadata (RFC 9728).""" - base = str(request.base_url).rstrip("/") - return JSONResponse({ - "resource": base, - "authorization_servers": [base], - "scopes_supported": ["mcp"], - "bearer_methods_supported": ["header"], - }) - - -# Routes to add to the app -async def register_endpoint(request: Request) -> JSONResponse: - """OAuth 2.0 Dynamic Client Registration (stub). - - Claude may try to register dynamically. Since we use pre-configured - credentials, just return the existing client info. - """ - try: - client_id, _ = _get_oauth_credentials() - except RuntimeError: - return JSONResponse({"error": "server_error"}, status_code=500) - - base = str(request.base_url).rstrip("/") - logger.info(f"REGISTER endpoint called: method={request.method}") - if request.method == "POST": - body = await request.body() - logger.info(f"REGISTER body: {body[:500]}") - - return JSONResponse({ - "client_id": client_id, - "client_name": "claude-desktop", - "redirect_uris": ["https://claude.ai/api/mcp/auth_callback"], - "grant_types": ["authorization_code"], - "response_types": ["code"], - "token_endpoint_auth_method": "none", - }) - - -auth_routes = [ - Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]), - Route("/.well-known/oauth-protected-resource", protected_resource_metadata, methods=["GET"]), - Route("/.well-known/oauth-protected-resource/mcp", protected_resource_metadata, methods=["GET"]), - Route("/authorize", authorize_endpoint, methods=["GET"]), - Route("/token", token_endpoint, methods=["GET", "POST"]), - Route("/register", register_endpoint, methods=["GET", "POST"]), -] diff --git a/mcp_bridge/mcp_server.py b/mcp_bridge/mcp_server.py index e50382d..3326229 100644 --- a/mcp_bridge/mcp_server.py +++ b/mcp_bridge/mcp_server.py @@ -5,169 +5,20 @@ import logging from datetime import datetime, timezone from fastmcp import FastMCP -import secrets -import time -import hashlib -from urllib.parse import urlencode - -from fastmcp.server.auth import OAuthProvider, AccessToken +from fastmcp.server.auth.providers.in_memory import InMemoryOAuthProvider from fastmcp.server.auth.auth import ClientRegistrationOptions -from mcp.server.auth.provider import AuthorizationParams -from mcp.shared.auth import OAuthClientInformationFull, OAuthToken from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route from .db import Database -from .config import get_group_chat_id, MCP_HOST, MCP_PORT +from .config import get_group_chat_id logger = logging.getLogger(__name__) -# Will be initialized in __main__ with shared db instance db: Database | None = None -TOKEN_LIFETIME = 3600 - - -class HomelabOAuth(OAuthProvider): - """Concrete OAuth provider with in-memory storage.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._clients: dict[str, OAuthClientInformationFull] = {} - self._auth_codes: dict[str, dict] = {} # code -> {client_id, code_challenge, redirect_uri, expires} - self._tokens: dict[str, dict] = {} # token_hash -> {client_id, scopes, expires} - self._refresh_tokens: dict[str, dict] = {} # refresh_hash -> {client_id, scopes} - - async def register_client(self, client_info: OAuthClientInformationFull) -> None: - self._clients[client_info.client_id] = client_info - logger.info(f"Registered OAuth client: {client_info.client_id} ({client_info.client_name})") - - async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: - return self._clients.get(client_id) - - async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: - """Auto-approve and redirect back with auth code (single-user setup).""" - code = await self.create_authorization_code( - client, params.code_challenge, str(params.redirect_uri), params.scopes, - ) - redirect_params = {"code": code} - if params.state: - redirect_params["state"] = params.state - redirect_uri = str(params.redirect_uri) - separator = "&" if "?" in redirect_uri else "?" - return f"{redirect_uri}{separator}{urlencode(redirect_params)}" - - async def create_authorization_code( - self, client: OAuthClientInformationFull, code_challenge: str | None, - redirect_uri: str | None, scopes: list[str] | None = None, - **kwargs, - ) -> str: - code = secrets.token_urlsafe(32) - self._auth_codes[code] = { - "client_id": client.client_id, - "code_challenge": code_challenge, - "redirect_uri": redirect_uri or str(client.redirect_uris[0]), - "scopes": scopes or [], - "expires": time.time() + 300, - } - logger.info(f"Auth code issued for client {client.client_id}") - return code - - async def exchange_authorization_code( - self, client: OAuthClientInformationFull, code: str, - code_verifier: str | None = None, redirect_uri: str | None = None, - **kwargs, - ) -> OAuthToken: - if code not in self._auth_codes: - raise ValueError("Invalid authorization code") - - code_data = self._auth_codes.pop(code) - - if code_data["expires"] < time.time(): - raise ValueError("Authorization code expired") - - if code_data["client_id"] != client.client_id: - raise ValueError("Client ID mismatch") - - # PKCE verification - if code_data["code_challenge"] and code_verifier: - import base64 - digest = hashlib.sha256(code_verifier.encode("ascii")).digest() - computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") - if computed != code_data["code_challenge"]: - raise ValueError("PKCE verification failed") - - # Issue tokens - access_token = secrets.token_urlsafe(48) - refresh_token = secrets.token_urlsafe(48) - - self._tokens[hashlib.sha256(access_token.encode()).hexdigest()] = { - "client_id": client.client_id, - "scopes": code_data["scopes"], - "expires": time.time() + TOKEN_LIFETIME, - } - self._refresh_tokens[hashlib.sha256(refresh_token.encode()).hexdigest()] = { - "client_id": client.client_id, - "scopes": code_data["scopes"], - } - - logger.info(f"Token issued for client {client.client_id}") - return OAuthToken( - access_token=access_token, - token_type="Bearer", - expires_in=TOKEN_LIFETIME, - refresh_token=refresh_token, - scope=" ".join(code_data["scopes"]) if code_data["scopes"] else None, - ) - - async def load_access_token(self, token: str) -> AccessToken | None: - token_hash = hashlib.sha256(token.encode()).hexdigest() - data = self._tokens.get(token_hash) - if not data or data["expires"] < time.time(): - return None - return AccessToken( - token=token, - client_id=data["client_id"], - scopes=data["scopes"], - ) - - async def exchange_refresh_token( - self, client: OAuthClientInformationFull, - refresh_token: str, scopes: list[str] | None = None, - **kwargs, - ) -> OAuthToken: - refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest() - data = self._refresh_tokens.get(refresh_hash) - if not data or data["client_id"] != client.client_id: - raise ValueError("Invalid refresh token") - - # Issue new tokens - new_access = secrets.token_urlsafe(48) - new_refresh = secrets.token_urlsafe(48) - - self._tokens[hashlib.sha256(new_access.encode()).hexdigest()] = { - "client_id": client.client_id, - "scopes": scopes or data["scopes"], - "expires": time.time() + TOKEN_LIFETIME, - } - # Rotate refresh token - del self._refresh_tokens[refresh_hash] - self._refresh_tokens[hashlib.sha256(new_refresh.encode()).hexdigest()] = { - "client_id": client.client_id, - "scopes": scopes or data["scopes"], - } - - return OAuthToken( - access_token=new_access, - token_type="Bearer", - expires_in=TOKEN_LIFETIME, - refresh_token=new_refresh, - scope=" ".join(scopes or data["scopes"]) if (scopes or data["scopes"]) else None, - ) - - -oauth = HomelabOAuth( +oauth = InMemoryOAuthProvider( base_url="https://mcp.georgsen.dk", client_registration_options=ClientRegistrationOptions(enabled=True), ) @@ -185,7 +36,6 @@ mcp = FastMCP( def init(database: Database): - """Set the shared database instance.""" global db db = database @@ -225,7 +75,6 @@ def pull_updates(since_id: int = 0, since: str | None = None, limit: int = 50) - else: messages = db.get_messages_since_id(since_id, limit) - # Enrich with attachment info for msg in messages: if msg["has_attachment"]: msg["attachments"] = db.get_attachments_for_message(msg["id"]) @@ -252,7 +101,7 @@ def queue_status() -> str: return json.dumps(status) -# Custom non-MCP routes (no auth required) +# Custom non-MCP routes (no auth required - local access only) async def ingest_message(request: Request) -> JSONResponse: """HTTP endpoint for local services to log messages into the bridge.""" try: @@ -285,9 +134,7 @@ async def ingest_message(request: Request) -> JSONResponse: if msg_id is None: return JSONResponse({"ok": True, "duplicate": True}) - logger.info( - f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}" - ) + logger.info(f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}") return JSONResponse({"ok": True, "id": msg_id})