From 50867163874464d0f4e28a7006ff519ec7e62104 Mon Sep 17 00:00:00 2001 From: Mikkel Georgsen Date: Mon, 30 Mar 2026 11:22:20 +0000 Subject: [PATCH] feat: implement concrete OAuth provider with in-memory storage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit OAuthProvider is abstract — subclassed as HomelabOAuth with full implementation of register_client, get_client, create/exchange authorization codes, token issuance, PKCE verification, and refresh token rotation. Co-Authored-By: Claude Opus 4.6 (1M context) --- mcp_bridge/mcp_server.py | 139 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 136 insertions(+), 3 deletions(-) diff --git a/mcp_bridge/mcp_server.py b/mcp_bridge/mcp_server.py index 84092d9..3ba13d4 100644 --- a/mcp_bridge/mcp_server.py +++ b/mcp_bridge/mcp_server.py @@ -5,8 +5,13 @@ import logging from datetime import datetime, timezone from fastmcp import FastMCP -from fastmcp.server.auth import OAuthProvider +import secrets +import time +import hashlib + +from fastmcp.server.auth import OAuthProvider, AccessToken from fastmcp.server.auth.auth import ClientRegistrationOptions +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route @@ -19,8 +24,136 @@ logger = logging.getLogger(__name__) # Will be initialized in __main__ with shared db instance db: Database | None = None -# OAuth provider with Dynamic Client Registration enabled -oauth = OAuthProvider( +TOKEN_LIFETIME = 3600 + + +class HomelabOAuth(OAuthProvider): + """Concrete OAuth provider with in-memory storage.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._clients: dict[str, OAuthClientInformationFull] = {} + self._auth_codes: dict[str, dict] = {} # code -> {client_id, code_challenge, redirect_uri, expires} + self._tokens: dict[str, dict] = {} # token_hash -> {client_id, scopes, expires} + self._refresh_tokens: dict[str, dict] = {} # refresh_hash -> {client_id, scopes} + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + self._clients[client_info.client_id] = client_info + logger.info(f"Registered OAuth client: {client_info.client_id} ({client_info.client_name})") + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + return self._clients.get(client_id) + + async def create_authorization_code( + self, client: OAuthClientInformationFull, code_challenge: str | None, + redirect_uri: str | None, scopes: list[str] | None = None, + **kwargs, + ) -> str: + code = secrets.token_urlsafe(32) + self._auth_codes[code] = { + "client_id": client.client_id, + "code_challenge": code_challenge, + "redirect_uri": redirect_uri or str(client.redirect_uris[0]), + "scopes": scopes or [], + "expires": time.time() + 300, + } + logger.info(f"Auth code issued for client {client.client_id}") + return code + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, code: str, + code_verifier: str | None = None, redirect_uri: str | None = None, + **kwargs, + ) -> OAuthToken: + if code not in self._auth_codes: + raise ValueError("Invalid authorization code") + + code_data = self._auth_codes.pop(code) + + if code_data["expires"] < time.time(): + raise ValueError("Authorization code expired") + + if code_data["client_id"] != client.client_id: + raise ValueError("Client ID mismatch") + + # PKCE verification + if code_data["code_challenge"] and code_verifier: + import base64 + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + if computed != code_data["code_challenge"]: + raise ValueError("PKCE verification failed") + + # Issue tokens + access_token = secrets.token_urlsafe(48) + refresh_token = secrets.token_urlsafe(48) + + self._tokens[hashlib.sha256(access_token.encode()).hexdigest()] = { + "client_id": client.client_id, + "scopes": code_data["scopes"], + "expires": time.time() + TOKEN_LIFETIME, + } + self._refresh_tokens[hashlib.sha256(refresh_token.encode()).hexdigest()] = { + "client_id": client.client_id, + "scopes": code_data["scopes"], + } + + logger.info(f"Token issued for client {client.client_id}") + return OAuthToken( + access_token=access_token, + token_type="Bearer", + expires_in=TOKEN_LIFETIME, + refresh_token=refresh_token, + scope=" ".join(code_data["scopes"]) if code_data["scopes"] else None, + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + token_hash = hashlib.sha256(token.encode()).hexdigest() + data = self._tokens.get(token_hash) + if not data or data["expires"] < time.time(): + return None + return AccessToken( + token=token, + client_id=data["client_id"], + scopes=data["scopes"], + ) + + async def exchange_refresh_token( + self, client: OAuthClientInformationFull, + refresh_token: str, scopes: list[str] | None = None, + **kwargs, + ) -> OAuthToken: + refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest() + data = self._refresh_tokens.get(refresh_hash) + if not data or data["client_id"] != client.client_id: + raise ValueError("Invalid refresh token") + + # Issue new tokens + new_access = secrets.token_urlsafe(48) + new_refresh = secrets.token_urlsafe(48) + + self._tokens[hashlib.sha256(new_access.encode()).hexdigest()] = { + "client_id": client.client_id, + "scopes": scopes or data["scopes"], + "expires": time.time() + TOKEN_LIFETIME, + } + # Rotate refresh token + del self._refresh_tokens[refresh_hash] + self._refresh_tokens[hashlib.sha256(new_refresh.encode()).hexdigest()] = { + "client_id": client.client_id, + "scopes": scopes or data["scopes"], + } + + return OAuthToken( + access_token=new_access, + token_type="Bearer", + expires_in=TOKEN_LIFETIME, + refresh_token=new_refresh, + scope=" ".join(scopes or data["scopes"]) if (scopes or data["scopes"]) else None, + ) + + +oauth = HomelabOAuth( base_url="https://mcp.georgsen.dk", client_registration_options=ClientRegistrationOptions(enabled=True), )