fix: use FastMCP's InMemoryOAuthProvider instead of custom implementation
Replaced hand-rolled OAuth with FastMCP's battle-tested InMemoryOAuthProvider. Handles DCR, PKCE, token exchange, refresh tokens, and revocation out of the box. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
1296310adb
commit
1dff4630fe
2 changed files with 5 additions and 487 deletions
|
|
@ -1,329 +0,0 @@
|
||||||
"""OAuth 2.0 authorization code + PKCE auth for MCP server."""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import secrets
|
|
||||||
import time
|
|
||||||
from urllib.parse import urlencode, urlparse, parse_qs
|
|
||||||
|
|
||||||
from starlette.requests import Request
|
|
||||||
from starlette.responses import JSONResponse, RedirectResponse
|
|
||||||
from starlette.routing import Route
|
|
||||||
|
|
||||||
from .config import load_credentials
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# In-memory stores
|
|
||||||
_active_tokens: dict[str, float] = {} # token_hash -> expiry
|
|
||||||
_auth_codes: dict[str, dict] = {} # code -> {client_id, redirect_uri, code_challenge, expires}
|
|
||||||
TOKEN_LIFETIME = 3600 # 1 hour
|
|
||||||
CODE_LIFETIME = 300 # 5 minutes
|
|
||||||
|
|
||||||
|
|
||||||
def _get_oauth_credentials() -> tuple[str, str]:
|
|
||||||
"""Load OAuth client_id and client_secret from credentials file."""
|
|
||||||
creds = load_credentials()
|
|
||||||
client_id = creds.get("OAUTH_CLIENT_ID", "")
|
|
||||||
client_secret = creds.get("OAUTH_CLIENT_SECRET", "")
|
|
||||||
if not client_id or not client_secret:
|
|
||||||
raise RuntimeError("OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET must be set in credentials")
|
|
||||||
return client_id, client_secret
|
|
||||||
|
|
||||||
|
|
||||||
def _hash(value: str) -> str:
|
|
||||||
return hashlib.sha256(value.encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_expired():
|
|
||||||
"""Remove expired tokens and codes."""
|
|
||||||
now = time.time()
|
|
||||||
for store in (_active_tokens, _auth_codes):
|
|
||||||
expired = [k for k, v in store.items()
|
|
||||||
if (v if isinstance(v, float) else v.get("expires", 0)) < now]
|
|
||||||
for k in expired:
|
|
||||||
del store[k]
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_pkce(code_verifier: str, code_challenge: str) -> bool:
|
|
||||||
"""Verify PKCE S256 challenge."""
|
|
||||||
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
|
||||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
|
||||||
return secrets.compare_digest(computed, code_challenge)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_bearer_token(token: str) -> bool:
|
|
||||||
"""Check if a bearer token is valid and not expired."""
|
|
||||||
_cleanup_expired()
|
|
||||||
token_hash = _hash(token)
|
|
||||||
return token_hash in _active_tokens and _active_tokens[token_hash] > time.time()
|
|
||||||
|
|
||||||
|
|
||||||
async def authorize_endpoint(request: Request):
|
|
||||||
"""OAuth 2.0 Authorization endpoint.
|
|
||||||
|
|
||||||
GET /authorize?response_type=code&client_id=...&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...
|
|
||||||
|
|
||||||
Single-user setup: auto-approves if client_id matches, redirects with code.
|
|
||||||
"""
|
|
||||||
params = dict(request.query_params)
|
|
||||||
|
|
||||||
response_type = params.get("response_type")
|
|
||||||
client_id = params.get("client_id", "")
|
|
||||||
redirect_uri = params.get("redirect_uri", "")
|
|
||||||
code_challenge = params.get("code_challenge", "")
|
|
||||||
code_challenge_method = params.get("code_challenge_method", "")
|
|
||||||
state = params.get("state", "")
|
|
||||||
|
|
||||||
if response_type != "code":
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "unsupported_response_type"},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
expected_id, _ = _get_oauth_credentials()
|
|
||||||
except RuntimeError:
|
|
||||||
return JSONResponse({"error": "server_error"}, status_code=500)
|
|
||||||
|
|
||||||
if not secrets.compare_digest(client_id, expected_id):
|
|
||||||
logger.warning(f"OAuth authorize: invalid client_id from {request.client.host}")
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "invalid_client", "error_description": "Unknown client_id"},
|
|
||||||
status_code=401,
|
|
||||||
)
|
|
||||||
|
|
||||||
if code_challenge_method and code_challenge_method != "S256":
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "invalid_request", "error_description": "Only S256 code_challenge_method supported"},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auto-approve: generate authorization code
|
|
||||||
code = secrets.token_urlsafe(32)
|
|
||||||
_auth_codes[code] = {
|
|
||||||
"client_id": client_id,
|
|
||||||
"redirect_uri": redirect_uri,
|
|
||||||
"code_challenge": code_challenge,
|
|
||||||
"expires": time.time() + CODE_LIFETIME,
|
|
||||||
}
|
|
||||||
_cleanup_expired()
|
|
||||||
|
|
||||||
logger.info(f"OAuth authorization code issued for {client_id[:12]}... -> {redirect_uri}")
|
|
||||||
|
|
||||||
# Redirect back with code
|
|
||||||
redirect_params = {"code": code}
|
|
||||||
if state:
|
|
||||||
redirect_params["state"] = state
|
|
||||||
separator = "&" if "?" in redirect_uri else "?"
|
|
||||||
return RedirectResponse(
|
|
||||||
url=f"{redirect_uri}{separator}{urlencode(redirect_params)}",
|
|
||||||
status_code=302,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def token_endpoint(request: Request) -> JSONResponse:
|
|
||||||
"""OAuth 2.0 token endpoint.
|
|
||||||
|
|
||||||
Supports:
|
|
||||||
- grant_type=authorization_code (with PKCE)
|
|
||||||
- grant_type=client_credentials (direct)
|
|
||||||
"""
|
|
||||||
# Parse request data from any source
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
# Always include query params
|
|
||||||
data.update(dict(request.query_params))
|
|
||||||
|
|
||||||
# Parse body for POST
|
|
||||||
if request.method == "POST":
|
|
||||||
try:
|
|
||||||
raw_body = await request.body()
|
|
||||||
logger.info(f"TOKEN POST body ({len(raw_body)} bytes): {raw_body[:500]}")
|
|
||||||
content_type = request.headers.get("content-type", "")
|
|
||||||
if "application/json" in content_type and raw_body:
|
|
||||||
data.update(json.loads(raw_body))
|
|
||||||
elif raw_body:
|
|
||||||
from urllib.parse import parse_qs
|
|
||||||
parsed = parse_qs(raw_body.decode(), keep_blank_values=True)
|
|
||||||
data.update({k: v[0] if len(v) == 1 else v for k, v in parsed.items()})
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"TOKEN body parse error: {e}")
|
|
||||||
|
|
||||||
logger.info(f"TOKEN {request.method} data_keys={list(data.keys())} "
|
|
||||||
f"grant_type={data.get('grant_type', '(missing)')}")
|
|
||||||
|
|
||||||
# Bare GET with no params = probe/ping, return 200 with metadata
|
|
||||||
if request.method == "GET" and not data:
|
|
||||||
return JSONResponse({
|
|
||||||
"grant_types_supported": ["authorization_code", "client_credentials"],
|
|
||||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
|
||||||
"code_challenge_methods_supported": ["S256"],
|
|
||||||
})
|
|
||||||
|
|
||||||
grant_type = data.get("grant_type", "")
|
|
||||||
|
|
||||||
logger.debug(f"Token request: method={request.method} grant_type={grant_type} params={list(data.keys())}")
|
|
||||||
|
|
||||||
# If there's a code param but no grant_type, assume authorization_code
|
|
||||||
if not grant_type and "code" in data:
|
|
||||||
grant_type = "authorization_code"
|
|
||||||
|
|
||||||
if grant_type == "authorization_code":
|
|
||||||
return await _handle_auth_code_grant(data, request)
|
|
||||||
elif grant_type == "client_credentials":
|
|
||||||
return await _handle_client_credentials_grant(data, request)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Token request with unsupported grant_type={grant_type!r}, data keys={list(data.keys())}")
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "unsupported_grant_type",
|
|
||||||
"error_description": f"Supported: authorization_code, client_credentials. Got: {grant_type!r}"},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_auth_code_grant(data: dict, request: Request) -> JSONResponse:
|
|
||||||
"""Exchange authorization code + PKCE verifier for access token."""
|
|
||||||
code = data.get("code", "")
|
|
||||||
client_id = data.get("client_id", "")
|
|
||||||
code_verifier = data.get("code_verifier", "")
|
|
||||||
redirect_uri = data.get("redirect_uri", "")
|
|
||||||
|
|
||||||
_cleanup_expired()
|
|
||||||
|
|
||||||
if code not in _auth_codes:
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "invalid_grant", "error_description": "Invalid or expired authorization code"},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
code_data = _auth_codes.pop(code)
|
|
||||||
|
|
||||||
# Verify client_id matches
|
|
||||||
if not secrets.compare_digest(client_id, code_data["client_id"]):
|
|
||||||
return JSONResponse({"error": "invalid_client"}, status_code=401)
|
|
||||||
|
|
||||||
# Verify redirect_uri matches
|
|
||||||
if redirect_uri and redirect_uri != code_data["redirect_uri"]:
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "invalid_grant", "error_description": "redirect_uri mismatch"},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify PKCE
|
|
||||||
if code_data["code_challenge"]:
|
|
||||||
if not code_verifier:
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "invalid_request", "error_description": "code_verifier required"},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
if not _verify_pkce(code_verifier, code_data["code_challenge"]):
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "invalid_grant", "error_description": "PKCE verification failed"},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Issue access token
|
|
||||||
access_token = secrets.token_urlsafe(48)
|
|
||||||
_active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME
|
|
||||||
|
|
||||||
logger.info(f"OAuth token issued via auth code to {request.client.host}")
|
|
||||||
return JSONResponse({
|
|
||||||
"access_token": access_token,
|
|
||||||
"token_type": "Bearer",
|
|
||||||
"expires_in": TOKEN_LIFETIME,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_client_credentials_grant(data: dict, request: Request) -> JSONResponse:
|
|
||||||
"""Direct client_id + client_secret exchange for access token."""
|
|
||||||
client_id = data.get("client_id", "")
|
|
||||||
client_secret = data.get("client_secret", "")
|
|
||||||
|
|
||||||
try:
|
|
||||||
expected_id, expected_secret = _get_oauth_credentials()
|
|
||||||
except RuntimeError:
|
|
||||||
return JSONResponse({"error": "server_error"}, status_code=500)
|
|
||||||
|
|
||||||
if not secrets.compare_digest(client_id, expected_id) or \
|
|
||||||
not secrets.compare_digest(client_secret, expected_secret):
|
|
||||||
logger.warning(f"OAuth client_credentials auth failed from {request.client.host}")
|
|
||||||
return JSONResponse({"error": "invalid_client"}, status_code=401)
|
|
||||||
|
|
||||||
access_token = secrets.token_urlsafe(48)
|
|
||||||
_active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME
|
|
||||||
_cleanup_expired()
|
|
||||||
|
|
||||||
logger.info(f"OAuth token issued via client_credentials to {request.client.host}")
|
|
||||||
return JSONResponse({
|
|
||||||
"access_token": access_token,
|
|
||||||
"token_type": "Bearer",
|
|
||||||
"expires_in": TOKEN_LIFETIME,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
async def oauth_metadata(request: Request) -> JSONResponse:
|
|
||||||
"""OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
|
|
||||||
base = str(request.base_url).rstrip("/")
|
|
||||||
return JSONResponse({
|
|
||||||
"issuer": base,
|
|
||||||
"authorization_endpoint": f"{base}/authorize",
|
|
||||||
"token_endpoint": f"{base}/token",
|
|
||||||
"registration_endpoint": f"{base}/register",
|
|
||||||
"token_endpoint_auth_methods_supported": ["none"],
|
|
||||||
"grant_types_supported": ["authorization_code"],
|
|
||||||
"response_types_supported": ["code"],
|
|
||||||
"code_challenge_methods_supported": ["S256"],
|
|
||||||
"scopes_supported": ["mcp"],
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
async def protected_resource_metadata(request: Request) -> JSONResponse:
|
|
||||||
"""OAuth 2.0 Protected Resource Metadata (RFC 9728)."""
|
|
||||||
base = str(request.base_url).rstrip("/")
|
|
||||||
return JSONResponse({
|
|
||||||
"resource": base,
|
|
||||||
"authorization_servers": [base],
|
|
||||||
"scopes_supported": ["mcp"],
|
|
||||||
"bearer_methods_supported": ["header"],
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
# Routes to add to the app
|
|
||||||
async def register_endpoint(request: Request) -> JSONResponse:
|
|
||||||
"""OAuth 2.0 Dynamic Client Registration (stub).
|
|
||||||
|
|
||||||
Claude may try to register dynamically. Since we use pre-configured
|
|
||||||
credentials, just return the existing client info.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
client_id, _ = _get_oauth_credentials()
|
|
||||||
except RuntimeError:
|
|
||||||
return JSONResponse({"error": "server_error"}, status_code=500)
|
|
||||||
|
|
||||||
base = str(request.base_url).rstrip("/")
|
|
||||||
logger.info(f"REGISTER endpoint called: method={request.method}")
|
|
||||||
if request.method == "POST":
|
|
||||||
body = await request.body()
|
|
||||||
logger.info(f"REGISTER body: {body[:500]}")
|
|
||||||
|
|
||||||
return JSONResponse({
|
|
||||||
"client_id": client_id,
|
|
||||||
"client_name": "claude-desktop",
|
|
||||||
"redirect_uris": ["https://claude.ai/api/mcp/auth_callback"],
|
|
||||||
"grant_types": ["authorization_code"],
|
|
||||||
"response_types": ["code"],
|
|
||||||
"token_endpoint_auth_method": "none",
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
auth_routes = [
|
|
||||||
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
|
|
||||||
Route("/.well-known/oauth-protected-resource", protected_resource_metadata, methods=["GET"]),
|
|
||||||
Route("/.well-known/oauth-protected-resource/mcp", protected_resource_metadata, methods=["GET"]),
|
|
||||||
Route("/authorize", authorize_endpoint, methods=["GET"]),
|
|
||||||
Route("/token", token_endpoint, methods=["GET", "POST"]),
|
|
||||||
Route("/register", register_endpoint, methods=["GET", "POST"]),
|
|
||||||
]
|
|
||||||
|
|
@ -5,169 +5,20 @@ import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
import secrets
|
from fastmcp.server.auth.providers.in_memory import InMemoryOAuthProvider
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
|
|
||||||
from fastmcp.server.auth import OAuthProvider, AccessToken
|
|
||||||
from fastmcp.server.auth.auth import ClientRegistrationOptions
|
from fastmcp.server.auth.auth import ClientRegistrationOptions
|
||||||
from mcp.server.auth.provider import AuthorizationParams
|
|
||||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
|
|
||||||
from .db import Database
|
from .db import Database
|
||||||
from .config import get_group_chat_id, MCP_HOST, MCP_PORT
|
from .config import get_group_chat_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will be initialized in __main__ with shared db instance
|
|
||||||
db: Database | None = None
|
db: Database | None = None
|
||||||
|
|
||||||
TOKEN_LIFETIME = 3600
|
oauth = InMemoryOAuthProvider(
|
||||||
|
|
||||||
|
|
||||||
class HomelabOAuth(OAuthProvider):
|
|
||||||
"""Concrete OAuth provider with in-memory storage."""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._clients: dict[str, OAuthClientInformationFull] = {}
|
|
||||||
self._auth_codes: dict[str, dict] = {} # code -> {client_id, code_challenge, redirect_uri, expires}
|
|
||||||
self._tokens: dict[str, dict] = {} # token_hash -> {client_id, scopes, expires}
|
|
||||||
self._refresh_tokens: dict[str, dict] = {} # refresh_hash -> {client_id, scopes}
|
|
||||||
|
|
||||||
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
|
||||||
self._clients[client_info.client_id] = client_info
|
|
||||||
logger.info(f"Registered OAuth client: {client_info.client_id} ({client_info.client_name})")
|
|
||||||
|
|
||||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
|
||||||
return self._clients.get(client_id)
|
|
||||||
|
|
||||||
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
|
||||||
"""Auto-approve and redirect back with auth code (single-user setup)."""
|
|
||||||
code = await self.create_authorization_code(
|
|
||||||
client, params.code_challenge, str(params.redirect_uri), params.scopes,
|
|
||||||
)
|
|
||||||
redirect_params = {"code": code}
|
|
||||||
if params.state:
|
|
||||||
redirect_params["state"] = params.state
|
|
||||||
redirect_uri = str(params.redirect_uri)
|
|
||||||
separator = "&" if "?" in redirect_uri else "?"
|
|
||||||
return f"{redirect_uri}{separator}{urlencode(redirect_params)}"
|
|
||||||
|
|
||||||
async def create_authorization_code(
|
|
||||||
self, client: OAuthClientInformationFull, code_challenge: str | None,
|
|
||||||
redirect_uri: str | None, scopes: list[str] | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
code = secrets.token_urlsafe(32)
|
|
||||||
self._auth_codes[code] = {
|
|
||||||
"client_id": client.client_id,
|
|
||||||
"code_challenge": code_challenge,
|
|
||||||
"redirect_uri": redirect_uri or str(client.redirect_uris[0]),
|
|
||||||
"scopes": scopes or [],
|
|
||||||
"expires": time.time() + 300,
|
|
||||||
}
|
|
||||||
logger.info(f"Auth code issued for client {client.client_id}")
|
|
||||||
return code
|
|
||||||
|
|
||||||
async def exchange_authorization_code(
|
|
||||||
self, client: OAuthClientInformationFull, code: str,
|
|
||||||
code_verifier: str | None = None, redirect_uri: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> OAuthToken:
|
|
||||||
if code not in self._auth_codes:
|
|
||||||
raise ValueError("Invalid authorization code")
|
|
||||||
|
|
||||||
code_data = self._auth_codes.pop(code)
|
|
||||||
|
|
||||||
if code_data["expires"] < time.time():
|
|
||||||
raise ValueError("Authorization code expired")
|
|
||||||
|
|
||||||
if code_data["client_id"] != client.client_id:
|
|
||||||
raise ValueError("Client ID mismatch")
|
|
||||||
|
|
||||||
# PKCE verification
|
|
||||||
if code_data["code_challenge"] and code_verifier:
|
|
||||||
import base64
|
|
||||||
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
|
||||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
|
||||||
if computed != code_data["code_challenge"]:
|
|
||||||
raise ValueError("PKCE verification failed")
|
|
||||||
|
|
||||||
# Issue tokens
|
|
||||||
access_token = secrets.token_urlsafe(48)
|
|
||||||
refresh_token = secrets.token_urlsafe(48)
|
|
||||||
|
|
||||||
self._tokens[hashlib.sha256(access_token.encode()).hexdigest()] = {
|
|
||||||
"client_id": client.client_id,
|
|
||||||
"scopes": code_data["scopes"],
|
|
||||||
"expires": time.time() + TOKEN_LIFETIME,
|
|
||||||
}
|
|
||||||
self._refresh_tokens[hashlib.sha256(refresh_token.encode()).hexdigest()] = {
|
|
||||||
"client_id": client.client_id,
|
|
||||||
"scopes": code_data["scopes"],
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Token issued for client {client.client_id}")
|
|
||||||
return OAuthToken(
|
|
||||||
access_token=access_token,
|
|
||||||
token_type="Bearer",
|
|
||||||
expires_in=TOKEN_LIFETIME,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
scope=" ".join(code_data["scopes"]) if code_data["scopes"] else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_access_token(self, token: str) -> AccessToken | None:
|
|
||||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
||||||
data = self._tokens.get(token_hash)
|
|
||||||
if not data or data["expires"] < time.time():
|
|
||||||
return None
|
|
||||||
return AccessToken(
|
|
||||||
token=token,
|
|
||||||
client_id=data["client_id"],
|
|
||||||
scopes=data["scopes"],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def exchange_refresh_token(
|
|
||||||
self, client: OAuthClientInformationFull,
|
|
||||||
refresh_token: str, scopes: list[str] | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> OAuthToken:
|
|
||||||
refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
|
||||||
data = self._refresh_tokens.get(refresh_hash)
|
|
||||||
if not data or data["client_id"] != client.client_id:
|
|
||||||
raise ValueError("Invalid refresh token")
|
|
||||||
|
|
||||||
# Issue new tokens
|
|
||||||
new_access = secrets.token_urlsafe(48)
|
|
||||||
new_refresh = secrets.token_urlsafe(48)
|
|
||||||
|
|
||||||
self._tokens[hashlib.sha256(new_access.encode()).hexdigest()] = {
|
|
||||||
"client_id": client.client_id,
|
|
||||||
"scopes": scopes or data["scopes"],
|
|
||||||
"expires": time.time() + TOKEN_LIFETIME,
|
|
||||||
}
|
|
||||||
# Rotate refresh token
|
|
||||||
del self._refresh_tokens[refresh_hash]
|
|
||||||
self._refresh_tokens[hashlib.sha256(new_refresh.encode()).hexdigest()] = {
|
|
||||||
"client_id": client.client_id,
|
|
||||||
"scopes": scopes or data["scopes"],
|
|
||||||
}
|
|
||||||
|
|
||||||
return OAuthToken(
|
|
||||||
access_token=new_access,
|
|
||||||
token_type="Bearer",
|
|
||||||
expires_in=TOKEN_LIFETIME,
|
|
||||||
refresh_token=new_refresh,
|
|
||||||
scope=" ".join(scopes or data["scopes"]) if (scopes or data["scopes"]) else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
oauth = HomelabOAuth(
|
|
||||||
base_url="https://mcp.georgsen.dk",
|
base_url="https://mcp.georgsen.dk",
|
||||||
client_registration_options=ClientRegistrationOptions(enabled=True),
|
client_registration_options=ClientRegistrationOptions(enabled=True),
|
||||||
)
|
)
|
||||||
|
|
@ -185,7 +36,6 @@ mcp = FastMCP(
|
||||||
|
|
||||||
|
|
||||||
def init(database: Database):
|
def init(database: Database):
|
||||||
"""Set the shared database instance."""
|
|
||||||
global db
|
global db
|
||||||
db = database
|
db = database
|
||||||
|
|
||||||
|
|
@ -225,7 +75,6 @@ def pull_updates(since_id: int = 0, since: str | None = None, limit: int = 50) -
|
||||||
else:
|
else:
|
||||||
messages = db.get_messages_since_id(since_id, limit)
|
messages = db.get_messages_since_id(since_id, limit)
|
||||||
|
|
||||||
# Enrich with attachment info
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if msg["has_attachment"]:
|
if msg["has_attachment"]:
|
||||||
msg["attachments"] = db.get_attachments_for_message(msg["id"])
|
msg["attachments"] = db.get_attachments_for_message(msg["id"])
|
||||||
|
|
@ -252,7 +101,7 @@ def queue_status() -> str:
|
||||||
return json.dumps(status)
|
return json.dumps(status)
|
||||||
|
|
||||||
|
|
||||||
# Custom non-MCP routes (no auth required)
|
# Custom non-MCP routes (no auth required - local access only)
|
||||||
async def ingest_message(request: Request) -> JSONResponse:
|
async def ingest_message(request: Request) -> JSONResponse:
|
||||||
"""HTTP endpoint for local services to log messages into the bridge."""
|
"""HTTP endpoint for local services to log messages into the bridge."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -285,9 +134,7 @@ async def ingest_message(request: Request) -> JSONResponse:
|
||||||
if msg_id is None:
|
if msg_id is None:
|
||||||
return JSONResponse({"ok": True, "duplicate": True})
|
return JSONResponse({"ok": True, "duplicate": True})
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}")
|
||||||
f"Ingested message {telegram_message_id} from {data.get('sender_name', 'unknown')}"
|
|
||||||
)
|
|
||||||
return JSONResponse({"ok": True, "id": msg_id})
|
return JSONResponse({"ok": True, "id": msg_id})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue