feat: switch OAuth to authorization code + PKCE flow
Claude Desktop uses authorization code flow, not client credentials. Added /authorize endpoint that auto-approves (single-user setup) and redirects with code. Token endpoint now supports both grant types. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
205b978b89
commit
15e3582787
2 changed files with 168 additions and 40 deletions
|
|
@ -24,6 +24,7 @@ logger = logging.getLogger("mcp_bridge")
|
||||||
# Paths that don't require auth
|
# Paths that don't require auth
|
||||||
PUBLIC_PATHS = {
|
PUBLIC_PATHS = {
|
||||||
"/.well-known/oauth-authorization-server",
|
"/.well-known/oauth-authorization-server",
|
||||||
|
"/authorize",
|
||||||
"/token",
|
"/token",
|
||||||
"/api/health",
|
"/api/health",
|
||||||
"/api/ingest", # Local-only, not exposed via NPM
|
"/api/ingest", # Local-only, not exposed via NPM
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,26 @@
|
||||||
"""OAuth 2.0 client credentials auth for MCP server."""
|
"""OAuth 2.0 authorization code + PKCE auth for MCP server."""
|
||||||
|
|
||||||
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
|
from urllib.parse import urlencode, urlparse, parse_qs
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse, RedirectResponse
|
||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
|
|
||||||
from .config import load_credentials
|
from .config import load_credentials
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# In-memory token store: token_hash -> expiry timestamp
|
# In-memory stores
|
||||||
_active_tokens: dict[str, float] = {}
|
_active_tokens: dict[str, float] = {} # token_hash -> expiry
|
||||||
|
_auth_codes: dict[str, dict] = {} # code -> {client_id, redirect_uri, code_challenge, expires}
|
||||||
TOKEN_LIFETIME = 3600 # 1 hour
|
TOKEN_LIFETIME = 3600 # 1 hour
|
||||||
|
CODE_LIFETIME = 300 # 5 minutes
|
||||||
|
|
||||||
|
|
||||||
def _get_oauth_credentials() -> tuple[str, str]:
|
def _get_oauth_credentials() -> tuple[str, str]:
|
||||||
|
|
@ -29,35 +33,105 @@ def _get_oauth_credentials() -> tuple[str, str]:
|
||||||
return client_id, client_secret
|
return client_id, client_secret
|
||||||
|
|
||||||
|
|
||||||
def _hash_token(token: str) -> str:
|
def _hash(value: str) -> str:
|
||||||
return hashlib.sha256(token.encode()).hexdigest()
|
return hashlib.sha256(value.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_expired():
|
def _cleanup_expired():
|
||||||
"""Remove expired tokens."""
|
"""Remove expired tokens and codes."""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
expired = [h for h, exp in _active_tokens.items() if exp < now]
|
for store in (_active_tokens, _auth_codes):
|
||||||
for h in expired:
|
expired = [k for k, v in store.items()
|
||||||
del _active_tokens[h]
|
if (v if isinstance(v, float) else v.get("expires", 0)) < now]
|
||||||
|
for k in expired:
|
||||||
|
del store[k]
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_pkce(code_verifier: str, code_challenge: str) -> bool:
|
||||||
|
"""Verify PKCE S256 challenge."""
|
||||||
|
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||||
|
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||||
|
return secrets.compare_digest(computed, code_challenge)
|
||||||
|
|
||||||
|
|
||||||
def validate_bearer_token(token: str) -> bool:
|
def validate_bearer_token(token: str) -> bool:
|
||||||
"""Check if a bearer token is valid and not expired."""
|
"""Check if a bearer token is valid and not expired."""
|
||||||
_cleanup_expired()
|
_cleanup_expired()
|
||||||
token_hash = _hash_token(token)
|
token_hash = _hash(token)
|
||||||
return token_hash in _active_tokens and _active_tokens[token_hash] > time.time()
|
return token_hash in _active_tokens and _active_tokens[token_hash] > time.time()
|
||||||
|
|
||||||
|
|
||||||
|
async def authorize_endpoint(request: Request):
|
||||||
|
"""OAuth 2.0 Authorization endpoint.
|
||||||
|
|
||||||
|
GET /authorize?response_type=code&client_id=...&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...
|
||||||
|
|
||||||
|
Single-user setup: auto-approves if client_id matches, redirects with code.
|
||||||
|
"""
|
||||||
|
params = dict(request.query_params)
|
||||||
|
|
||||||
|
response_type = params.get("response_type")
|
||||||
|
client_id = params.get("client_id", "")
|
||||||
|
redirect_uri = params.get("redirect_uri", "")
|
||||||
|
code_challenge = params.get("code_challenge", "")
|
||||||
|
code_challenge_method = params.get("code_challenge_method", "")
|
||||||
|
state = params.get("state", "")
|
||||||
|
|
||||||
|
if response_type != "code":
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "unsupported_response_type"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
expected_id, _ = _get_oauth_credentials()
|
||||||
|
except RuntimeError:
|
||||||
|
return JSONResponse({"error": "server_error"}, status_code=500)
|
||||||
|
|
||||||
|
if not secrets.compare_digest(client_id, expected_id):
|
||||||
|
logger.warning(f"OAuth authorize: invalid client_id from {request.client.host}")
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "invalid_client", "error_description": "Unknown client_id"},
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
if code_challenge_method and code_challenge_method != "S256":
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "invalid_request", "error_description": "Only S256 code_challenge_method supported"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-approve: generate authorization code
|
||||||
|
code = secrets.token_urlsafe(32)
|
||||||
|
_auth_codes[code] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"code_challenge": code_challenge,
|
||||||
|
"expires": time.time() + CODE_LIFETIME,
|
||||||
|
}
|
||||||
|
_cleanup_expired()
|
||||||
|
|
||||||
|
logger.info(f"OAuth authorization code issued for {client_id[:12]}... -> {redirect_uri}")
|
||||||
|
|
||||||
|
# Redirect back with code
|
||||||
|
redirect_params = {"code": code}
|
||||||
|
if state:
|
||||||
|
redirect_params["state"] = state
|
||||||
|
separator = "&" if "?" in redirect_uri else "?"
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}{separator}{urlencode(redirect_params)}",
|
||||||
|
status_code=302,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def token_endpoint(request: Request) -> JSONResponse:
|
async def token_endpoint(request: Request) -> JSONResponse:
|
||||||
"""OAuth 2.0 token endpoint (client_credentials grant).
|
"""OAuth 2.0 token endpoint.
|
||||||
|
|
||||||
POST /token
|
Supports:
|
||||||
Content-Type: application/x-www-form-urlencoded
|
- grant_type=authorization_code (with PKCE)
|
||||||
|
- grant_type=client_credentials (direct)
|
||||||
grant_type=client_credentials&client_id=...&client_secret=...
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Accept both form-encoded and JSON
|
|
||||||
content_type = request.headers.get("content-type", "")
|
content_type = request.headers.get("content-type", "")
|
||||||
if "application/json" in content_type:
|
if "application/json" in content_type:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
@ -71,38 +145,92 @@ async def token_endpoint(request: Request) -> JSONResponse:
|
||||||
)
|
)
|
||||||
|
|
||||||
grant_type = data.get("grant_type")
|
grant_type = data.get("grant_type")
|
||||||
if grant_type != "client_credentials":
|
|
||||||
|
if grant_type == "authorization_code":
|
||||||
|
return await _handle_auth_code_grant(data, request)
|
||||||
|
elif grant_type == "client_credentials":
|
||||||
|
return await _handle_client_credentials_grant(data, request)
|
||||||
|
else:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"error": "unsupported_grant_type", "error_description": "Only client_credentials is supported"},
|
{"error": "unsupported_grant_type",
|
||||||
|
"error_description": "Supported: authorization_code, client_credentials"},
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_auth_code_grant(data: dict, request: Request) -> JSONResponse:
|
||||||
|
"""Exchange authorization code + PKCE verifier for access token."""
|
||||||
|
code = data.get("code", "")
|
||||||
|
client_id = data.get("client_id", "")
|
||||||
|
code_verifier = data.get("code_verifier", "")
|
||||||
|
redirect_uri = data.get("redirect_uri", "")
|
||||||
|
|
||||||
|
_cleanup_expired()
|
||||||
|
|
||||||
|
if code not in _auth_codes:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "invalid_grant", "error_description": "Invalid or expired authorization code"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
code_data = _auth_codes.pop(code)
|
||||||
|
|
||||||
|
# Verify client_id matches
|
||||||
|
if not secrets.compare_digest(client_id, code_data["client_id"]):
|
||||||
|
return JSONResponse({"error": "invalid_client"}, status_code=401)
|
||||||
|
|
||||||
|
# Verify redirect_uri matches
|
||||||
|
if redirect_uri and redirect_uri != code_data["redirect_uri"]:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "invalid_grant", "error_description": "redirect_uri mismatch"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify PKCE
|
||||||
|
if code_data["code_challenge"]:
|
||||||
|
if not code_verifier:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "invalid_request", "error_description": "code_verifier required"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
if not _verify_pkce(code_verifier, code_data["code_challenge"]):
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "invalid_grant", "error_description": "PKCE verification failed"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Issue access token
|
||||||
|
access_token = secrets.token_urlsafe(48)
|
||||||
|
_active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME
|
||||||
|
|
||||||
|
logger.info(f"OAuth token issued via auth code to {request.client.host}")
|
||||||
|
return JSONResponse({
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": TOKEN_LIFETIME,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_client_credentials_grant(data: dict, request: Request) -> JSONResponse:
|
||||||
|
"""Direct client_id + client_secret exchange for access token."""
|
||||||
client_id = data.get("client_id", "")
|
client_id = data.get("client_id", "")
|
||||||
client_secret = data.get("client_secret", "")
|
client_secret = data.get("client_secret", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
expected_id, expected_secret = _get_oauth_credentials()
|
expected_id, expected_secret = _get_oauth_credentials()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logger.error("OAuth credentials not configured")
|
return JSONResponse({"error": "server_error"}, status_code=500)
|
||||||
return JSONResponse(
|
|
||||||
{"error": "server_error", "error_description": "Auth not configured"},
|
|
||||||
status_code=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not secrets.compare_digest(client_id, expected_id) or \
|
if not secrets.compare_digest(client_id, expected_id) or \
|
||||||
not secrets.compare_digest(client_secret, expected_secret):
|
not secrets.compare_digest(client_secret, expected_secret):
|
||||||
logger.warning(f"OAuth auth failed from {request.client.host}")
|
logger.warning(f"OAuth client_credentials auth failed from {request.client.host}")
|
||||||
return JSONResponse(
|
return JSONResponse({"error": "invalid_client"}, status_code=401)
|
||||||
{"error": "invalid_client", "error_description": "Invalid client credentials"},
|
|
||||||
status_code=401,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Issue token
|
|
||||||
access_token = secrets.token_urlsafe(48)
|
access_token = secrets.token_urlsafe(48)
|
||||||
_active_tokens[_hash_token(access_token)] = time.time() + TOKEN_LIFETIME
|
_active_tokens[_hash(access_token)] = time.time() + TOKEN_LIFETIME
|
||||||
_cleanup_expired()
|
_cleanup_expired()
|
||||||
|
|
||||||
logger.info(f"OAuth token issued to {request.client.host}")
|
logger.info(f"OAuth token issued via client_credentials to {request.client.host}")
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
|
|
@ -111,18 +239,16 @@ async def token_endpoint(request: Request) -> JSONResponse:
|
||||||
|
|
||||||
|
|
||||||
async def oauth_metadata(request: Request) -> JSONResponse:
|
async def oauth_metadata(request: Request) -> JSONResponse:
|
||||||
"""OAuth 2.0 Authorization Server Metadata (RFC 8414).
|
"""OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
|
||||||
|
|
||||||
GET /.well-known/oauth-authorization-server
|
|
||||||
"""
|
|
||||||
# Build base URL from request
|
|
||||||
base = str(request.base_url).rstrip("/")
|
base = str(request.base_url).rstrip("/")
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
"issuer": base,
|
"issuer": base,
|
||||||
|
"authorization_endpoint": f"{base}/authorize",
|
||||||
"token_endpoint": f"{base}/token",
|
"token_endpoint": f"{base}/token",
|
||||||
"token_endpoint_auth_methods_supported": ["client_secret_post"],
|
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
||||||
"grant_types_supported": ["client_credentials"],
|
"grant_types_supported": ["authorization_code", "client_credentials"],
|
||||||
"response_types_supported": [],
|
"response_types_supported": ["code"],
|
||||||
|
"code_challenge_methods_supported": ["S256"],
|
||||||
"scopes_supported": ["mcp"],
|
"scopes_supported": ["mcp"],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -130,5 +256,6 @@ async def oauth_metadata(request: Request) -> JSONResponse:
|
||||||
# Routes to add to the app
|
# Routes to add to the app
|
||||||
auth_routes = [
|
auth_routes = [
|
||||||
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
|
Route("/.well-known/oauth-authorization-server", oauth_metadata, methods=["GET"]),
|
||||||
|
Route("/authorize", authorize_endpoint, methods=["GET"]),
|
||||||
Route("/token", token_endpoint, methods=["POST"]),
|
Route("/token", token_endpoint, methods=["POST"]),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue