feat: implement concrete OAuth provider with in-memory storage
OAuthProvider is abstract — subclassed as HomelabOAuth with full implementation of register_client, get_client, create/exchange authorization codes, token issuance, PKCE verification, and refresh token rotation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a71595b9d8
commit
5086716387
1 changed files with 136 additions and 3 deletions
|
|
@ -5,8 +5,13 @@ import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
from fastmcp.server.auth import OAuthProvider
|
import secrets
|
||||||
|
import time
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from fastmcp.server.auth import OAuthProvider, AccessToken
|
||||||
from fastmcp.server.auth.auth import ClientRegistrationOptions
|
from fastmcp.server.auth.auth import ClientRegistrationOptions
|
||||||
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
|
|
@ -19,8 +24,136 @@ logger = logging.getLogger(__name__)
|
||||||
# Will be initialized in __main__ with shared db instance
|
# Will be initialized in __main__ with shared db instance
|
||||||
db: Database | None = None
|
db: Database | None = None
|
||||||
|
|
||||||
# OAuth provider with Dynamic Client Registration enabled
|
TOKEN_LIFETIME = 3600
|
||||||
oauth = OAuthProvider(
|
|
||||||
|
|
||||||
|
class HomelabOAuth(OAuthProvider):
|
||||||
|
"""Concrete OAuth provider with in-memory storage."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._clients: dict[str, OAuthClientInformationFull] = {}
|
||||||
|
self._auth_codes: dict[str, dict] = {} # code -> {client_id, code_challenge, redirect_uri, expires}
|
||||||
|
self._tokens: dict[str, dict] = {} # token_hash -> {client_id, scopes, expires}
|
||||||
|
self._refresh_tokens: dict[str, dict] = {} # refresh_hash -> {client_id, scopes}
|
||||||
|
|
||||||
|
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
||||||
|
self._clients[client_info.client_id] = client_info
|
||||||
|
logger.info(f"Registered OAuth client: {client_info.client_id} ({client_info.client_name})")
|
||||||
|
|
||||||
|
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||||
|
return self._clients.get(client_id)
|
||||||
|
|
||||||
|
async def create_authorization_code(
|
||||||
|
self, client: OAuthClientInformationFull, code_challenge: str | None,
|
||||||
|
redirect_uri: str | None, scopes: list[str] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
code = secrets.token_urlsafe(32)
|
||||||
|
self._auth_codes[code] = {
|
||||||
|
"client_id": client.client_id,
|
||||||
|
"code_challenge": code_challenge,
|
||||||
|
"redirect_uri": redirect_uri or str(client.redirect_uris[0]),
|
||||||
|
"scopes": scopes or [],
|
||||||
|
"expires": time.time() + 300,
|
||||||
|
}
|
||||||
|
logger.info(f"Auth code issued for client {client.client_id}")
|
||||||
|
return code
|
||||||
|
|
||||||
|
async def exchange_authorization_code(
|
||||||
|
self, client: OAuthClientInformationFull, code: str,
|
||||||
|
code_verifier: str | None = None, redirect_uri: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> OAuthToken:
|
||||||
|
if code not in self._auth_codes:
|
||||||
|
raise ValueError("Invalid authorization code")
|
||||||
|
|
||||||
|
code_data = self._auth_codes.pop(code)
|
||||||
|
|
||||||
|
if code_data["expires"] < time.time():
|
||||||
|
raise ValueError("Authorization code expired")
|
||||||
|
|
||||||
|
if code_data["client_id"] != client.client_id:
|
||||||
|
raise ValueError("Client ID mismatch")
|
||||||
|
|
||||||
|
# PKCE verification
|
||||||
|
if code_data["code_challenge"] and code_verifier:
|
||||||
|
import base64
|
||||||
|
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||||
|
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||||
|
if computed != code_data["code_challenge"]:
|
||||||
|
raise ValueError("PKCE verification failed")
|
||||||
|
|
||||||
|
# Issue tokens
|
||||||
|
access_token = secrets.token_urlsafe(48)
|
||||||
|
refresh_token = secrets.token_urlsafe(48)
|
||||||
|
|
||||||
|
self._tokens[hashlib.sha256(access_token.encode()).hexdigest()] = {
|
||||||
|
"client_id": client.client_id,
|
||||||
|
"scopes": code_data["scopes"],
|
||||||
|
"expires": time.time() + TOKEN_LIFETIME,
|
||||||
|
}
|
||||||
|
self._refresh_tokens[hashlib.sha256(refresh_token.encode()).hexdigest()] = {
|
||||||
|
"client_id": client.client_id,
|
||||||
|
"scopes": code_data["scopes"],
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Token issued for client {client.client_id}")
|
||||||
|
return OAuthToken(
|
||||||
|
access_token=access_token,
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=TOKEN_LIFETIME,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
scope=" ".join(code_data["scopes"]) if code_data["scopes"] else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_access_token(self, token: str) -> AccessToken | None:
|
||||||
|
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
data = self._tokens.get(token_hash)
|
||||||
|
if not data or data["expires"] < time.time():
|
||||||
|
return None
|
||||||
|
return AccessToken(
|
||||||
|
token=token,
|
||||||
|
client_id=data["client_id"],
|
||||||
|
scopes=data["scopes"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def exchange_refresh_token(
|
||||||
|
self, client: OAuthClientInformationFull,
|
||||||
|
refresh_token: str, scopes: list[str] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> OAuthToken:
|
||||||
|
refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
||||||
|
data = self._refresh_tokens.get(refresh_hash)
|
||||||
|
if not data or data["client_id"] != client.client_id:
|
||||||
|
raise ValueError("Invalid refresh token")
|
||||||
|
|
||||||
|
# Issue new tokens
|
||||||
|
new_access = secrets.token_urlsafe(48)
|
||||||
|
new_refresh = secrets.token_urlsafe(48)
|
||||||
|
|
||||||
|
self._tokens[hashlib.sha256(new_access.encode()).hexdigest()] = {
|
||||||
|
"client_id": client.client_id,
|
||||||
|
"scopes": scopes or data["scopes"],
|
||||||
|
"expires": time.time() + TOKEN_LIFETIME,
|
||||||
|
}
|
||||||
|
# Rotate refresh token
|
||||||
|
del self._refresh_tokens[refresh_hash]
|
||||||
|
self._refresh_tokens[hashlib.sha256(new_refresh.encode()).hexdigest()] = {
|
||||||
|
"client_id": client.client_id,
|
||||||
|
"scopes": scopes or data["scopes"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return OAuthToken(
|
||||||
|
access_token=new_access,
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=TOKEN_LIFETIME,
|
||||||
|
refresh_token=new_refresh,
|
||||||
|
scope=" ".join(scopes or data["scopes"]) if (scopes or data["scopes"]) else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
oauth = HomelabOAuth(
|
||||||
base_url="https://mcp.georgsen.dk",
|
base_url="https://mcp.georgsen.dk",
|
||||||
client_registration_options=ClientRegistrationOptions(enabled=True),
|
client_registration_options=ClientRegistrationOptions(enabled=True),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue