telegram-bot-mcp/mcp_bridge/auth.py
Mikkel Georgsen a71595b9d8 feat: replace custom OAuth with FastMCP built-in OAuthProvider
FastMCP's OAuthProvider handles the full OAuth flow including DCR
(Dynamic Client Registration), authorization code + PKCE, token
issuance, and refresh tokens. No more custom auth code.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-30 11:18:16 +00:00

329 lines
12 KiB
Python

"""OAuth 2.0 authorization code + PKCE auth for MCP server."""
import base64
import hashlib
import json
import logging
import secrets
import time
from urllib.parse import urlencode, urlparse, parse_qs
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse
from starlette.routing import Route
from .config import load_credentials
logger = logging.getLogger(__name__)
# In-memory stores
_active_tokens: dict[str, float] = {} # token_hash -> expiry
_auth_codes: dict[str, dict] = {} # code -> {client_id, redirect_uri, code_challenge, expires}
TOKEN_LIFETIME = 3600 # 1 hour
CODE_LIFETIME = 300 # 5 minutes
def _get_oauth_credentials() -> tuple[str, str]:
"""Load OAuth client_id and client_secret from credentials file."""
creds = load_credentials()
client_id = creds.get("OAUTH_CLIENT_ID", "")
client_secret = creds.get("OAUTH_CLIENT_SECRET", "")
if not client_id or not client_secret:
raise RuntimeError("OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET must be set in credentials")
return client_id, client_secret
def _hash(value: str) -> str:
return hashlib.sha256(value.encode()).hexdigest()
def _cleanup_expired():
"""Remove expired tokens and codes."""
now = time.time()
for store in (_active_tokens, _auth_codes):
expired = [k for k, v in store.items()
if (v if isinstance(v, float) else v.get("expires", 0)) < now]
for k in expired:
del store[k]
def _verify_pkce(code_verifier: str, code_challenge: str) -> bool:
"""Verify PKCE S256 challenge."""
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
return secrets.compare_digest(computed, code_challenge)
def validate_bearer_token(token: str) -> bool:
"""Check if a bearer token is valid and not expired."""
_cleanup_expired()
token_hash = _hash(token)
return token_hash in _active_tokens and _active_tokens[token_hash] > time.time()
async def authorize_endpoint(request: Request):
"""OAuth 2.0 Authorization endpoint.
GET /authorize?response_type=code&client_id=...&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...
Single-user setup: auto-approves if client_id matches, redirects with code.
"""
params = dict(request.query_params)
response_type = params.get("response_type")
client_id = params.get("client_id", "")
redirect_uri = params.get("redirect_uri", "")
code_challenge = params.get("code_challenge", "")
code_challenge_method = params.get("code_challenge_method", "")
state = params.get("state", "")
if response_type != "code":
return JSONResponse(
{"error": "unsupported_response_type"},
status_code=400,
)
try:
expected_id, _ = _get_oauth_credentials()
except RuntimeError:
return JSONResponse({"error": "server_error"}, status_code=500)
if not secrets.compare_digest(client_id, expected_id):
logger.warning(f"OAuth authorize: invalid client_id from {request.client.host}")
return JSONResponse(
{"error": "invalid_client", "error_description": "Unknown client_id"},
status_code=401,
)
if code_challenge_method and code_challenge_method != "S256":
return JSONResponse(
{"error": "invalid_request", "error_description": "Only S256 code_challenge_method supported"},
status_code=400,
)
# Auto-approve: generate authorization code
code = secrets.token_urlsafe(32)
_auth_codes[code] = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"expires": time.time() + CODE_LIFETIME,
}
_cleanup_expired()
logger.info(f"OAuth authorization code issued for {client_id[:12]}... -> {redirect_uri}")
# Redirect back with code
redirect_params = {"code": code}
if state:
redirect_params["state"] = state
separator = "&" if "?" in redirect_uri else "?"
return RedirectResponse(
url=f"{redirect_uri}{separator}{urlencode(redirect_params)}",
status_code=302,
)
async def token_endpoint(request: Request) -> JSONResponse:
"""OAuth 2.0 token endpoint.
Supports:
- grant_type=authorization_code (with PKCE)
- grant_type=client_credentials (direct)
"""
# Parse request data from any source
data = {}
# Always include query params
data.update(dict(request.query_params))
# Parse body for POST
if request.method == "POST":
try:
raw_body = await request.body()
logger.info(f"TOKEN POST body ({len(raw_body)} bytes): {raw_body[:500]}")
content_type = request.headers.get("content-type", "")
if "application/json" in content_type and raw_body:
data.update(json.loads(raw_body))
elif raw_body:
from urllib.parse import parse_qs
parsed = parse_qs(raw_body.decode(), keep_blank_values=True)
data.update({k: v[0] if len(v) == 1 else v for k, v in parsed.items()})
except Exception as e:
logger.warning(f"TOKEN body parse error: {e}")
logger.info(f"TOKEN {request.method} data_keys={list(data.keys())} "
f"grant_type={data.get('grant_type', '(missing)')}")
# Bare GET with no params = probe/ping, return 200 with metadata
if request.method == "GET" and not data:
return JSONResponse({
"grant_types_supported": ["authorization_code", "client_credentials"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
"code_challenge_methods_supported": ["S256"],
})
grant_type = data.get("grant_type", "")
logger.debug(f"Token request: method={request.method} grant_type={grant_type} params={list(data.keys())}")
# If there's a code param but no grant_type, assume authorization_code
if not grant_type and "code" in data:
grant_type = "authorization_code"
if grant_type == "authorization_code":
return await _handle_auth_code_grant(data, request)
elif grant_type == "client_credentials":
return await _handle_client_credentials_grant(data, request)
else:
logger.warning(f"Token request with unsupported grant_type={grant_type!r}, data keys={list(data.keys())}")
return JSONResponse(
{"error": "unsupported_grant_type",
"error_description": f"Supported: authorization_code, client_credentials. Got: {grant_type!r}"},
status_code=400,
)
async def _handle_auth_code_grant(data: dict, request: Request) -> JSONResponse:
"""Exchange authorization code + PKCE verifier for access token."""
code = data.get("code", "")
client_id = data.get("client_id", "")
code_verifier = data.get("code_verifier", "")
redirect_uri = data.get("redirect_uri", "")
_cleanup_expired()
if code not in _auth_codes:
return JSONResponse(
{"error": "invalid_grant", "error_description": "Invalid or expired authorization code"},
status_code=400,
)
code_data = _auth_codes.pop(code)
# Verify client_id matches
if not secrets.compare_digest(client_id, code_data["client_id"]):
return JSONResponse({"error": "invalid_client"}, status_code=401)
# Verify redirect_uri matches
if redirect_uri and redirect_uri != code_data["redirect_uri"]:
return JSONResponse(
{"error": "invalid_grant", "error_description": "redirect_uri mismatch"},
status_code=400,
)
# Verify PKCE
if code_data["code_challenge"]:
if not code_verifier:
return JSONResponse(
{"error": "invalid_request", "error_description": "code_verifier required"},
status_code=400,
)
if not _verify_pkce(code_verifier, code_data["code_challenge"]):
return JSONResponse(
{"error": "invalid_grant", "error_description": "PKCE verification failed"},
status_code=400,
)
# Issue access token
access_token = secrets.token_urlsafe(48)
_active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME
logger.info(f"OAuth token issued via auth code to {request.client.host}")
return JSONResponse({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": TOKEN_LIFETIME,
})
async def _handle_client_credentials_grant(data: dict, request: Request) -> JSONResponse:
"""Direct client_id + client_secret exchange for access token."""
client_id = data.get("client_id", "")
client_secret = data.get("client_secret", "")
try:
expected_id, expected_secret = _get_oauth_credentials()
except RuntimeError:
return JSONResponse({"error": "server_error"}, status_code=500)
if not secrets.compare_digest(client_id, expected_id) or \
not secrets.compare_digest(client_secret, expected_secret):
logger.warning(f"OAuth client_credentials auth failed from {request.client.host}")
return JSONResponse({"error": "invalid_client"}, status_code=401)
access_token = secrets.token_urlsafe(48)
_active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME
_cleanup_expired()
logger.info(f"OAuth token issued via client_credentials to {request.client.host}")
return JSONResponse({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": TOKEN_LIFETIME,
})
async def oauth_metadata(request: Request) -> JSONResponse:
"""OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
base = str(request.base_url).rstrip("/")
return JSONResponse({
"issuer": base,
"authorization_endpoint": f"{base}/authorize",
"token_endpoint": f"{base}/token",
"registration_endpoint": f"{base}/register",
"token_endpoint_auth_methods_supported": ["none"],
"grant_types_supported": ["authorization_code"],
"response_types_supported": ["code"],
"code_challenge_methods_supported": ["S256"],
"scopes_supported": ["mcp"],
})
async def protected_resource_metadata(request: Request) -> JSONResponse:
"""OAuth 2.0 Protected Resource Metadata (RFC 9728)."""
base = str(request.base_url).rstrip("/")
return JSONResponse({
"resource": base,
"authorization_servers": [base],
"scopes_supported": ["mcp"],
"bearer_methods_supported": ["header"],
})
# Routes to add to the app
async def register_endpoint(request: Request) -> JSONResponse:
"""OAuth 2.0 Dynamic Client Registration (stub).
Claude may try to register dynamically. Since we use pre-configured
credentials, just return the existing client info.
"""
try:
client_id, _ = _get_oauth_credentials()
except RuntimeError:
return JSONResponse({"error": "server_error"}, status_code=500)
base = str(request.base_url).rstrip("/")
logger.info(f"REGISTER endpoint called: method={request.method}")
if request.method == "POST":
body = await request.body()
logger.info(f"REGISTER body: {body[:500]}")
return JSONResponse({
"client_id": client_id,
"client_name": "claude-desktop",
"redirect_uris": ["https://claude.ai/api/mcp/auth_callback"],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
})
auth_routes = [
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
Route("/.well-known/oauth-protected-resource", protected_resource_metadata, methods=["GET"]),
Route("/.well-known/oauth-protected-resource/mcp", protected_resource_metadata, methods=["GET"]),
Route("/authorize", authorize_endpoint, methods=["GET"]),
Route("/token", token_endpoint, methods=["GET", "POST"]),
Route("/register", register_endpoint, methods=["GET", "POST"]),
]