feat: replace custom OAuth with FastMCP built-in OAuthProvider
FastMCP's OAuthProvider handles the full OAuth flow including DCR (Dynamic Client Registration), authorization code + PKCE, token issuance, and refresh tokens. No more custom auth code. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a21dd3ebbb
commit
a71595b9d8
3 changed files with 138 additions and 142 deletions
|
|
@ -4,15 +4,12 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import signal
|
import signal
|
||||||
|
|
||||||
from starlette.middleware import Middleware
|
import uvicorn
|
||||||
from starlette.requests import Request
|
|
||||||
from starlette.responses import JSONResponse, Response
|
|
||||||
|
|
||||||
from .config import MCP_HOST, MCP_PORT, MEDIA_DIR, DB_PATH
|
from .config import MCP_HOST, MCP_PORT, MEDIA_DIR, DB_PATH
|
||||||
from .db import Database
|
from .db import Database
|
||||||
from .telegram_bot import BridgeBot
|
from .telegram_bot import BridgeBot
|
||||||
from .mcp_server import mcp, init as init_mcp, custom_routes
|
from .mcp_server import mcp, init as init_mcp, custom_routes
|
||||||
from .auth import auth_routes, validate_bearer_token
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
|
|
@ -21,50 +18,6 @@ logging.basicConfig(
|
||||||
)
|
)
|
||||||
logger = logging.getLogger("mcp_bridge")
|
logger = logging.getLogger("mcp_bridge")
|
||||||
|
|
||||||
# Paths that don't require auth
|
|
||||||
PUBLIC_PATHS = {
|
|
||||||
"/.well-known/oauth-authorization-server",
|
|
||||||
"/.well-known/oauth-protected-resource",
|
|
||||||
"/authorize",
|
|
||||||
"/token",
|
|
||||||
"/api/health",
|
|
||||||
"/api/ingest", # Local-only, not exposed via NPM
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AuthMiddleware:
|
|
||||||
"""ASGI middleware that validates Bearer tokens on protected endpoints."""
|
|
||||||
|
|
||||||
def __init__(self, app):
|
|
||||||
self.app = app
|
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
|
||||||
if scope["type"] != "http":
|
|
||||||
return await self.app(scope, receive, send)
|
|
||||||
|
|
||||||
path = scope.get("path", "")
|
|
||||||
|
|
||||||
# Skip auth for public paths
|
|
||||||
if path in PUBLIC_PATHS:
|
|
||||||
return await self.app(scope, receive, send)
|
|
||||||
|
|
||||||
# Check Authorization header
|
|
||||||
headers = dict(scope.get("headers", []))
|
|
||||||
auth_header = headers.get(b"authorization", b"").decode()
|
|
||||||
|
|
||||||
if auth_header.startswith("Bearer "):
|
|
||||||
token = auth_header[7:]
|
|
||||||
if validate_bearer_token(token):
|
|
||||||
return await self.app(scope, receive, send)
|
|
||||||
|
|
||||||
# Reject — send 401
|
|
||||||
response = JSONResponse(
|
|
||||||
{"error": "unauthorized", "error_description": "Valid Bearer token required"},
|
|
||||||
status_code=401,
|
|
||||||
headers={"WWW-Authenticate": 'Bearer realm="mcp"'},
|
|
||||||
)
|
|
||||||
await response(scope, receive, send)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_telegram_bot(bot: BridgeBot):
|
async def run_telegram_bot(bot: BridgeBot):
|
||||||
"""Run the telegram bot polling loop."""
|
"""Run the telegram bot polling loop."""
|
||||||
|
|
@ -87,21 +40,15 @@ async def run_telegram_bot(bot: BridgeBot):
|
||||||
|
|
||||||
|
|
||||||
async def run_mcp_server():
|
async def run_mcp_server():
|
||||||
"""Run the FastMCP HTTP server with OAuth auth and custom routes."""
|
"""Run the FastMCP HTTP server with built-in OAuth."""
|
||||||
import uvicorn
|
# Get the FastMCP app (includes OAuth routes automatically)
|
||||||
|
|
||||||
# Get the FastMCP Starlette app
|
|
||||||
mcp_app = mcp.http_app()
|
mcp_app = mcp.http_app()
|
||||||
|
|
||||||
# Add custom routes (API + OAuth)
|
# Add our custom non-auth routes
|
||||||
mcp_app.routes.extend(custom_routes)
|
mcp_app.routes.extend(custom_routes)
|
||||||
mcp_app.routes.extend(auth_routes)
|
|
||||||
|
|
||||||
# Wrap with auth middleware
|
logger.info(f"MCP server starting on {MCP_HOST}:{MCP_PORT} (OAuth via FastMCP)")
|
||||||
authed_app = AuthMiddleware(mcp_app)
|
config = uvicorn.Config(mcp_app, host=MCP_HOST, port=MCP_PORT, log_level="info")
|
||||||
|
|
||||||
logger.info(f"MCP server starting on {MCP_HOST}:{MCP_PORT} (OAuth enabled)")
|
|
||||||
config = uvicorn.Config(authed_app, host=MCP_HOST, port=MCP_PORT, log_level="info")
|
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
await server.serve()
|
await server.serve()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -131,32 +131,55 @@ async def token_endpoint(request: Request) -> JSONResponse:
|
||||||
- grant_type=authorization_code (with PKCE)
|
- grant_type=authorization_code (with PKCE)
|
||||||
- grant_type=client_credentials (direct)
|
- grant_type=client_credentials (direct)
|
||||||
"""
|
"""
|
||||||
try:
|
# Parse request data from any source
|
||||||
if request.method == "GET":
|
data = {}
|
||||||
data = dict(request.query_params)
|
|
||||||
else:
|
|
||||||
content_type = request.headers.get("content-type", "")
|
|
||||||
if "application/json" in content_type:
|
|
||||||
data = await request.json()
|
|
||||||
else:
|
|
||||||
form = await request.form()
|
|
||||||
data = dict(form)
|
|
||||||
except Exception:
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "invalid_request", "error_description": "Could not parse request body"},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
grant_type = data.get("grant_type")
|
# Always include query params
|
||||||
|
data.update(dict(request.query_params))
|
||||||
|
|
||||||
|
# Parse body for POST
|
||||||
|
if request.method == "POST":
|
||||||
|
try:
|
||||||
|
raw_body = await request.body()
|
||||||
|
logger.info(f"TOKEN POST body ({len(raw_body)} bytes): {raw_body[:500]}")
|
||||||
|
content_type = request.headers.get("content-type", "")
|
||||||
|
if "application/json" in content_type and raw_body:
|
||||||
|
data.update(json.loads(raw_body))
|
||||||
|
elif raw_body:
|
||||||
|
from urllib.parse import parse_qs
|
||||||
|
parsed = parse_qs(raw_body.decode(), keep_blank_values=True)
|
||||||
|
data.update({k: v[0] if len(v) == 1 else v for k, v in parsed.items()})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"TOKEN body parse error: {e}")
|
||||||
|
|
||||||
|
logger.info(f"TOKEN {request.method} data_keys={list(data.keys())} "
|
||||||
|
f"grant_type={data.get('grant_type', '(missing)')}")
|
||||||
|
|
||||||
|
# Bare GET with no params = probe/ping, return 200 with metadata
|
||||||
|
if request.method == "GET" and not data:
|
||||||
|
return JSONResponse({
|
||||||
|
"grant_types_supported": ["authorization_code", "client_credentials"],
|
||||||
|
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
||||||
|
"code_challenge_methods_supported": ["S256"],
|
||||||
|
})
|
||||||
|
|
||||||
|
grant_type = data.get("grant_type", "")
|
||||||
|
|
||||||
|
logger.debug(f"Token request: method={request.method} grant_type={grant_type} params={list(data.keys())}")
|
||||||
|
|
||||||
|
# If there's a code param but no grant_type, assume authorization_code
|
||||||
|
if not grant_type and "code" in data:
|
||||||
|
grant_type = "authorization_code"
|
||||||
|
|
||||||
if grant_type == "authorization_code":
|
if grant_type == "authorization_code":
|
||||||
return await _handle_auth_code_grant(data, request)
|
return await _handle_auth_code_grant(data, request)
|
||||||
elif grant_type == "client_credentials":
|
elif grant_type == "client_credentials":
|
||||||
return await _handle_client_credentials_grant(data, request)
|
return await _handle_client_credentials_grant(data, request)
|
||||||
else:
|
else:
|
||||||
|
logger.warning(f"Token request with unsupported grant_type={grant_type!r}, data keys={list(data.keys())}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"error": "unsupported_grant_type",
|
{"error": "unsupported_grant_type",
|
||||||
"error_description": "Supported: authorization_code, client_credentials"},
|
"error_description": f"Supported: authorization_code, client_credentials. Got: {grant_type!r}"},
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -248,8 +271,9 @@ async def oauth_metadata(request: Request) -> JSONResponse:
|
||||||
"issuer": base,
|
"issuer": base,
|
||||||
"authorization_endpoint": f"{base}/authorize",
|
"authorization_endpoint": f"{base}/authorize",
|
||||||
"token_endpoint": f"{base}/token",
|
"token_endpoint": f"{base}/token",
|
||||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
"registration_endpoint": f"{base}/register",
|
||||||
"grant_types_supported": ["authorization_code", "client_credentials"],
|
"token_endpoint_auth_methods_supported": ["none"],
|
||||||
|
"grant_types_supported": ["authorization_code"],
|
||||||
"response_types_supported": ["code"],
|
"response_types_supported": ["code"],
|
||||||
"code_challenge_methods_supported": ["S256"],
|
"code_challenge_methods_supported": ["S256"],
|
||||||
"scopes_supported": ["mcp"],
|
"scopes_supported": ["mcp"],
|
||||||
|
|
@ -268,9 +292,38 @@ async def protected_resource_metadata(request: Request) -> JSONResponse:
|
||||||
|
|
||||||
|
|
||||||
# Routes to add to the app
|
# Routes to add to the app
|
||||||
|
async def register_endpoint(request: Request) -> JSONResponse:
|
||||||
|
"""OAuth 2.0 Dynamic Client Registration (stub).
|
||||||
|
|
||||||
|
Claude may try to register dynamically. Since we use pre-configured
|
||||||
|
credentials, just return the existing client info.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client_id, _ = _get_oauth_credentials()
|
||||||
|
except RuntimeError:
|
||||||
|
return JSONResponse({"error": "server_error"}, status_code=500)
|
||||||
|
|
||||||
|
base = str(request.base_url).rstrip("/")
|
||||||
|
logger.info(f"REGISTER endpoint called: method={request.method}")
|
||||||
|
if request.method == "POST":
|
||||||
|
body = await request.body()
|
||||||
|
logger.info(f"REGISTER body: {body[:500]}")
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_name": "claude-desktop",
|
||||||
|
"redirect_uris": ["https://claude.ai/api/mcp/auth_callback"],
|
||||||
|
"grant_types": ["authorization_code"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"token_endpoint_auth_method": "none",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
auth_routes = [
|
auth_routes = [
|
||||||
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
|
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
|
||||||
Route("/.well-known/oauth-protected-resource", protected_resource_metadata, methods=["GET"]),
|
Route("/.well-known/oauth-protected-resource", protected_resource_metadata, methods=["GET"]),
|
||||||
|
Route("/.well-known/oauth-protected-resource/mcp", protected_resource_metadata, methods=["GET"]),
|
||||||
Route("/authorize", authorize_endpoint, methods=["GET"]),
|
Route("/authorize", authorize_endpoint, methods=["GET"]),
|
||||||
Route("/token", token_endpoint, methods=["GET", "POST"]),
|
Route("/token", token_endpoint, methods=["GET", "POST"]),
|
||||||
|
Route("/register", register_endpoint, methods=["GET", "POST"]),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -5,84 +5,29 @@ import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
from fastmcp.server.auth import OAuthProvider
|
||||||
|
from fastmcp.server.auth.auth import ClientRegistrationOptions
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
|
|
||||||
from .db import Database
|
from .db import Database
|
||||||
from .config import get_group_chat_id
|
from .config import get_group_chat_id, MCP_HOST, MCP_PORT
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will be initialized in __main__ with shared db instance
|
# Will be initialized in __main__ with shared db instance
|
||||||
db: Database | None = None
|
db: Database | None = None
|
||||||
|
|
||||||
|
# OAuth provider with Dynamic Client Registration enabled
|
||||||
async def ingest_message(request: Request) -> JSONResponse:
|
oauth = OAuthProvider(
|
||||||
"""HTTP endpoint for local services to log messages into the bridge.
|
base_url="https://mcp.georgsen.dk",
|
||||||
|
client_registration_options=ClientRegistrationOptions(enabled=True),
|
||||||
POST /api/ingest
|
)
|
||||||
{
|
|
||||||
"telegram_message_id": 123, # required
|
|
||||||
"chat_id": -100..., # required
|
|
||||||
"sender_type": "homelab_bot", # required
|
|
||||||
"sender_id": 8521598773, # optional
|
|
||||||
"sender_name": "Homelab Bot", # optional
|
|
||||||
"content": "message text", # optional
|
|
||||||
"reply_to_message_id": null, # optional
|
|
||||||
"created_at": "ISO8601" # optional (defaults to now)
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = await request.json()
|
|
||||||
except Exception:
|
|
||||||
return JSONResponse({"error": "invalid JSON"}, status_code=400)
|
|
||||||
|
|
||||||
telegram_message_id = data.get("telegram_message_id")
|
|
||||||
chat_id = data.get("chat_id")
|
|
||||||
if not telegram_message_id or not chat_id:
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "telegram_message_id and chat_id are required"},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
created_at = data.get("created_at", datetime.now(timezone.utc).isoformat())
|
|
||||||
|
|
||||||
msg_id = db.insert_message(
|
|
||||||
telegram_message_id=telegram_message_id,
|
|
||||||
chat_id=chat_id,
|
|
||||||
sender_type=data.get("sender_type", "unknown"),
|
|
||||||
sender_id=data.get("sender_id"),
|
|
||||||
sender_name=data.get("sender_name"),
|
|
||||||
content=data.get("content"),
|
|
||||||
reply_to_message_id=data.get("reply_to_message_id"),
|
|
||||||
has_attachment=data.get("has_attachment", False),
|
|
||||||
created_at=created_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
if msg_id is None:
|
|
||||||
return JSONResponse({"ok": True, "duplicate": True})
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}"
|
|
||||||
)
|
|
||||||
return JSONResponse({"ok": True, "id": msg_id})
|
|
||||||
|
|
||||||
|
|
||||||
async def health(request: Request) -> JSONResponse:
|
|
||||||
"""Health check endpoint."""
|
|
||||||
status = db.get_status()
|
|
||||||
return JSONResponse({"status": "ok", **status})
|
|
||||||
|
|
||||||
|
|
||||||
# Custom routes added to the FastMCP app
|
|
||||||
custom_routes = [
|
|
||||||
Route("/api/ingest", ingest_message, methods=["POST"]),
|
|
||||||
Route("/api/health", health, methods=["GET"]),
|
|
||||||
]
|
|
||||||
|
|
||||||
mcp = FastMCP(
|
mcp = FastMCP(
|
||||||
name="homelab-bridge",
|
name="homelab-bridge",
|
||||||
|
auth=oauth,
|
||||||
instructions=(
|
instructions=(
|
||||||
"This MCP server bridges claude.ai to a homelab Telegram group chat. "
|
"This MCP server bridges claude.ai to a homelab Telegram group chat. "
|
||||||
"Use pull_updates to read conversation history (supports cursor-based pagination). "
|
"Use pull_updates to read conversation history (supports cursor-based pagination). "
|
||||||
|
|
@ -158,3 +103,54 @@ def queue_status() -> str:
|
||||||
"""
|
"""
|
||||||
status = db.get_status()
|
status = db.get_status()
|
||||||
return json.dumps(status)
|
return json.dumps(status)
|
||||||
|
|
||||||
|
|
||||||
|
# Custom non-MCP routes (no auth required)
|
||||||
|
async def ingest_message(request: Request) -> JSONResponse:
|
||||||
|
"""HTTP endpoint for local services to log messages into the bridge."""
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
except Exception:
|
||||||
|
return JSONResponse({"error": "invalid JSON"}, status_code=400)
|
||||||
|
|
||||||
|
telegram_message_id = data.get("telegram_message_id")
|
||||||
|
chat_id = data.get("chat_id")
|
||||||
|
if not telegram_message_id or not chat_id:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "telegram_message_id and chat_id are required"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
created_at = data.get("created_at", datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
|
msg_id = db.insert_message(
|
||||||
|
telegram_message_id=telegram_message_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
sender_type=data.get("sender_type", "unknown"),
|
||||||
|
sender_id=data.get("sender_id"),
|
||||||
|
sender_name=data.get("sender_name"),
|
||||||
|
content=data.get("content"),
|
||||||
|
reply_to_message_id=data.get("reply_to_message_id"),
|
||||||
|
has_attachment=data.get("has_attachment", False),
|
||||||
|
created_at=created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
if msg_id is None:
|
||||||
|
return JSONResponse({"ok": True, "duplicate": True})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}"
|
||||||
|
)
|
||||||
|
return JSONResponse({"ok": True, "id": msg_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def health(request: Request) -> JSONResponse:
|
||||||
|
"""Health check endpoint."""
|
||||||
|
status = db.get_status()
|
||||||
|
return JSONResponse({"status": "ok", **status})
|
||||||
|
|
||||||
|
|
||||||
|
custom_routes = [
|
||||||
|
Route("/api/ingest", ingest_message, methods=["POST"]),
|
||||||
|
Route("/api/health", health, methods=["GET"]),
|
||||||
|
]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue