106 lines
3 KiB
Python
106 lines
3 KiB
Python
"""Database session management for MoAI.
|
|
|
|
Provides async session management using SQLAlchemy 2.0 async support.
|
|
|
|
Usage pattern:
|
|
from moai.core.database import init_db, create_tables, get_session, close_db
|
|
|
|
# Initialize on startup
|
|
init_db("sqlite+aiosqlite:///./moai.db")
|
|
await create_tables()
|
|
|
|
# Use sessions for database operations
|
|
async with get_session() as session:
|
|
project = Project(name="My Project")
|
|
session.add(project)
|
|
# Auto-commits on context exit, rollback on exception
|
|
|
|
# Cleanup on shutdown
|
|
await close_db()
|
|
"""
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
|
|
# Module-level state
|
|
DATABASE_URL: str = "sqlite+aiosqlite:///./moai.db"
|
|
engine: AsyncEngine | None = None
|
|
async_session_factory: async_sessionmaker[AsyncSession] | None = None
|
|
|
|
|
|
def init_db(url: str | None = None) -> None:
|
|
"""Initialize the database engine and session factory.
|
|
|
|
Args:
|
|
url: Database URL. Defaults to module-level DATABASE_URL if not provided.
|
|
Use "sqlite+aiosqlite:///:memory:" for in-memory testing.
|
|
"""
|
|
global engine, async_session_factory, DATABASE_URL
|
|
|
|
if url is not None:
|
|
DATABASE_URL = url
|
|
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
|
|
async def create_tables() -> None:
|
|
"""Create all database tables defined in models.
|
|
|
|
Must be called after init_db(). Creates tables if they don't exist.
|
|
"""
|
|
from moai.core.models import Base
|
|
|
|
if engine is None:
|
|
raise RuntimeError("Database not initialized. Call init_db() first.")
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Async context manager providing a database session.
|
|
|
|
Yields:
|
|
AsyncSession: Database session for operations.
|
|
|
|
The session auto-commits on successful context exit.
|
|
On exception, the session is rolled back automatically.
|
|
|
|
Example:
|
|
async with get_session() as session:
|
|
project = Project(name="Test")
|
|
session.add(project)
|
|
# Commits automatically on exit
|
|
"""
|
|
if async_session_factory is None:
|
|
raise RuntimeError("Database not initialized. Call init_db() first.")
|
|
|
|
async with async_session_factory() as session:
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
|
|
|
|
async def close_db() -> None:
|
|
"""Dispose of the database engine and release connections.
|
|
|
|
Should be called during application shutdown.
|
|
"""
|
|
global engine, async_session_factory
|
|
|
|
if engine is not None:
|
|
await engine.dispose()
|
|
engine = None
|
|
async_session_factory = None
|