fix: use FastMCP's InMemoryOAuthProvider instead of custom implementation

Replaced hand-rolled OAuth with FastMCP's battle-tested
InMemoryOAuthProvider. Handles DCR, PKCE, token exchange,
refresh tokens, and revocation out of the box.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Mikkel Georgsen 2026-03-30 11:32:01 +00:00
parent 1296310adb
commit 1dff4630fe
2 changed files with 5 additions and 487 deletions

View file

@ -1,329 +0,0 @@
"""OAuth 2.0 authorization code + PKCE auth for MCP server."""
import base64
import hashlib
import json
import logging
import secrets
import time
from urllib.parse import urlencode, urlparse, parse_qs
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse
from starlette.routing import Route
from .config import load_credentials
logger = logging.getLogger(__name__)
# In-memory stores
_active_tokens: dict[str, float] = {} # token_hash -> expiry
_auth_codes: dict[str, dict] = {} # code -> {client_id, redirect_uri, code_challenge, expires}
TOKEN_LIFETIME = 3600 # 1 hour
CODE_LIFETIME = 300 # 5 minutes
def _get_oauth_credentials() -> tuple[str, str]:
"""Load OAuth client_id and client_secret from credentials file."""
creds = load_credentials()
client_id = creds.get("OAUTH_CLIENT_ID", "")
client_secret = creds.get("OAUTH_CLIENT_SECRET", "")
if not client_id or not client_secret:
raise RuntimeError("OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET must be set in credentials")
return client_id, client_secret
def _hash(value: str) -> str:
return hashlib.sha256(value.encode()).hexdigest()
def _cleanup_expired():
"""Remove expired tokens and codes."""
now = time.time()
for store in (_active_tokens, _auth_codes):
expired = [k for k, v in store.items()
if (v if isinstance(v, float) else v.get("expires", 0)) < now]
for k in expired:
del store[k]
def _verify_pkce(code_verifier: str, code_challenge: str) -> bool:
"""Verify PKCE S256 challenge."""
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
return secrets.compare_digest(computed, code_challenge)
def validate_bearer_token(token: str) -> bool:
"""Check if a bearer token is valid and not expired."""
_cleanup_expired()
token_hash = _hash(token)
return token_hash in _active_tokens and _active_tokens[token_hash] > time.time()
async def authorize_endpoint(request: Request):
"""OAuth 2.0 Authorization endpoint.
GET /authorize?response_type=code&client_id=...&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...
Single-user setup: auto-approves if client_id matches, redirects with code.
"""
params = dict(request.query_params)
response_type = params.get("response_type")
client_id = params.get("client_id", "")
redirect_uri = params.get("redirect_uri", "")
code_challenge = params.get("code_challenge", "")
code_challenge_method = params.get("code_challenge_method", "")
state = params.get("state", "")
if response_type != "code":
return JSONResponse(
{"error": "unsupported_response_type"},
status_code=400,
)
try:
expected_id, _ = _get_oauth_credentials()
except RuntimeError:
return JSONResponse({"error": "server_error"}, status_code=500)
if not secrets.compare_digest(client_id, expected_id):
logger.warning(f"OAuth authorize: invalid client_id from {request.client.host}")
return JSONResponse(
{"error": "invalid_client", "error_description": "Unknown client_id"},
status_code=401,
)
if code_challenge_method and code_challenge_method != "S256":
return JSONResponse(
{"error": "invalid_request", "error_description": "Only S256 code_challenge_method supported"},
status_code=400,
)
# Auto-approve: generate authorization code
code = secrets.token_urlsafe(32)
_auth_codes[code] = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"expires": time.time() + CODE_LIFETIME,
}
_cleanup_expired()
logger.info(f"OAuth authorization code issued for {client_id[:12]}... -> {redirect_uri}")
# Redirect back with code
redirect_params = {"code": code}
if state:
redirect_params["state"] = state
separator = "&" if "?" in redirect_uri else "?"
return RedirectResponse(
url=f"{redirect_uri}{separator}{urlencode(redirect_params)}",
status_code=302,
)
async def token_endpoint(request: Request) -> JSONResponse:
"""OAuth 2.0 token endpoint.
Supports:
- grant_type=authorization_code (with PKCE)
- grant_type=client_credentials (direct)
"""
# Parse request data from any source
data = {}
# Always include query params
data.update(dict(request.query_params))
# Parse body for POST
if request.method == "POST":
try:
raw_body = await request.body()
logger.info(f"TOKEN POST body ({len(raw_body)} bytes): {raw_body[:500]}")
content_type = request.headers.get("content-type", "")
if "application/json" in content_type and raw_body:
data.update(json.loads(raw_body))
elif raw_body:
from urllib.parse import parse_qs
parsed = parse_qs(raw_body.decode(), keep_blank_values=True)
data.update({k: v[0] if len(v) == 1 else v for k, v in parsed.items()})
except Exception as e:
logger.warning(f"TOKEN body parse error: {e}")
logger.info(f"TOKEN {request.method} data_keys={list(data.keys())} "
f"grant_type={data.get('grant_type', '(missing)')}")
# Bare GET with no params = probe/ping, return 200 with metadata
if request.method == "GET" and not data:
return JSONResponse({
"grant_types_supported": ["authorization_code", "client_credentials"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
"code_challenge_methods_supported": ["S256"],
})
grant_type = data.get("grant_type", "")
logger.debug(f"Token request: method={request.method} grant_type={grant_type} params={list(data.keys())}")
# If there's a code param but no grant_type, assume authorization_code
if not grant_type and "code" in data:
grant_type = "authorization_code"
if grant_type == "authorization_code":
return await _handle_auth_code_grant(data, request)
elif grant_type == "client_credentials":
return await _handle_client_credentials_grant(data, request)
else:
logger.warning(f"Token request with unsupported grant_type={grant_type!r}, data keys={list(data.keys())}")
return JSONResponse(
{"error": "unsupported_grant_type",
"error_description": f"Supported: authorization_code, client_credentials. Got: {grant_type!r}"},
status_code=400,
)
async def _handle_auth_code_grant(data: dict, request: Request) -> JSONResponse:
"""Exchange authorization code + PKCE verifier for access token."""
code = data.get("code", "")
client_id = data.get("client_id", "")
code_verifier = data.get("code_verifier", "")
redirect_uri = data.get("redirect_uri", "")
_cleanup_expired()
if code not in _auth_codes:
return JSONResponse(
{"error": "invalid_grant", "error_description": "Invalid or expired authorization code"},
status_code=400,
)
code_data = _auth_codes.pop(code)
# Verify client_id matches
if not secrets.compare_digest(client_id, code_data["client_id"]):
return JSONResponse({"error": "invalid_client"}, status_code=401)
# Verify redirect_uri matches
if redirect_uri and redirect_uri != code_data["redirect_uri"]:
return JSONResponse(
{"error": "invalid_grant", "error_description": "redirect_uri mismatch"},
status_code=400,
)
# Verify PKCE
if code_data["code_challenge"]:
if not code_verifier:
return JSONResponse(
{"error": "invalid_request", "error_description": "code_verifier required"},
status_code=400,
)
if not _verify_pkce(code_verifier, code_data["code_challenge"]):
return JSONResponse(
{"error": "invalid_grant", "error_description": "PKCE verification failed"},
status_code=400,
)
# Issue access token
access_token = secrets.token_urlsafe(48)
_active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME
logger.info(f"OAuth token issued via auth code to {request.client.host}")
return JSONResponse({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": TOKEN_LIFETIME,
})
async def _handle_client_credentials_grant(data: dict, request: Request) -> JSONResponse:
"""Direct client_id + client_secret exchange for access token."""
client_id = data.get("client_id", "")
client_secret = data.get("client_secret", "")
try:
expected_id, expected_secret = _get_oauth_credentials()
except RuntimeError:
return JSONResponse({"error": "server_error"}, status_code=500)
if not secrets.compare_digest(client_id, expected_id) or \
not secrets.compare_digest(client_secret, expected_secret):
logger.warning(f"OAuth client_credentials auth failed from {request.client.host}")
return JSONResponse({"error": "invalid_client"}, status_code=401)
access_token = secrets.token_urlsafe(48)
_active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME
_cleanup_expired()
logger.info(f"OAuth token issued via client_credentials to {request.client.host}")
return JSONResponse({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": TOKEN_LIFETIME,
})
async def oauth_metadata(request: Request) -> JSONResponse:
"""OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
base = str(request.base_url).rstrip("/")
return JSONResponse({
"issuer": base,
"authorization_endpoint": f"{base}/authorize",
"token_endpoint": f"{base}/token",
"registration_endpoint": f"{base}/register",
"token_endpoint_auth_methods_supported": ["none"],
"grant_types_supported": ["authorization_code"],
"response_types_supported": ["code"],
"code_challenge_methods_supported": ["S256"],
"scopes_supported": ["mcp"],
})
async def protected_resource_metadata(request: Request) -> JSONResponse:
"""OAuth 2.0 Protected Resource Metadata (RFC 9728)."""
base = str(request.base_url).rstrip("/")
return JSONResponse({
"resource": base,
"authorization_servers": [base],
"scopes_supported": ["mcp"],
"bearer_methods_supported": ["header"],
})
# Routes to add to the app
async def register_endpoint(request: Request) -> JSONResponse:
"""OAuth 2.0 Dynamic Client Registration (stub).
Claude may try to register dynamically. Since we use pre-configured
credentials, just return the existing client info.
"""
try:
client_id, _ = _get_oauth_credentials()
except RuntimeError:
return JSONResponse({"error": "server_error"}, status_code=500)
base = str(request.base_url).rstrip("/")
logger.info(f"REGISTER endpoint called: method={request.method}")
if request.method == "POST":
body = await request.body()
logger.info(f"REGISTER body: {body[:500]}")
return JSONResponse({
"client_id": client_id,
"client_name": "claude-desktop",
"redirect_uris": ["https://claude.ai/api/mcp/auth_callback"],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
})
auth_routes = [
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
Route("/.well-known/oauth-protected-resource", protected_resource_metadata, methods=["GET"]),
Route("/.well-known/oauth-protected-resource/mcp", protected_resource_metadata, methods=["GET"]),
Route("/authorize", authorize_endpoint, methods=["GET"]),
Route("/token", token_endpoint, methods=["GET", "POST"]),
Route("/register", register_endpoint, methods=["GET", "POST"]),
]

View file

@ -5,169 +5,20 @@ import logging
from datetime import datetime, timezone
from fastmcp import FastMCP
import secrets
import time
import hashlib
from urllib.parse import urlencode
from fastmcp.server.auth import OAuthProvider, AccessToken
from fastmcp.server.auth.providers.in_memory import InMemoryOAuthProvider
from fastmcp.server.auth.auth import ClientRegistrationOptions
from mcp.server.auth.provider import AuthorizationParams
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from .db import Database
from .config import get_group_chat_id, MCP_HOST, MCP_PORT
from .config import get_group_chat_id
logger = logging.getLogger(__name__)
# Will be initialized in __main__ with shared db instance
db: Database | None = None
TOKEN_LIFETIME = 3600
class HomelabOAuth(OAuthProvider):
"""Concrete OAuth provider with in-memory storage."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._clients: dict[str, OAuthClientInformationFull] = {}
self._auth_codes: dict[str, dict] = {} # code -> {client_id, code_challenge, redirect_uri, expires}
self._tokens: dict[str, dict] = {} # token_hash -> {client_id, scopes, expires}
self._refresh_tokens: dict[str, dict] = {} # refresh_hash -> {client_id, scopes}
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
self._clients[client_info.client_id] = client_info
logger.info(f"Registered OAuth client: {client_info.client_id} ({client_info.client_name})")
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
return self._clients.get(client_id)
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
"""Auto-approve and redirect back with auth code (single-user setup)."""
code = await self.create_authorization_code(
client, params.code_challenge, str(params.redirect_uri), params.scopes,
)
redirect_params = {"code": code}
if params.state:
redirect_params["state"] = params.state
redirect_uri = str(params.redirect_uri)
separator = "&" if "?" in redirect_uri else "?"
return f"{redirect_uri}{separator}{urlencode(redirect_params)}"
async def create_authorization_code(
self, client: OAuthClientInformationFull, code_challenge: str | None,
redirect_uri: str | None, scopes: list[str] | None = None,
**kwargs,
) -> str:
code = secrets.token_urlsafe(32)
self._auth_codes[code] = {
"client_id": client.client_id,
"code_challenge": code_challenge,
"redirect_uri": redirect_uri or str(client.redirect_uris[0]),
"scopes": scopes or [],
"expires": time.time() + 300,
}
logger.info(f"Auth code issued for client {client.client_id}")
return code
async def exchange_authorization_code(
self, client: OAuthClientInformationFull, code: str,
code_verifier: str | None = None, redirect_uri: str | None = None,
**kwargs,
) -> OAuthToken:
if code not in self._auth_codes:
raise ValueError("Invalid authorization code")
code_data = self._auth_codes.pop(code)
if code_data["expires"] < time.time():
raise ValueError("Authorization code expired")
if code_data["client_id"] != client.client_id:
raise ValueError("Client ID mismatch")
# PKCE verification
if code_data["code_challenge"] and code_verifier:
import base64
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
if computed != code_data["code_challenge"]:
raise ValueError("PKCE verification failed")
# Issue tokens
access_token = secrets.token_urlsafe(48)
refresh_token = secrets.token_urlsafe(48)
self._tokens[hashlib.sha256(access_token.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": code_data["scopes"],
"expires": time.time() + TOKEN_LIFETIME,
}
self._refresh_tokens[hashlib.sha256(refresh_token.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": code_data["scopes"],
}
logger.info(f"Token issued for client {client.client_id}")
return OAuthToken(
access_token=access_token,
token_type="Bearer",
expires_in=TOKEN_LIFETIME,
refresh_token=refresh_token,
scope=" ".join(code_data["scopes"]) if code_data["scopes"] else None,
)
async def load_access_token(self, token: str) -> AccessToken | None:
token_hash = hashlib.sha256(token.encode()).hexdigest()
data = self._tokens.get(token_hash)
if not data or data["expires"] < time.time():
return None
return AccessToken(
token=token,
client_id=data["client_id"],
scopes=data["scopes"],
)
async def exchange_refresh_token(
self, client: OAuthClientInformationFull,
refresh_token: str, scopes: list[str] | None = None,
**kwargs,
) -> OAuthToken:
refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
data = self._refresh_tokens.get(refresh_hash)
if not data or data["client_id"] != client.client_id:
raise ValueError("Invalid refresh token")
# Issue new tokens
new_access = secrets.token_urlsafe(48)
new_refresh = secrets.token_urlsafe(48)
self._tokens[hashlib.sha256(new_access.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": scopes or data["scopes"],
"expires": time.time() + TOKEN_LIFETIME,
}
# Rotate refresh token
del self._refresh_tokens[refresh_hash]
self._refresh_tokens[hashlib.sha256(new_refresh.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": scopes or data["scopes"],
}
return OAuthToken(
access_token=new_access,
token_type="Bearer",
expires_in=TOKEN_LIFETIME,
refresh_token=new_refresh,
scope=" ".join(scopes or data["scopes"]) if (scopes or data["scopes"]) else None,
)
oauth = HomelabOAuth(
oauth = InMemoryOAuthProvider(
base_url="https://mcp.georgsen.dk",
client_registration_options=ClientRegistrationOptions(enabled=True),
)
@ -185,7 +36,6 @@ mcp = FastMCP(
def init(database: Database):
"""Set the shared database instance."""
global db
db = database
@ -225,7 +75,6 @@ def pull_updates(since_id: int = 0, since: str | None = None, limit: int = 50) -
else:
messages = db.get_messages_since_id(since_id, limit)
# Enrich with attachment info
for msg in messages:
if msg["has_attachment"]:
msg["attachments"] = db.get_attachments_for_message(msg["id"])
@ -252,7 +101,7 @@ def queue_status() -> str:
return json.dumps(status)
# Custom non-MCP routes (no auth required)
# Custom non-MCP routes (no auth required - local access only)
async def ingest_message(request: Request) -> JSONResponse:
"""HTTP endpoint for local services to log messages into the bridge."""
try:
@ -285,9 +134,7 @@ async def ingest_message(request: Request) -> JSONResponse:
if msg_id is None:
return JSONResponse({"ok": True, "duplicate": True})
logger.info(
f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}"
)
logger.info(f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}")
return JSONResponse({"ok": True, "id": msg_id})