feat: add OAuth client credentials auth to MCP server
- OAuth 2.0 discovery at /.well-known/oauth-authorization-server - Token endpoint at /token (client_credentials grant) - Bearer token middleware on /mcp (all MCP requests require auth) - Health, ingest, and OAuth endpoints remain public - Tokens expire after 1 hour, stored hashed in memory Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
494bb510d3
commit
205b978b89
3 changed files with 196 additions and 6 deletions
|
|
@ -11,3 +11,7 @@ GROUP_CHAT_ID=
|
|||
# (Optional) Bot ID of the existing homelab bot, for sender classification
|
||||
# Find it: https://api.telegram.org/bot<HOMELAB_TOKEN>/getMe
|
||||
HOMELAB_BOT_ID=8521598773
|
||||
|
||||
# OAuth client credentials for MCP auth (generate with: python3 -c "import secrets; print(secrets.token_urlsafe(32))")
|
||||
OAUTH_CLIENT_ID=
|
||||
OAUTH_CLIENT_SECRET=
|
||||
|
|
|
|||
|
|
@ -4,10 +4,15 @@ import asyncio
|
|||
import logging
|
||||
import signal
|
||||
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from .config import MCP_HOST, MCP_PORT, MEDIA_DIR, DB_PATH
|
||||
from .db import Database
|
||||
from .telegram_bot import BridgeBot
|
||||
from .mcp_server import mcp, init as init_mcp, custom_routes
|
||||
from .auth import auth_routes, validate_bearer_token
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -16,12 +21,53 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger("mcp_bridge")
|
||||
|
||||
# Paths that don't require auth
|
||||
PUBLIC_PATHS = {
|
||||
"/.well-known/oauth-authorization-server",
|
||||
"/token",
|
||||
"/api/health",
|
||||
"/api/ingest", # Local-only, not exposed via NPM
|
||||
}
|
||||
|
||||
|
||||
class AuthMiddleware:
|
||||
"""ASGI middleware that validates Bearer tokens on protected endpoints."""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Skip auth for public paths
|
||||
if path in PUBLIC_PATHS:
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Check Authorization header
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
if validate_bearer_token(token):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Reject — send 401
|
||||
response = JSONResponse(
|
||||
{"error": "unauthorized", "error_description": "Valid Bearer token required"},
|
||||
status_code=401,
|
||||
headers={"WWW-Authenticate": 'Bearer realm="mcp"'},
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
|
||||
|
||||
async def run_telegram_bot(bot: BridgeBot):
|
||||
"""Run the telegram bot polling loop."""
|
||||
app = bot.build_application()
|
||||
await app.initialize()
|
||||
# post_init isn't called automatically when we manually start
|
||||
await bot._post_init(app)
|
||||
await app.start()
|
||||
updater = app.updater
|
||||
|
|
@ -39,15 +85,21 @@ async def run_telegram_bot(bot: BridgeBot):
|
|||
|
||||
|
||||
async def run_mcp_server():
|
||||
"""Run the FastMCP HTTP server with custom API routes."""
|
||||
"""Run the FastMCP HTTP server with OAuth auth and custom routes."""
|
||||
import uvicorn
|
||||
|
||||
# Get the FastMCP Starlette app and add our custom routes to it
|
||||
# Get the FastMCP Starlette app
|
||||
mcp_app = mcp.http_app()
|
||||
mcp_app.routes.extend(custom_routes)
|
||||
|
||||
logger.info(f"MCP server starting on {MCP_HOST}:{MCP_PORT}")
|
||||
config = uvicorn.Config(mcp_app, host=MCP_HOST, port=MCP_PORT, log_level="info")
|
||||
# Add custom routes (API + OAuth)
|
||||
mcp_app.routes.extend(custom_routes)
|
||||
mcp_app.routes.extend(auth_routes)
|
||||
|
||||
# Wrap with auth middleware
|
||||
authed_app = AuthMiddleware(mcp_app)
|
||||
|
||||
logger.info(f"MCP server starting on {MCP_HOST}:{MCP_PORT} (OAuth enabled)")
|
||||
config = uvicorn.Config(authed_app, host=MCP_HOST, port=MCP_PORT, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
|
|
|||
134
mcp_bridge/auth.py
Normal file
134
mcp_bridge/auth.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""OAuth 2.0 client credentials auth for MCP server."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from .config import load_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# In-memory token store: token_hash -> expiry timestamp
|
||||
_active_tokens: dict[str, float] = {}
|
||||
TOKEN_LIFETIME = 3600 # 1 hour
|
||||
|
||||
|
||||
def _get_oauth_credentials() -> tuple[str, str]:
|
||||
"""Load OAuth client_id and client_secret from credentials file."""
|
||||
creds = load_credentials()
|
||||
client_id = creds.get("OAUTH_CLIENT_ID", "")
|
||||
client_secret = creds.get("OAUTH_CLIENT_SECRET", "")
|
||||
if not client_id or not client_secret:
|
||||
raise RuntimeError("OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET must be set in credentials")
|
||||
return client_id, client_secret
|
||||
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def _cleanup_expired():
|
||||
"""Remove expired tokens."""
|
||||
now = time.time()
|
||||
expired = [h for h, exp in _active_tokens.items() if exp < now]
|
||||
for h in expired:
|
||||
del _active_tokens[h]
|
||||
|
||||
|
||||
def validate_bearer_token(token: str) -> bool:
|
||||
"""Check if a bearer token is valid and not expired."""
|
||||
_cleanup_expired()
|
||||
token_hash = _hash_token(token)
|
||||
return token_hash in _active_tokens and _active_tokens[token_hash] > time.time()
|
||||
|
||||
|
||||
async def token_endpoint(request: Request) -> JSONResponse:
|
||||
"""OAuth 2.0 token endpoint (client_credentials grant).
|
||||
|
||||
POST /token
|
||||
Content-Type: application/x-www-form-urlencoded
|
||||
|
||||
grant_type=client_credentials&client_id=...&client_secret=...
|
||||
"""
|
||||
try:
|
||||
# Accept both form-encoded and JSON
|
||||
content_type = request.headers.get("content-type", "")
|
||||
if "application/json" in content_type:
|
||||
data = await request.json()
|
||||
else:
|
||||
form = await request.form()
|
||||
data = dict(form)
|
||||
except Exception:
|
||||
return JSONResponse(
|
||||
{"error": "invalid_request", "error_description": "Could not parse request body"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
grant_type = data.get("grant_type")
|
||||
if grant_type != "client_credentials":
|
||||
return JSONResponse(
|
||||
{"error": "unsupported_grant_type", "error_description": "Only client_credentials is supported"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
client_id = data.get("client_id", "")
|
||||
client_secret = data.get("client_secret", "")
|
||||
|
||||
try:
|
||||
expected_id, expected_secret = _get_oauth_credentials()
|
||||
except RuntimeError:
|
||||
logger.error("OAuth credentials not configured")
|
||||
return JSONResponse(
|
||||
{"error": "server_error", "error_description": "Auth not configured"},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(client_id, expected_id) or \
|
||||
not secrets.compare_digest(client_secret, expected_secret):
|
||||
logger.warning(f"OAuth auth failed from {request.client.host}")
|
||||
return JSONResponse(
|
||||
{"error": "invalid_client", "error_description": "Invalid client credentials"},
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
# Issue token
|
||||
access_token = secrets.token_urlsafe(48)
|
||||
_active_tokens[_hash_token(access_token)] = time.time() + TOKEN_LIFETIME
|
||||
_cleanup_expired()
|
||||
|
||||
logger.info(f"OAuth token issued to {request.client.host}")
|
||||
return JSONResponse({
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": TOKEN_LIFETIME,
|
||||
})
|
||||
|
||||
|
||||
async def oauth_metadata(request: Request) -> JSONResponse:
|
||||
"""OAuth 2.0 Authorization Server Metadata (RFC 8414).
|
||||
|
||||
GET /.well-known/oauth-authorization-server
|
||||
"""
|
||||
# Build base URL from request
|
||||
base = str(request.base_url).rstrip("/")
|
||||
return JSONResponse({
|
||||
"issuer": base,
|
||||
"token_endpoint": f"{base}/token",
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post"],
|
||||
"grant_types_supported": ["client_credentials"],
|
||||
"response_types_supported": [],
|
||||
"scopes_supported": ["mcp"],
|
||||
})
|
||||
|
||||
|
||||
# Routes to add to the app
|
||||
auth_routes = [
|
||||
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
|
||||
Route("/token", token_endpoint, methods=["POST"]),
|
||||
]
|
||||
Loading…
Add table
Reference in a new issue