feat: implement concrete OAuth provider with in-memory storage

OAuthProvider is abstract — subclassed as HomelabOAuth with full
implementation of register_client, get_client, create/exchange
authorization codes, token issuance, PKCE verification, and
refresh token rotation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Mikkel Georgsen 2026-03-30 11:22:20 +00:00
parent a71595b9d8
commit 5086716387

View file

@ -5,8 +5,13 @@ import logging
from datetime import datetime, timezone
from fastmcp import FastMCP
from fastmcp.server.auth import OAuthProvider
import secrets
import time
import hashlib
from fastmcp.server.auth import OAuthProvider, AccessToken
from fastmcp.server.auth.auth import ClientRegistrationOptions
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
@ -19,8 +24,136 @@ logger = logging.getLogger(__name__)
# Will be initialized in __main__ with shared db instance
db: Database | None = None
# OAuth provider with Dynamic Client Registration enabled
oauth = OAuthProvider(
TOKEN_LIFETIME = 3600
class HomelabOAuth(OAuthProvider):
"""Concrete OAuth provider with in-memory storage."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._clients: dict[str, OAuthClientInformationFull] = {}
self._auth_codes: dict[str, dict] = {} # code -> {client_id, code_challenge, redirect_uri, expires}
self._tokens: dict[str, dict] = {} # token_hash -> {client_id, scopes, expires}
self._refresh_tokens: dict[str, dict] = {} # refresh_hash -> {client_id, scopes}
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
self._clients[client_info.client_id] = client_info
logger.info(f"Registered OAuth client: {client_info.client_id} ({client_info.client_name})")
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
return self._clients.get(client_id)
async def create_authorization_code(
self, client: OAuthClientInformationFull, code_challenge: str | None,
redirect_uri: str | None, scopes: list[str] | None = None,
**kwargs,
) -> str:
code = secrets.token_urlsafe(32)
self._auth_codes[code] = {
"client_id": client.client_id,
"code_challenge": code_challenge,
"redirect_uri": redirect_uri or str(client.redirect_uris[0]),
"scopes": scopes or [],
"expires": time.time() + 300,
}
logger.info(f"Auth code issued for client {client.client_id}")
return code
async def exchange_authorization_code(
self, client: OAuthClientInformationFull, code: str,
code_verifier: str | None = None, redirect_uri: str | None = None,
**kwargs,
) -> OAuthToken:
if code not in self._auth_codes:
raise ValueError("Invalid authorization code")
code_data = self._auth_codes.pop(code)
if code_data["expires"] < time.time():
raise ValueError("Authorization code expired")
if code_data["client_id"] != client.client_id:
raise ValueError("Client ID mismatch")
# PKCE verification
if code_data["code_challenge"] and code_verifier:
import base64
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
if computed != code_data["code_challenge"]:
raise ValueError("PKCE verification failed")
# Issue tokens
access_token = secrets.token_urlsafe(48)
refresh_token = secrets.token_urlsafe(48)
self._tokens[hashlib.sha256(access_token.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": code_data["scopes"],
"expires": time.time() + TOKEN_LIFETIME,
}
self._refresh_tokens[hashlib.sha256(refresh_token.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": code_data["scopes"],
}
logger.info(f"Token issued for client {client.client_id}")
return OAuthToken(
access_token=access_token,
token_type="Bearer",
expires_in=TOKEN_LIFETIME,
refresh_token=refresh_token,
scope=" ".join(code_data["scopes"]) if code_data["scopes"] else None,
)
async def load_access_token(self, token: str) -> AccessToken | None:
token_hash = hashlib.sha256(token.encode()).hexdigest()
data = self._tokens.get(token_hash)
if not data or data["expires"] < time.time():
return None
return AccessToken(
token=token,
client_id=data["client_id"],
scopes=data["scopes"],
)
async def exchange_refresh_token(
self, client: OAuthClientInformationFull,
refresh_token: str, scopes: list[str] | None = None,
**kwargs,
) -> OAuthToken:
refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
data = self._refresh_tokens.get(refresh_hash)
if not data or data["client_id"] != client.client_id:
raise ValueError("Invalid refresh token")
# Issue new tokens
new_access = secrets.token_urlsafe(48)
new_refresh = secrets.token_urlsafe(48)
self._tokens[hashlib.sha256(new_access.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": scopes or data["scopes"],
"expires": time.time() + TOKEN_LIFETIME,
}
# Rotate refresh token
del self._refresh_tokens[refresh_hash]
self._refresh_tokens[hashlib.sha256(new_refresh.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": scopes or data["scopes"],
}
return OAuthToken(
access_token=new_access,
token_type="Bearer",
expires_in=TOKEN_LIFETIME,
refresh_token=new_refresh,
scope=" ".join(scopes or data["scopes"]) if (scopes or data["scopes"]) else None,
)
oauth = HomelabOAuth(
base_url="https://mcp.georgsen.dk",
client_registration_options=ClientRegistrationOptions(enabled=True),
)