OAuthProvider is abstract — subclassed as HomelabOAuth with full implementation of register_client, get_client, create/exchange authorization codes, token issuance, PKCE verification, and refresh token rotation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
289 lines
9.8 KiB
Python
289 lines
9.8 KiB
Python
"""FastMCP server exposing bridge tools to claude.ai."""
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
from fastmcp import FastMCP
|
|
import secrets
|
|
import time
|
|
import hashlib
|
|
|
|
from fastmcp.server.auth import OAuthProvider, AccessToken
|
|
from fastmcp.server.auth.auth import ClientRegistrationOptions
|
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
from starlette.routing import Route
|
|
|
|
from .db import Database
|
|
from .config import get_group_chat_id, MCP_HOST, MCP_PORT
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Will be initialized in __main__ with shared db instance
|
|
db: Database | None = None
|
|
|
|
TOKEN_LIFETIME = 3600
|
|
|
|
|
|
class HomelabOAuth(OAuthProvider):
|
|
"""Concrete OAuth provider with in-memory storage."""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._clients: dict[str, OAuthClientInformationFull] = {}
|
|
self._auth_codes: dict[str, dict] = {} # code -> {client_id, code_challenge, redirect_uri, expires}
|
|
self._tokens: dict[str, dict] = {} # token_hash -> {client_id, scopes, expires}
|
|
self._refresh_tokens: dict[str, dict] = {} # refresh_hash -> {client_id, scopes}
|
|
|
|
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
|
self._clients[client_info.client_id] = client_info
|
|
logger.info(f"Registered OAuth client: {client_info.client_id} ({client_info.client_name})")
|
|
|
|
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
|
return self._clients.get(client_id)
|
|
|
|
async def create_authorization_code(
|
|
self, client: OAuthClientInformationFull, code_challenge: str | None,
|
|
redirect_uri: str | None, scopes: list[str] | None = None,
|
|
**kwargs,
|
|
) -> str:
|
|
code = secrets.token_urlsafe(32)
|
|
self._auth_codes[code] = {
|
|
"client_id": client.client_id,
|
|
"code_challenge": code_challenge,
|
|
"redirect_uri": redirect_uri or str(client.redirect_uris[0]),
|
|
"scopes": scopes or [],
|
|
"expires": time.time() + 300,
|
|
}
|
|
logger.info(f"Auth code issued for client {client.client_id}")
|
|
return code
|
|
|
|
async def exchange_authorization_code(
|
|
self, client: OAuthClientInformationFull, code: str,
|
|
code_verifier: str | None = None, redirect_uri: str | None = None,
|
|
**kwargs,
|
|
) -> OAuthToken:
|
|
if code not in self._auth_codes:
|
|
raise ValueError("Invalid authorization code")
|
|
|
|
code_data = self._auth_codes.pop(code)
|
|
|
|
if code_data["expires"] < time.time():
|
|
raise ValueError("Authorization code expired")
|
|
|
|
if code_data["client_id"] != client.client_id:
|
|
raise ValueError("Client ID mismatch")
|
|
|
|
# PKCE verification
|
|
if code_data["code_challenge"] and code_verifier:
|
|
import base64
|
|
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
|
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
|
if computed != code_data["code_challenge"]:
|
|
raise ValueError("PKCE verification failed")
|
|
|
|
# Issue tokens
|
|
access_token = secrets.token_urlsafe(48)
|
|
refresh_token = secrets.token_urlsafe(48)
|
|
|
|
self._tokens[hashlib.sha256(access_token.encode()).hexdigest()] = {
|
|
"client_id": client.client_id,
|
|
"scopes": code_data["scopes"],
|
|
"expires": time.time() + TOKEN_LIFETIME,
|
|
}
|
|
self._refresh_tokens[hashlib.sha256(refresh_token.encode()).hexdigest()] = {
|
|
"client_id": client.client_id,
|
|
"scopes": code_data["scopes"],
|
|
}
|
|
|
|
logger.info(f"Token issued for client {client.client_id}")
|
|
return OAuthToken(
|
|
access_token=access_token,
|
|
token_type="Bearer",
|
|
expires_in=TOKEN_LIFETIME,
|
|
refresh_token=refresh_token,
|
|
scope=" ".join(code_data["scopes"]) if code_data["scopes"] else None,
|
|
)
|
|
|
|
async def load_access_token(self, token: str) -> AccessToken | None:
|
|
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
data = self._tokens.get(token_hash)
|
|
if not data or data["expires"] < time.time():
|
|
return None
|
|
return AccessToken(
|
|
token=token,
|
|
client_id=data["client_id"],
|
|
scopes=data["scopes"],
|
|
)
|
|
|
|
async def exchange_refresh_token(
|
|
self, client: OAuthClientInformationFull,
|
|
refresh_token: str, scopes: list[str] | None = None,
|
|
**kwargs,
|
|
) -> OAuthToken:
|
|
refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
|
data = self._refresh_tokens.get(refresh_hash)
|
|
if not data or data["client_id"] != client.client_id:
|
|
raise ValueError("Invalid refresh token")
|
|
|
|
# Issue new tokens
|
|
new_access = secrets.token_urlsafe(48)
|
|
new_refresh = secrets.token_urlsafe(48)
|
|
|
|
self._tokens[hashlib.sha256(new_access.encode()).hexdigest()] = {
|
|
"client_id": client.client_id,
|
|
"scopes": scopes or data["scopes"],
|
|
"expires": time.time() + TOKEN_LIFETIME,
|
|
}
|
|
# Rotate refresh token
|
|
del self._refresh_tokens[refresh_hash]
|
|
self._refresh_tokens[hashlib.sha256(new_refresh.encode()).hexdigest()] = {
|
|
"client_id": client.client_id,
|
|
"scopes": scopes or data["scopes"],
|
|
}
|
|
|
|
return OAuthToken(
|
|
access_token=new_access,
|
|
token_type="Bearer",
|
|
expires_in=TOKEN_LIFETIME,
|
|
refresh_token=new_refresh,
|
|
scope=" ".join(scopes or data["scopes"]) if (scopes or data["scopes"]) else None,
|
|
)
|
|
|
|
|
|
oauth = HomelabOAuth(
|
|
base_url="https://mcp.georgsen.dk",
|
|
client_registration_options=ClientRegistrationOptions(enabled=True),
|
|
)
|
|
|
|
mcp = FastMCP(
|
|
name="homelab-bridge",
|
|
auth=oauth,
|
|
instructions=(
|
|
"This MCP server bridges claude.ai to a homelab Telegram group chat. "
|
|
"Use pull_updates to read conversation history (supports cursor-based pagination). "
|
|
"Use send_message to post messages to the group (attributed as [claude.ai]). "
|
|
"Use queue_status for a quick summary."
|
|
),
|
|
)
|
|
|
|
|
|
def init(database: Database):
|
|
"""Set the shared database instance."""
|
|
global db
|
|
db = database
|
|
|
|
|
|
@mcp.tool()
|
|
def send_message(message: str) -> str:
|
|
"""Send a message to the homelab Telegram group chat.
|
|
|
|
The message will be posted with [claude.ai] attribution so participants
|
|
know the message came from claude.ai.
|
|
|
|
Args:
|
|
message: The text to send to the group chat.
|
|
"""
|
|
chat_id = get_group_chat_id()
|
|
outbound_id = db.queue_outbound(chat_id, message)
|
|
return json.dumps({"sent": True, "id": outbound_id})
|
|
|
|
|
|
@mcp.tool()
|
|
def pull_updates(since_id: int = 0, since: str | None = None, limit: int = 50) -> str:
|
|
"""Pull conversation messages from the Telegram group.
|
|
|
|
Returns messages from all participants (Mikkel, homelab bot, MCP bot).
|
|
Supports cursor-based pagination: use the returned 'cursor' value as
|
|
'since_id' in the next call to get only new messages.
|
|
|
|
Args:
|
|
since_id: Return messages with id > this value. Use cursor from previous response.
|
|
since: ISO 8601 timestamp. Alternative to since_id — returns messages after this time.
|
|
limit: Maximum number of messages to return (default 50, max 200).
|
|
"""
|
|
limit = min(limit, 200)
|
|
|
|
if since:
|
|
messages = db.get_messages_since_timestamp(since, limit)
|
|
else:
|
|
messages = db.get_messages_since_id(since_id, limit)
|
|
|
|
# Enrich with attachment info
|
|
for msg in messages:
|
|
if msg["has_attachment"]:
|
|
msg["attachments"] = db.get_attachments_for_message(msg["id"])
|
|
else:
|
|
msg["attachments"] = []
|
|
del msg["has_attachment"]
|
|
|
|
cursor = messages[-1]["id"] if messages else since_id
|
|
|
|
return json.dumps({
|
|
"messages": messages,
|
|
"cursor": cursor,
|
|
"count": len(messages),
|
|
})
|
|
|
|
|
|
@mcp.tool()
|
|
def queue_status() -> str:
|
|
"""Get current status of the bridge.
|
|
|
|
Returns message counts, last activity, and pending outbound messages.
|
|
"""
|
|
status = db.get_status()
|
|
return json.dumps(status)
|
|
|
|
|
|
# Custom non-MCP routes (no auth required)
|
|
async def ingest_message(request: Request) -> JSONResponse:
|
|
"""HTTP endpoint for local services to log messages into the bridge."""
|
|
try:
|
|
data = await request.json()
|
|
except Exception:
|
|
return JSONResponse({"error": "invalid JSON"}, status_code=400)
|
|
|
|
telegram_message_id = data.get("telegram_message_id")
|
|
chat_id = data.get("chat_id")
|
|
if not telegram_message_id or not chat_id:
|
|
return JSONResponse(
|
|
{"error": "telegram_message_id and chat_id are required"},
|
|
status_code=400,
|
|
)
|
|
|
|
created_at = data.get("created_at", datetime.now(timezone.utc).isoformat())
|
|
|
|
msg_id = db.insert_message(
|
|
telegram_message_id=telegram_message_id,
|
|
chat_id=chat_id,
|
|
sender_type=data.get("sender_type", "unknown"),
|
|
sender_id=data.get("sender_id"),
|
|
sender_name=data.get("sender_name"),
|
|
content=data.get("content"),
|
|
reply_to_message_id=data.get("reply_to_message_id"),
|
|
has_attachment=data.get("has_attachment", False),
|
|
created_at=created_at,
|
|
)
|
|
|
|
if msg_id is None:
|
|
return JSONResponse({"ok": True, "duplicate": True})
|
|
|
|
logger.info(
|
|
f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}"
|
|
)
|
|
return JSONResponse({"ok": True, "id": msg_id})
|
|
|
|
|
|
async def health(request: Request) -> JSONResponse:
|
|
"""Health check endpoint."""
|
|
status = db.get_status()
|
|
return JSONResponse({"status": "ok", **status})
|
|
|
|
|
|
custom_routes = [
|
|
Route("/api/ingest", ingest_message, methods=["POST"]),
|
|
Route("/api/health", health, methods=["GET"]),
|
|
]
|