feat: implement concrete OAuth provider with in-memory storage
OAuthProvider is abstract — subclassed as HomelabOAuth with full implementation of register_client, get_client, create/exchange authorization codes, token issuance, PKCE verification, and refresh token rotation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a71595b9d8
commit
5086716387
1 changed files with 136 additions and 3 deletions
|
|
@ -5,8 +5,13 @@ import logging
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.auth import OAuthProvider
|
||||
import secrets
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from fastmcp.server.auth import OAuthProvider, AccessToken
|
||||
from fastmcp.server.auth.auth import ClientRegistrationOptions
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
|
@ -19,8 +24,136 @@ logger = logging.getLogger(__name__)
|
|||
# Will be initialized in __main__ with shared db instance
|
||||
db: Database | None = None
|
||||
|
||||
# OAuth provider with Dynamic Client Registration enabled
|
||||
oauth = OAuthProvider(
|
||||
TOKEN_LIFETIME = 3600
|
||||
|
||||
|
||||
class HomelabOAuth(OAuthProvider):
|
||||
"""Concrete OAuth provider with in-memory storage."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._clients: dict[str, OAuthClientInformationFull] = {}
|
||||
self._auth_codes: dict[str, dict] = {} # code -> {client_id, code_challenge, redirect_uri, expires}
|
||||
self._tokens: dict[str, dict] = {} # token_hash -> {client_id, scopes, expires}
|
||||
self._refresh_tokens: dict[str, dict] = {} # refresh_hash -> {client_id, scopes}
|
||||
|
||||
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
||||
self._clients[client_info.client_id] = client_info
|
||||
logger.info(f"Registered OAuth client: {client_info.client_id} ({client_info.client_name})")
|
||||
|
||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
return self._clients.get(client_id)
|
||||
|
||||
async def create_authorization_code(
|
||||
self, client: OAuthClientInformationFull, code_challenge: str | None,
|
||||
redirect_uri: str | None, scopes: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
code = secrets.token_urlsafe(32)
|
||||
self._auth_codes[code] = {
|
||||
"client_id": client.client_id,
|
||||
"code_challenge": code_challenge,
|
||||
"redirect_uri": redirect_uri or str(client.redirect_uris[0]),
|
||||
"scopes": scopes or [],
|
||||
"expires": time.time() + 300,
|
||||
}
|
||||
logger.info(f"Auth code issued for client {client.client_id}")
|
||||
return code
|
||||
|
||||
async def exchange_authorization_code(
|
||||
self, client: OAuthClientInformationFull, code: str,
|
||||
code_verifier: str | None = None, redirect_uri: str | None = None,
|
||||
**kwargs,
|
||||
) -> OAuthToken:
|
||||
if code not in self._auth_codes:
|
||||
raise ValueError("Invalid authorization code")
|
||||
|
||||
code_data = self._auth_codes.pop(code)
|
||||
|
||||
if code_data["expires"] < time.time():
|
||||
raise ValueError("Authorization code expired")
|
||||
|
||||
if code_data["client_id"] != client.client_id:
|
||||
raise ValueError("Client ID mismatch")
|
||||
|
||||
# PKCE verification
|
||||
if code_data["code_challenge"] and code_verifier:
|
||||
import base64
|
||||
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
if computed != code_data["code_challenge"]:
|
||||
raise ValueError("PKCE verification failed")
|
||||
|
||||
# Issue tokens
|
||||
access_token = secrets.token_urlsafe(48)
|
||||
refresh_token = secrets.token_urlsafe(48)
|
||||
|
||||
self._tokens[hashlib.sha256(access_token.encode()).hexdigest()] = {
|
||||
"client_id": client.client_id,
|
||||
"scopes": code_data["scopes"],
|
||||
"expires": time.time() + TOKEN_LIFETIME,
|
||||
}
|
||||
self._refresh_tokens[hashlib.sha256(refresh_token.encode()).hexdigest()] = {
|
||||
"client_id": client.client_id,
|
||||
"scopes": code_data["scopes"],
|
||||
}
|
||||
|
||||
logger.info(f"Token issued for client {client.client_id}")
|
||||
return OAuthToken(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=TOKEN_LIFETIME,
|
||||
refresh_token=refresh_token,
|
||||
scope=" ".join(code_data["scopes"]) if code_data["scopes"] else None,
|
||||
)
|
||||
|
||||
async def load_access_token(self, token: str) -> AccessToken | None:
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
data = self._tokens.get(token_hash)
|
||||
if not data or data["expires"] < time.time():
|
||||
return None
|
||||
return AccessToken(
|
||||
token=token,
|
||||
client_id=data["client_id"],
|
||||
scopes=data["scopes"],
|
||||
)
|
||||
|
||||
async def exchange_refresh_token(
|
||||
self, client: OAuthClientInformationFull,
|
||||
refresh_token: str, scopes: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> OAuthToken:
|
||||
refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
||||
data = self._refresh_tokens.get(refresh_hash)
|
||||
if not data or data["client_id"] != client.client_id:
|
||||
raise ValueError("Invalid refresh token")
|
||||
|
||||
# Issue new tokens
|
||||
new_access = secrets.token_urlsafe(48)
|
||||
new_refresh = secrets.token_urlsafe(48)
|
||||
|
||||
self._tokens[hashlib.sha256(new_access.encode()).hexdigest()] = {
|
||||
"client_id": client.client_id,
|
||||
"scopes": scopes or data["scopes"],
|
||||
"expires": time.time() + TOKEN_LIFETIME,
|
||||
}
|
||||
# Rotate refresh token
|
||||
del self._refresh_tokens[refresh_hash]
|
||||
self._refresh_tokens[hashlib.sha256(new_refresh.encode()).hexdigest()] = {
|
||||
"client_id": client.client_id,
|
||||
"scopes": scopes or data["scopes"],
|
||||
}
|
||||
|
||||
return OAuthToken(
|
||||
access_token=new_access,
|
||||
token_type="Bearer",
|
||||
expires_in=TOKEN_LIFETIME,
|
||||
refresh_token=new_refresh,
|
||||
scope=" ".join(scopes or data["scopes"]) if (scopes or data["scopes"]) else None,
|
||||
)
|
||||
|
||||
|
||||
oauth = HomelabOAuth(
|
||||
base_url="https://mcp.georgsen.dk",
|
||||
client_registration_options=ClientRegistrationOptions(enabled=True),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue