diff --git a/telegram/bot.py b/telegram/bot.py index b3d629a..75addae 100755 --- a/telegram/bot.py +++ b/telegram/bot.py @@ -68,6 +68,27 @@ def is_authorized(user_id: int) -> bool: """Check if user is authorized.""" return user_id in get_authorized_users() +MODEL_ALIASES = { + "sonnet": "claude-sonnet-4-5-20250929", + "opus": "claude-opus-4-5-20251101", + "haiku": "claude-haiku-4-5-20251001", +} + +def load_persona_for_session(session_name: str) -> dict: + """Load persona with session-level model override applied.""" + session_data = session_manager.get_session(session_name) + persona_name = session_data.get('persona', 'default') + persona = session_manager.load_persona(persona_name) + + # Apply session model override if set + model_override = session_data.get('model_override') + if model_override: + if 'settings' not in persona: + persona['settings'] = {} + persona['settings']['model'] = model_override + + return persona + def make_callbacks(bot, chat_id, session_name: str): """Create callbacks for ClaudeSubprocess bound to specific chat with dynamic typing control. @@ -379,8 +400,7 @@ async def new_session(update: Update, context: ContextTypes.DEFAULT_TYPE): # Spawn subprocess for the new session (but don't send message yet) session_dir = session_manager.get_session_dir(name) - persona_config = session_manager.get_session(name) - persona_data = session_manager.load_persona(persona or 'default') + persona_data = load_persona_for_session(name) # Create callbacks bound to this chat (typing looked up dynamically) callbacks = make_callbacks(context.bot, update.effective_chat.id, name) @@ -462,9 +482,7 @@ async def switch_session_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE) # Auto-spawn subprocess if not alive if name not in subprocesses or not subprocesses[name].is_alive: session_dir = session_manager.get_session_dir(name) - session_data = session_manager.get_session(name) - persona_name = session_data.get('persona', 'default') - persona_data = session_manager.load_persona(persona_name) + persona_data = load_persona_for_session(name) # Create callbacks bound to this chat (typing looked up dynamically) callbacks = make_callbacks(context.bot, update.effective_chat.id, name) @@ -582,9 +600,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE): already_alive = active_session in subprocesses and subprocesses[active_session].is_alive if not already_alive: session_dir = session_manager.get_session_dir(active_session) - session_data = session_manager.get_session(active_session) - persona_name = session_data.get('persona', 'default') - persona_data = session_manager.load_persona(persona_name) + persona_data = load_persona_for_session(active_session) # Create callbacks bound to this chat (typing looked up dynamically) callbacks = make_callbacks(context.bot, update.effective_chat.id, active_session) @@ -681,9 +697,7 @@ async def handle_photo(update: Update, context: ContextTypes.DEFAULT_TYPE): # Get or create subprocess if active_session not in subprocesses or not subprocesses[active_session].is_alive: - session_data = session_manager.get_session(active_session) - persona_name = session_data.get('persona', 'default') - persona_data = session_manager.load_persona(persona_name) + persona_data = load_persona_for_session(active_session) callbacks = make_callbacks(context.bot, update.effective_chat.id, active_session) @@ -758,9 +772,7 @@ async def handle_document(update: Update, context: ContextTypes.DEFAULT_TYPE): # Get or create subprocess if active_session not in subprocesses or not subprocesses[active_session].is_alive: - session_data = session_manager.get_session(active_session) - persona_name = session_data.get('persona', 'default') - persona_data = session_manager.load_persona(persona_name) + persona_data = load_persona_for_session(active_session) callbacks = make_callbacks(context.bot, update.effective_chat.id, active_session) @@ -779,6 +791,50 @@ async def handle_document(update: Update, context: ContextTypes.DEFAULT_TYPE): # Send notification directly (not batched) await subprocesses[active_session].send_message(notify_message) +async def model_cmd(update: Update, context: ContextTypes.DEFAULT_TYPE): + """Switch model for current session. Persists across session switches.""" + if not is_authorized(update.effective_user.id): + return + + active_session = session_manager.get_active_session() + if not active_session: + await update.message.reply_text("No active session. Use /new to start one.") + return + + if not context.args: + # Show current model + session_data = session_manager.get_session(active_session) + model_override = session_data.get('model_override') + persona_name = session_data.get('persona', 'default') + persona = session_manager.load_persona(persona_name) + current = model_override or persona.get('settings', {}).get('model', 'default') + aliases = "\n".join(f" {k} → {v}" for k, v in MODEL_ALIASES.items()) + await update.message.reply_text( + f"Current model: {current}\n\nUsage: /model \n\nAliases:\n{aliases}" + ) + return + + model = context.args[0] + # Resolve alias + resolved = MODEL_ALIASES.get(model, model) + + # Persist to session metadata + session_manager.update_session(active_session, model_override=resolved) + + # Terminate current subprocess so next message spawns with new model + if active_session in subprocesses: + if subprocesses[active_session].is_alive: + await subprocesses[active_session].terminate() + del subprocesses[active_session] + + # Clean up batcher too + if active_session in batchers: + await batchers[active_session].flush_immediately() + del batchers[active_session] + + await update.message.reply_text(f"Model set to {resolved} for session '{active_session}'.") + logger.info(f"Model changed to {resolved} for session '{active_session}'") + async def unknown(update: Update, context: ContextTypes.DEFAULT_TYPE): """Handle unknown commands.""" if not is_authorized(update.effective_user.id): @@ -797,6 +853,7 @@ def main(): app.add_handler(CommandHandler("new", new_session)) app.add_handler(CommandHandler("session", switch_session_cmd)) app.add_handler(CommandHandler("archive", archive_session_cmd)) + app.add_handler(CommandHandler("model", model_cmd)) app.add_handler(CommandHandler("status", status)) app.add_handler(CommandHandler("pbs", pbs)) app.add_handler(CommandHandler("pbs_status", pbs_status)) diff --git a/telegram/claude_subprocess.py b/telegram/claude_subprocess.py index 10b0173..3e7d757 100644 --- a/telegram/claude_subprocess.py +++ b/telegram/claude_subprocess.py @@ -174,7 +174,7 @@ class ClaudeSubprocess: ) try: - # Spawn subprocess + # Spawn subprocess (10MB stdout limit for large stream-json lines e.g. image tool results) self._process = await asyncio.create_subprocess_exec( *cmd, stdin=asyncio.subprocess.PIPE, @@ -182,6 +182,7 @@ class ClaudeSubprocess: stderr=asyncio.subprocess.PIPE, cwd=str(self._session_dir), env=env, + limit=10 * 1024 * 1024, ) elapsed = time.monotonic() - self._spawn_time