feat: replace custom OAuth with FastMCP built-in OAuthProvider
FastMCP's OAuthProvider handles the full OAuth flow including DCR (Dynamic Client Registration), authorization code + PKCE, token issuance, and refresh tokens. No more custom auth code. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a21dd3ebbb
commit
a71595b9d8
3 changed files with 138 additions and 142 deletions
|
|
@ -4,15 +4,12 @@ import asyncio
|
|||
import logging
|
||||
import signal
|
||||
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
import uvicorn
|
||||
|
||||
from .config import MCP_HOST, MCP_PORT, MEDIA_DIR, DB_PATH
|
||||
from .db import Database
|
||||
from .telegram_bot import BridgeBot
|
||||
from .mcp_server import mcp, init as init_mcp, custom_routes
|
||||
from .auth import auth_routes, validate_bearer_token
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -21,50 +18,6 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger("mcp_bridge")
|
||||
|
||||
# Paths that don't require auth
|
||||
PUBLIC_PATHS = {
|
||||
"/.well-known/oauth-authorization-server",
|
||||
"/.well-known/oauth-protected-resource",
|
||||
"/authorize",
|
||||
"/token",
|
||||
"/api/health",
|
||||
"/api/ingest", # Local-only, not exposed via NPM
|
||||
}
|
||||
|
||||
|
||||
class AuthMiddleware:
|
||||
"""ASGI middleware that validates Bearer tokens on protected endpoints."""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Skip auth for public paths
|
||||
if path in PUBLIC_PATHS:
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Check Authorization header
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
if validate_bearer_token(token):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Reject — send 401
|
||||
response = JSONResponse(
|
||||
{"error": "unauthorized", "error_description": "Valid Bearer token required"},
|
||||
status_code=401,
|
||||
headers={"WWW-Authenticate": 'Bearer realm="mcp"'},
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
|
||||
|
||||
async def run_telegram_bot(bot: BridgeBot):
|
||||
"""Run the telegram bot polling loop."""
|
||||
|
|
@ -87,21 +40,15 @@ async def run_telegram_bot(bot: BridgeBot):
|
|||
|
||||
|
||||
async def run_mcp_server():
|
||||
"""Run the FastMCP HTTP server with OAuth auth and custom routes."""
|
||||
import uvicorn
|
||||
|
||||
# Get the FastMCP Starlette app
|
||||
"""Run the FastMCP HTTP server with built-in OAuth."""
|
||||
# Get the FastMCP app (includes OAuth routes automatically)
|
||||
mcp_app = mcp.http_app()
|
||||
|
||||
# Add custom routes (API + OAuth)
|
||||
# Add our custom non-auth routes
|
||||
mcp_app.routes.extend(custom_routes)
|
||||
mcp_app.routes.extend(auth_routes)
|
||||
|
||||
# Wrap with auth middleware
|
||||
authed_app = AuthMiddleware(mcp_app)
|
||||
|
||||
logger.info(f"MCP server starting on {MCP_HOST}:{MCP_PORT} (OAuth enabled)")
|
||||
config = uvicorn.Config(authed_app, host=MCP_HOST, port=MCP_PORT, log_level="info")
|
||||
logger.info(f"MCP server starting on {MCP_HOST}:{MCP_PORT} (OAuth via FastMCP)")
|
||||
config = uvicorn.Config(mcp_app, host=MCP_HOST, port=MCP_PORT, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
|
|
|||
|
|
@ -131,32 +131,55 @@ async def token_endpoint(request: Request) -> JSONResponse:
|
|||
- grant_type=authorization_code (with PKCE)
|
||||
- grant_type=client_credentials (direct)
|
||||
"""
|
||||
try:
|
||||
if request.method == "GET":
|
||||
data = dict(request.query_params)
|
||||
else:
|
||||
content_type = request.headers.get("content-type", "")
|
||||
if "application/json" in content_type:
|
||||
data = await request.json()
|
||||
else:
|
||||
form = await request.form()
|
||||
data = dict(form)
|
||||
except Exception:
|
||||
return JSONResponse(
|
||||
{"error": "invalid_request", "error_description": "Could not parse request body"},
|
||||
status_code=400,
|
||||
)
|
||||
# Parse request data from any source
|
||||
data = {}
|
||||
|
||||
grant_type = data.get("grant_type")
|
||||
# Always include query params
|
||||
data.update(dict(request.query_params))
|
||||
|
||||
# Parse body for POST
|
||||
if request.method == "POST":
|
||||
try:
|
||||
raw_body = await request.body()
|
||||
logger.info(f"TOKEN POST body ({len(raw_body)} bytes): {raw_body[:500]}")
|
||||
content_type = request.headers.get("content-type", "")
|
||||
if "application/json" in content_type and raw_body:
|
||||
data.update(json.loads(raw_body))
|
||||
elif raw_body:
|
||||
from urllib.parse import parse_qs
|
||||
parsed = parse_qs(raw_body.decode(), keep_blank_values=True)
|
||||
data.update({k: v[0] if len(v) == 1 else v for k, v in parsed.items()})
|
||||
except Exception as e:
|
||||
logger.warning(f"TOKEN body parse error: {e}")
|
||||
|
||||
logger.info(f"TOKEN {request.method} data_keys={list(data.keys())} "
|
||||
f"grant_type={data.get('grant_type', '(missing)')}")
|
||||
|
||||
# Bare GET with no params = probe/ping, return 200 with metadata
|
||||
if request.method == "GET" and not data:
|
||||
return JSONResponse({
|
||||
"grant_types_supported": ["authorization_code", "client_credentials"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
})
|
||||
|
||||
grant_type = data.get("grant_type", "")
|
||||
|
||||
logger.debug(f"Token request: method={request.method} grant_type={grant_type} params={list(data.keys())}")
|
||||
|
||||
# If there's a code param but no grant_type, assume authorization_code
|
||||
if not grant_type and "code" in data:
|
||||
grant_type = "authorization_code"
|
||||
|
||||
if grant_type == "authorization_code":
|
||||
return await _handle_auth_code_grant(data, request)
|
||||
elif grant_type == "client_credentials":
|
||||
return await _handle_client_credentials_grant(data, request)
|
||||
else:
|
||||
logger.warning(f"Token request with unsupported grant_type={grant_type!r}, data keys={list(data.keys())}")
|
||||
return JSONResponse(
|
||||
{"error": "unsupported_grant_type",
|
||||
"error_description": "Supported: authorization_code, client_credentials"},
|
||||
"error_description": f"Supported: authorization_code, client_credentials. Got: {grant_type!r}"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
|
|
@ -248,8 +271,9 @@ async def oauth_metadata(request: Request) -> JSONResponse:
|
|||
"issuer": base,
|
||||
"authorization_endpoint": f"{base}/authorize",
|
||||
"token_endpoint": f"{base}/token",
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
||||
"grant_types_supported": ["authorization_code", "client_credentials"],
|
||||
"registration_endpoint": f"{base}/register",
|
||||
"token_endpoint_auth_methods_supported": ["none"],
|
||||
"grant_types_supported": ["authorization_code"],
|
||||
"response_types_supported": ["code"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"scopes_supported": ["mcp"],
|
||||
|
|
@ -268,9 +292,38 @@ async def protected_resource_metadata(request: Request) -> JSONResponse:
|
|||
|
||||
|
||||
# Routes to add to the app
|
||||
async def register_endpoint(request: Request) -> JSONResponse:
|
||||
"""OAuth 2.0 Dynamic Client Registration (stub).
|
||||
|
||||
Claude may try to register dynamically. Since we use pre-configured
|
||||
credentials, just return the existing client info.
|
||||
"""
|
||||
try:
|
||||
client_id, _ = _get_oauth_credentials()
|
||||
except RuntimeError:
|
||||
return JSONResponse({"error": "server_error"}, status_code=500)
|
||||
|
||||
base = str(request.base_url).rstrip("/")
|
||||
logger.info(f"REGISTER endpoint called: method={request.method}")
|
||||
if request.method == "POST":
|
||||
body = await request.body()
|
||||
logger.info(f"REGISTER body: {body[:500]}")
|
||||
|
||||
return JSONResponse({
|
||||
"client_id": client_id,
|
||||
"client_name": "claude-desktop",
|
||||
"redirect_uris": ["https://claude.ai/api/mcp/auth_callback"],
|
||||
"grant_types": ["authorization_code"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
})
|
||||
|
||||
|
||||
auth_routes = [
|
||||
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
|
||||
Route("/.well-known/oauth-protected-resource", protected_resource_metadata, methods=["GET"]),
|
||||
Route("/.well-known/oauth-protected-resource/mcp", protected_resource_metadata, methods=["GET"]),
|
||||
Route("/authorize", authorize_endpoint, methods=["GET"]),
|
||||
Route("/token", token_endpoint, methods=["GET", "POST"]),
|
||||
Route("/register", register_endpoint, methods=["GET", "POST"]),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,84 +5,29 @@ import logging
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.auth import OAuthProvider
|
||||
from fastmcp.server.auth.auth import ClientRegistrationOptions
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from .db import Database
|
||||
from .config import get_group_chat_id
|
||||
from .config import get_group_chat_id, MCP_HOST, MCP_PORT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will be initialized in __main__ with shared db instance
|
||||
db: Database | None = None
|
||||
|
||||
|
||||
async def ingest_message(request: Request) -> JSONResponse:
|
||||
"""HTTP endpoint for local services to log messages into the bridge.
|
||||
|
||||
POST /api/ingest
|
||||
{
|
||||
"telegram_message_id": 123, # required
|
||||
"chat_id": -100..., # required
|
||||
"sender_type": "homelab_bot", # required
|
||||
"sender_id": 8521598773, # optional
|
||||
"sender_name": "Homelab Bot", # optional
|
||||
"content": "message text", # optional
|
||||
"reply_to_message_id": null, # optional
|
||||
"created_at": "ISO8601" # optional (defaults to now)
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return JSONResponse({"error": "invalid JSON"}, status_code=400)
|
||||
|
||||
telegram_message_id = data.get("telegram_message_id")
|
||||
chat_id = data.get("chat_id")
|
||||
if not telegram_message_id or not chat_id:
|
||||
return JSONResponse(
|
||||
{"error": "telegram_message_id and chat_id are required"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
created_at = data.get("created_at", datetime.now(timezone.utc).isoformat())
|
||||
|
||||
msg_id = db.insert_message(
|
||||
telegram_message_id=telegram_message_id,
|
||||
chat_id=chat_id,
|
||||
sender_type=data.get("sender_type", "unknown"),
|
||||
sender_id=data.get("sender_id"),
|
||||
sender_name=data.get("sender_name"),
|
||||
content=data.get("content"),
|
||||
reply_to_message_id=data.get("reply_to_message_id"),
|
||||
has_attachment=data.get("has_attachment", False),
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
if msg_id is None:
|
||||
return JSONResponse({"ok": True, "duplicate": True})
|
||||
|
||||
logger.info(
|
||||
f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}"
|
||||
)
|
||||
return JSONResponse({"ok": True, "id": msg_id})
|
||||
|
||||
|
||||
async def health(request: Request) -> JSONResponse:
|
||||
"""Health check endpoint."""
|
||||
status = db.get_status()
|
||||
return JSONResponse({"status": "ok", **status})
|
||||
|
||||
|
||||
# Custom routes added to the FastMCP app
|
||||
custom_routes = [
|
||||
Route("/api/ingest", ingest_message, methods=["POST"]),
|
||||
Route("/api/health", health, methods=["GET"]),
|
||||
]
|
||||
# OAuth provider with Dynamic Client Registration enabled
|
||||
oauth = OAuthProvider(
|
||||
base_url="https://mcp.georgsen.dk",
|
||||
client_registration_options=ClientRegistrationOptions(enabled=True),
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
name="homelab-bridge",
|
||||
auth=oauth,
|
||||
instructions=(
|
||||
"This MCP server bridges claude.ai to a homelab Telegram group chat. "
|
||||
"Use pull_updates to read conversation history (supports cursor-based pagination). "
|
||||
|
|
@ -158,3 +103,54 @@ def queue_status() -> str:
|
|||
"""
|
||||
status = db.get_status()
|
||||
return json.dumps(status)
|
||||
|
||||
|
||||
# Custom non-MCP routes (no auth required)
|
||||
async def ingest_message(request: Request) -> JSONResponse:
|
||||
"""HTTP endpoint for local services to log messages into the bridge."""
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return JSONResponse({"error": "invalid JSON"}, status_code=400)
|
||||
|
||||
telegram_message_id = data.get("telegram_message_id")
|
||||
chat_id = data.get("chat_id")
|
||||
if not telegram_message_id or not chat_id:
|
||||
return JSONResponse(
|
||||
{"error": "telegram_message_id and chat_id are required"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
created_at = data.get("created_at", datetime.now(timezone.utc).isoformat())
|
||||
|
||||
msg_id = db.insert_message(
|
||||
telegram_message_id=telegram_message_id,
|
||||
chat_id=chat_id,
|
||||
sender_type=data.get("sender_type", "unknown"),
|
||||
sender_id=data.get("sender_id"),
|
||||
sender_name=data.get("sender_name"),
|
||||
content=data.get("content"),
|
||||
reply_to_message_id=data.get("reply_to_message_id"),
|
||||
has_attachment=data.get("has_attachment", False),
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
if msg_id is None:
|
||||
return JSONResponse({"ok": True, "duplicate": True})
|
||||
|
||||
logger.info(
|
||||
f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}"
|
||||
)
|
||||
return JSONResponse({"ok": True, "id": msg_id})
|
||||
|
||||
|
||||
async def health(request: Request) -> JSONResponse:
|
||||
"""Health check endpoint."""
|
||||
status = db.get_status()
|
||||
return JSONResponse({"status": "ok", **status})
|
||||
|
||||
|
||||
custom_routes = [
|
||||
Route("/api/ingest", ingest_message, methods=["POST"]),
|
||||
Route("/api/health", health, methods=["GET"]),
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue