telegram-bot-mcp/mcp_bridge/mcp_server.py
Mikkel Georgsen 1296310adb fix: implement authorize() to auto-approve and redirect with code
Parent's authorize() is abstract and returned None, causing /None redirect.
Override creates auth code and redirects to callback immediately.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-30 11:29:13 +00:00

303 lines
10 KiB
Python

"""FastMCP server exposing bridge tools to claude.ai."""
import json
import logging
from datetime import datetime, timezone
from fastmcp import FastMCP
import secrets
import time
import hashlib
from urllib.parse import urlencode
from fastmcp.server.auth import OAuthProvider, AccessToken
from fastmcp.server.auth.auth import ClientRegistrationOptions
from mcp.server.auth.provider import AuthorizationParams
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from .db import Database
from .config import get_group_chat_id, MCP_HOST, MCP_PORT
logger = logging.getLogger(__name__)
# Will be initialized in __main__ with shared db instance
db: Database | None = None
TOKEN_LIFETIME = 3600
class HomelabOAuth(OAuthProvider):
"""Concrete OAuth provider with in-memory storage."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._clients: dict[str, OAuthClientInformationFull] = {}
self._auth_codes: dict[str, dict] = {} # code -> {client_id, code_challenge, redirect_uri, expires}
self._tokens: dict[str, dict] = {} # token_hash -> {client_id, scopes, expires}
self._refresh_tokens: dict[str, dict] = {} # refresh_hash -> {client_id, scopes}
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
self._clients[client_info.client_id] = client_info
logger.info(f"Registered OAuth client: {client_info.client_id} ({client_info.client_name})")
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
return self._clients.get(client_id)
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
"""Auto-approve and redirect back with auth code (single-user setup)."""
code = await self.create_authorization_code(
client, params.code_challenge, str(params.redirect_uri), params.scopes,
)
redirect_params = {"code": code}
if params.state:
redirect_params["state"] = params.state
redirect_uri = str(params.redirect_uri)
separator = "&" if "?" in redirect_uri else "?"
return f"{redirect_uri}{separator}{urlencode(redirect_params)}"
async def create_authorization_code(
self, client: OAuthClientInformationFull, code_challenge: str | None,
redirect_uri: str | None, scopes: list[str] | None = None,
**kwargs,
) -> str:
code = secrets.token_urlsafe(32)
self._auth_codes[code] = {
"client_id": client.client_id,
"code_challenge": code_challenge,
"redirect_uri": redirect_uri or str(client.redirect_uris[0]),
"scopes": scopes or [],
"expires": time.time() + 300,
}
logger.info(f"Auth code issued for client {client.client_id}")
return code
async def exchange_authorization_code(
self, client: OAuthClientInformationFull, code: str,
code_verifier: str | None = None, redirect_uri: str | None = None,
**kwargs,
) -> OAuthToken:
if code not in self._auth_codes:
raise ValueError("Invalid authorization code")
code_data = self._auth_codes.pop(code)
if code_data["expires"] < time.time():
raise ValueError("Authorization code expired")
if code_data["client_id"] != client.client_id:
raise ValueError("Client ID mismatch")
# PKCE verification
if code_data["code_challenge"] and code_verifier:
import base64
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
if computed != code_data["code_challenge"]:
raise ValueError("PKCE verification failed")
# Issue tokens
access_token = secrets.token_urlsafe(48)
refresh_token = secrets.token_urlsafe(48)
self._tokens[hashlib.sha256(access_token.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": code_data["scopes"],
"expires": time.time() + TOKEN_LIFETIME,
}
self._refresh_tokens[hashlib.sha256(refresh_token.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": code_data["scopes"],
}
logger.info(f"Token issued for client {client.client_id}")
return OAuthToken(
access_token=access_token,
token_type="Bearer",
expires_in=TOKEN_LIFETIME,
refresh_token=refresh_token,
scope=" ".join(code_data["scopes"]) if code_data["scopes"] else None,
)
async def load_access_token(self, token: str) -> AccessToken | None:
token_hash = hashlib.sha256(token.encode()).hexdigest()
data = self._tokens.get(token_hash)
if not data or data["expires"] < time.time():
return None
return AccessToken(
token=token,
client_id=data["client_id"],
scopes=data["scopes"],
)
async def exchange_refresh_token(
self, client: OAuthClientInformationFull,
refresh_token: str, scopes: list[str] | None = None,
**kwargs,
) -> OAuthToken:
refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
data = self._refresh_tokens.get(refresh_hash)
if not data or data["client_id"] != client.client_id:
raise ValueError("Invalid refresh token")
# Issue new tokens
new_access = secrets.token_urlsafe(48)
new_refresh = secrets.token_urlsafe(48)
self._tokens[hashlib.sha256(new_access.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": scopes or data["scopes"],
"expires": time.time() + TOKEN_LIFETIME,
}
# Rotate refresh token
del self._refresh_tokens[refresh_hash]
self._refresh_tokens[hashlib.sha256(new_refresh.encode()).hexdigest()] = {
"client_id": client.client_id,
"scopes": scopes or data["scopes"],
}
return OAuthToken(
access_token=new_access,
token_type="Bearer",
expires_in=TOKEN_LIFETIME,
refresh_token=new_refresh,
scope=" ".join(scopes or data["scopes"]) if (scopes or data["scopes"]) else None,
)
oauth = HomelabOAuth(
base_url="https://mcp.georgsen.dk",
client_registration_options=ClientRegistrationOptions(enabled=True),
)
mcp = FastMCP(
name="homelab-bridge",
auth=oauth,
instructions=(
"This MCP server bridges claude.ai to a homelab Telegram group chat. "
"Use pull_updates to read conversation history (supports cursor-based pagination). "
"Use send_message to post messages to the group (attributed as [claude.ai]). "
"Use queue_status for a quick summary."
),
)
def init(database: Database):
"""Set the shared database instance."""
global db
db = database
@mcp.tool()
def send_message(message: str) -> str:
"""Send a message to the homelab Telegram group chat.
The message will be posted with [claude.ai] attribution so participants
know the message came from claude.ai.
Args:
message: The text to send to the group chat.
"""
chat_id = get_group_chat_id()
outbound_id = db.queue_outbound(chat_id, message)
return json.dumps({"sent": True, "id": outbound_id})
@mcp.tool()
def pull_updates(since_id: int = 0, since: str | None = None, limit: int = 50) -> str:
"""Pull conversation messages from the Telegram group.
Returns messages from all participants (Mikkel, homelab bot, MCP bot).
Supports cursor-based pagination: use the returned 'cursor' value as
'since_id' in the next call to get only new messages.
Args:
since_id: Return messages with id > this value. Use cursor from previous response.
since: ISO 8601 timestamp. Alternative to since_id — returns messages after this time.
limit: Maximum number of messages to return (default 50, max 200).
"""
limit = min(limit, 200)
if since:
messages = db.get_messages_since_timestamp(since, limit)
else:
messages = db.get_messages_since_id(since_id, limit)
# Enrich with attachment info
for msg in messages:
if msg["has_attachment"]:
msg["attachments"] = db.get_attachments_for_message(msg["id"])
else:
msg["attachments"] = []
del msg["has_attachment"]
cursor = messages[-1]["id"] if messages else since_id
return json.dumps({
"messages": messages,
"cursor": cursor,
"count": len(messages),
})
@mcp.tool()
def queue_status() -> str:
"""Get current status of the bridge.
Returns message counts, last activity, and pending outbound messages.
"""
status = db.get_status()
return json.dumps(status)
# Custom non-MCP routes (no auth required)
async def ingest_message(request: Request) -> JSONResponse:
"""HTTP endpoint for local services to log messages into the bridge."""
try:
data = await request.json()
except Exception:
return JSONResponse({"error": "invalid JSON"}, status_code=400)
telegram_message_id = data.get("telegram_message_id")
chat_id = data.get("chat_id")
if not telegram_message_id or not chat_id:
return JSONResponse(
{"error": "telegram_message_id and chat_id are required"},
status_code=400,
)
created_at = data.get("created_at", datetime.now(timezone.utc).isoformat())
msg_id = db.insert_message(
telegram_message_id=telegram_message_id,
chat_id=chat_id,
sender_type=data.get("sender_type", "unknown"),
sender_id=data.get("sender_id"),
sender_name=data.get("sender_name"),
content=data.get("content"),
reply_to_message_id=data.get("reply_to_message_id"),
has_attachment=data.get("has_attachment", False),
created_at=created_at,
)
if msg_id is None:
return JSONResponse({"ok": True, "duplicate": True})
logger.info(
f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}"
)
return JSONResponse({"ok": True, "id": msg_id})
async def health(request: Request) -> JSONResponse:
"""Health check endpoint."""
status = db.get_status()
return JSONResponse({"status": "ok", **status})
custom_routes = [
Route("/api/ingest", ingest_message, methods=["POST"]),
Route("/api/health", health, methods=["GET"]),
]