"""Tests for MoAI SQLAlchemy models. Tests verify model creation, relationships, and cascade delete behavior using an in-memory SQLite database. """ import pytest from sqlalchemy import select from moai.core.database import close_db, create_tables, get_session, init_db from moai.core.models import ( Consensus, Discussion, DiscussionStatus, DiscussionType, Message, Project, Round, RoundType, ) @pytest.fixture async def db_session(): """Provide a database session with in-memory SQLite for testing.""" init_db("sqlite+aiosqlite:///:memory:") await create_tables() async with get_session() as session: yield session await close_db() async def test_create_project(db_session): """Test creating a project with basic attributes.""" project = Project(name="Test Project", models=["claude", "gpt"]) db_session.add(project) await db_session.flush() assert project.id is not None assert len(project.id) == 36 # UUID format assert project.name == "Test Project" assert project.models == ["claude", "gpt"] assert project.created_at is not None assert project.updated_at is not None async def test_create_discussion_with_project(db_session): """Test creating a discussion linked to a project.""" project = Project(name="Test Project", models=["claude"]) db_session.add(project) await db_session.flush() discussion = Discussion( project_id=project.id, question="What is the meaning of life?", type=DiscussionType.OPEN, ) db_session.add(discussion) await db_session.flush() assert discussion.id is not None assert discussion.project_id == project.id assert discussion.status == DiscussionStatus.ACTIVE assert discussion.type == DiscussionType.OPEN # Verify relationship await db_session.refresh(project, ["discussions"]) assert len(project.discussions) == 1 assert project.discussions[0].id == discussion.id async def test_create_full_discussion_chain(db_session): """Test creating a full chain: Project -> Discussion -> Round -> Message.""" # Create project project = Project(name="Full Chain Test", models=["claude", "gpt", "gemini"]) db_session.add(project) await db_session.flush() # Create discussion discussion = Discussion( project_id=project.id, question="How should we approach this problem?", type=DiscussionType.DISCUSS, ) db_session.add(discussion) await db_session.flush() # Create round round_ = Round( discussion_id=discussion.id, round_number=1, type=RoundType.SEQUENTIAL, ) db_session.add(round_) await db_session.flush() # Create messages message1 = Message( round_id=round_.id, model="claude", content="I think we should consider option A.", is_direct=False, ) message2 = Message( round_id=round_.id, model="gpt", content="I agree with Claude, option A seems best.", is_direct=False, ) db_session.add_all([message1, message2]) await db_session.flush() # Verify all relationships await db_session.refresh(round_, ["messages"]) assert len(round_.messages) == 2 assert round_.discussion_id == discussion.id await db_session.refresh(discussion, ["rounds"]) assert len(discussion.rounds) == 1 assert discussion.project_id == project.id await db_session.refresh(project, ["discussions"]) assert len(project.discussions) == 1 async def test_create_consensus(db_session): """Test creating a consensus for a discussion.""" project = Project(name="Consensus Test", models=["claude"]) db_session.add(project) await db_session.flush() discussion = Discussion( project_id=project.id, question="What should we do?", type=DiscussionType.OPEN, ) db_session.add(discussion) await db_session.flush() consensus = Consensus( discussion_id=discussion.id, agreements=["We should prioritize user experience", "Performance is important"], disagreements=[{"topic": "Timeline", "positions": {"claude": "2 weeks", "gpt": "3 weeks"}}], generated_by="claude", ) db_session.add(consensus) await db_session.flush() # Verify consensus attributes assert consensus.id is not None assert consensus.discussion_id == discussion.id assert len(consensus.agreements) == 2 assert len(consensus.disagreements) == 1 assert consensus.generated_by == "claude" assert consensus.generated_at is not None # Verify relationship await db_session.refresh(discussion, ["consensus"]) assert discussion.consensus is not None assert discussion.consensus.id == consensus.id async def test_project_cascade_delete(db_session): """Test that deleting a project cascades to all children.""" # Create full hierarchy project = Project(name="Cascade Test", models=["claude"]) db_session.add(project) await db_session.flush() discussion = Discussion( project_id=project.id, question="Test question", type=DiscussionType.OPEN, ) db_session.add(discussion) await db_session.flush() round_ = Round( discussion_id=discussion.id, round_number=1, type=RoundType.PARALLEL, ) db_session.add(round_) await db_session.flush() message = Message( round_id=round_.id, model="claude", content="Test message", ) db_session.add(message) await db_session.flush() # Store IDs for verification project_id = project.id discussion_id = discussion.id round_id = round_.id message_id = message.id # Delete project await db_session.delete(project) await db_session.flush() # Verify all children are deleted (cascade) result = await db_session.execute(select(Project).where(Project.id == project_id)) assert result.scalar_one_or_none() is None result = await db_session.execute(select(Discussion).where(Discussion.id == discussion_id)) assert result.scalar_one_or_none() is None result = await db_session.execute(select(Round).where(Round.id == round_id)) assert result.scalar_one_or_none() is None result = await db_session.execute(select(Message).where(Message.id == message_id)) assert result.scalar_one_or_none() is None