diff --git a/internal/store/conversations.go b/internal/store/conversations.go new file mode 100644 index 0000000..6190132 --- /dev/null +++ b/internal/store/conversations.go @@ -0,0 +1,172 @@ +package store + +import ( + "context" + "errors" + "fmt" + "time" +) + +// ErrNotFound is returned when a requested resource does not exist. +var ErrNotFound = errors.New("store: not found") + +// Message represents a single chat message stored in the messages table. +type Message struct { + ID string + ConversationID string + Role string + Content string + CreatedAt time.Time +} + +// Conversation is a full conversation with all its messages. +type Conversation struct { + ID string + Model string + StartedAt time.Time + Messages []Message +} + +// ConversationSummary is a lightweight view of a conversation used in list responses. +// It omits individual messages to reduce query cost (T-06-01-04: LIMIT 100 guard). +type ConversationSummary struct { + ID string + Model string + StartedAt time.Time + MessageCount int +} + +// CreateConversation inserts a new conversation row and returns its UUID. +// All parameters are passed as $N placeholders (T-06-01-01). +func (s *Store) CreateConversation(ctx context.Context, model string) (string, error) { + var id string + err := s.pool.QueryRow(ctx, + `INSERT INTO conversations (model) VALUES ($1) RETURNING id`, + model, + ).Scan(&id) + if err != nil { + return "", fmt.Errorf("store: create conversation: %w", err) + } + return id, nil +} + +// AddMessage appends a message to an existing conversation. +// An invalid role value will cause a CHECK constraint violation from the DB (T-06-01-03). +func (s *Store) AddMessage(ctx context.Context, conversationID, role, content string) (string, error) { + var id string + err := s.pool.QueryRow(ctx, + `INSERT INTO messages (conversation_id, role, content) VALUES ($1, $2, $3) RETURNING id`, + conversationID, role, content, + ).Scan(&id) + if err != nil { + return "", fmt.Errorf("store: add message: %w", err) + } + return id, nil +} + +// GetConversation fetches a conversation and all its messages ordered by created_at ASC. +// Returns nil, ErrNotFound if the conversation ID does not exist. +func (s *Store) GetConversation(ctx context.Context, id string) (*Conversation, error) { + rows, err := s.pool.Query(ctx, + `SELECT c.id, c.started_at, c.model, + m.id, m.conversation_id, m.role, m.content, m.created_at + FROM conversations c + LEFT JOIN messages m ON m.conversation_id = c.id + WHERE c.id = $1 + ORDER BY m.created_at ASC NULLS LAST`, + id, + ) + if err != nil { + return nil, fmt.Errorf("store: get conversation query: %w", err) + } + defer rows.Close() + + var conv *Conversation + + for rows.Next() { + var ( + cID string + cStartedAt time.Time + cModel string + // Message fields — nullable because LEFT JOIN may produce NULLs + // when the conversation has no messages. + mID *string + mConversationID *string + mRole *string + mContent *string + mCreatedAt *time.Time + ) + + if err := rows.Scan( + &cID, &cStartedAt, &cModel, + &mID, &mConversationID, &mRole, &mContent, &mCreatedAt, + ); err != nil { + return nil, fmt.Errorf("store: get conversation scan: %w", err) + } + + if conv == nil { + conv = &Conversation{ + ID: cID, + Model: cModel, + StartedAt: cStartedAt, + Messages: []Message{}, + } + } + + if mID != nil { + conv.Messages = append(conv.Messages, Message{ + ID: *mID, + ConversationID: *mConversationID, + Role: *mRole, + Content: *mContent, + CreatedAt: *mCreatedAt, + }) + } + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("store: get conversation rows: %w", err) + } + + if conv == nil { + return nil, ErrNotFound + } + + return conv, nil +} + +// ListConversations returns all conversations ordered by started_at DESC. +// A soft LIMIT 100 guards against unbounded growth on a busy homelab instance (T-06-01-04). +func (s *Store) ListConversations(ctx context.Context) ([]ConversationSummary, error) { + rows, err := s.pool.Query(ctx, + `SELECT c.id, c.started_at, c.model, COUNT(m.id) AS message_count + FROM conversations c + LEFT JOIN messages m ON m.conversation_id = c.id + GROUP BY c.id + ORDER BY c.started_at DESC + LIMIT 100`, + ) + if err != nil { + return nil, fmt.Errorf("store: list conversations query: %w", err) + } + defer rows.Close() + + var summaries []ConversationSummary + for rows.Next() { + var s ConversationSummary + if err := rows.Scan(&s.ID, &s.StartedAt, &s.Model, &s.MessageCount); err != nil { + return nil, fmt.Errorf("store: list conversations scan: %w", err) + } + summaries = append(summaries, s) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("store: list conversations rows: %w", err) + } + + if summaries == nil { + summaries = []ConversationSummary{} + } + + return summaries, nil +} diff --git a/internal/store/store_test.go b/internal/store/store_test.go new file mode 100644 index 0000000..0348d0f --- /dev/null +++ b/internal/store/store_test.go @@ -0,0 +1,270 @@ +//go:build integration + +package store_test + +import ( + "context" + "os" + "testing" + "time" + + . "git.georgsen.dk/hwlab/internal/store" +) + +func testDSN(t *testing.T) string { + t.Helper() + dsn := os.Getenv("HWLAB_DATABASE_URL") + if dsn == "" { + t.Skip("HWLAB_DATABASE_URL not set, skipping integration tests") + } + return dsn +} + +// TestNewStore verifies basic connection lifecycle. +func TestNewStore(t *testing.T) { + ctx := context.Background() + + t.Run("invalid DSN returns error", func(t *testing.T) { + _, err := NewStore(ctx, "invalid-dsn") + if err == nil { + t.Fatal("expected error for invalid DSN, got nil") + } + }) + + t.Run("valid DSN returns store", func(t *testing.T) { + dsn := testDSN(t) + s, err := NewStore(ctx, dsn) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer s.Close() + + if s.Pool() == nil { + t.Fatal("expected non-nil pool") + } + }) + + t.Run("Close does not panic", func(t *testing.T) { + dsn := testDSN(t) + s, err := NewStore(ctx, dsn) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + s.Close() // must not panic + }) +} + +// TestRunMigrations verifies idempotent table creation. +func TestRunMigrations(t *testing.T) { + ctx := context.Background() + dsn := testDSN(t) + + s, err := NewStore(ctx, dsn) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer s.Close() + + t.Run("first run creates tables", func(t *testing.T) { + if err := RunMigrations(ctx, s.Pool()); err != nil { + t.Fatalf("RunMigrations failed: %v", err) + } + }) + + t.Run("second run is idempotent", func(t *testing.T) { + if err := RunMigrations(ctx, s.Pool()); err != nil { + t.Fatalf("RunMigrations second call failed: %v", err) + } + }) +} + +// TestConversationCRUD verifies full CRUD round-trip. +func TestConversationCRUD(t *testing.T) { + ctx := context.Background() + dsn := testDSN(t) + + s, err := NewStore(ctx, dsn) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer s.Close() + + if err := RunMigrations(ctx, s.Pool()); err != nil { + t.Fatalf("RunMigrations failed: %v", err) + } + + t.Run("CreateConversation returns ID", func(t *testing.T) { + convID, err := s.CreateConversation(ctx, "gemma-4-e4b") + if err != nil { + t.Fatalf("CreateConversation failed: %v", err) + } + if convID == "" { + t.Fatal("expected non-empty conversation ID") + } + // Cleanup + defer func() { + _, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", convID) + }() + }) + + t.Run("AddMessage returns message ID", func(t *testing.T) { + convID, err := s.CreateConversation(ctx, "gemma-4-e4b") + if err != nil { + t.Fatalf("CreateConversation failed: %v", err) + } + defer func() { + _, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", convID) + }() + + msgID, err := s.AddMessage(ctx, convID, "user", "Hello, AI!") + if err != nil { + t.Fatalf("AddMessage failed: %v", err) + } + if msgID == "" { + t.Fatal("expected non-empty message ID") + } + }) + + t.Run("AddMessage with invalid role returns error", func(t *testing.T) { + convID, err := s.CreateConversation(ctx, "test-model") + if err != nil { + t.Fatalf("CreateConversation failed: %v", err) + } + defer func() { + _, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", convID) + }() + + _, err = s.AddMessage(ctx, convID, "invalid_role", "content") + if err == nil { + t.Fatal("expected error for invalid role, got nil") + } + }) + + t.Run("GetConversation returns ErrNotFound for unknown ID", func(t *testing.T) { + conv, err := s.GetConversation(ctx, "00000000-0000-0000-0000-000000000000") + if err != ErrNotFound { + t.Fatalf("expected ErrNotFound, got err=%v conv=%v", err, conv) + } + if conv != nil { + t.Fatal("expected nil conversation for unknown ID") + } + }) + + t.Run("full round-trip: create, add messages, get", func(t *testing.T) { + convID, err := s.CreateConversation(ctx, "gemma-4-e4b") + if err != nil { + t.Fatalf("CreateConversation failed: %v", err) + } + defer func() { + _, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", convID) + }() + + msg1ID, err := s.AddMessage(ctx, convID, "user", "What is a switch?") + if err != nil { + t.Fatalf("AddMessage (user) failed: %v", err) + } + // Small sleep to ensure ordering by created_at + time.Sleep(5 * time.Millisecond) + + msg2ID, err := s.AddMessage(ctx, convID, "assistant", "A switch is a network device...") + if err != nil { + t.Fatalf("AddMessage (assistant) failed: %v", err) + } + + conv, err := s.GetConversation(ctx, convID) + if err != nil { + t.Fatalf("GetConversation failed: %v", err) + } + if conv.ID != convID { + t.Errorf("expected conv ID %s, got %s", convID, conv.ID) + } + if conv.Model != "gemma-4-e4b" { + t.Errorf("expected model gemma-4-e4b, got %s", conv.Model) + } + if len(conv.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(conv.Messages)) + } + if conv.Messages[0].ID != msg1ID { + t.Errorf("expected first message ID %s, got %s", msg1ID, conv.Messages[0].ID) + } + if conv.Messages[1].ID != msg2ID { + t.Errorf("expected second message ID %s, got %s", msg2ID, conv.Messages[1].ID) + } + if conv.Messages[0].Role != "user" { + t.Errorf("expected first message role 'user', got %s", conv.Messages[0].Role) + } + if conv.Messages[1].Role != "assistant" { + t.Errorf("expected second message role 'assistant', got %s", conv.Messages[1].Role) + } + }) + + t.Run("ListConversations returns summaries ordered by started_at DESC", func(t *testing.T) { + // Create two conversations + conv1ID, err := s.CreateConversation(ctx, "model-a") + if err != nil { + t.Fatalf("CreateConversation 1 failed: %v", err) + } + defer func() { + _, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", conv1ID) + }() + + time.Sleep(5 * time.Millisecond) + + conv2ID, err := s.CreateConversation(ctx, "model-b") + if err != nil { + t.Fatalf("CreateConversation 2 failed: %v", err) + } + defer func() { + _, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", conv2ID) + }() + + // Add a message to conv1 + _, err = s.AddMessage(ctx, conv1ID, "user", "hello") + if err != nil { + t.Fatalf("AddMessage failed: %v", err) + } + + summaries, err := s.ListConversations(ctx) + if err != nil { + t.Fatalf("ListConversations failed: %v", err) + } + + // Find our two conversations in the list + var found1, found2 *ConversationSummary + for i := range summaries { + switch summaries[i].ID { + case conv1ID: + found1 = &summaries[i] + case conv2ID: + found2 = &summaries[i] + } + } + + if found1 == nil { + t.Fatal("conv1 not found in ListConversations") + } + if found2 == nil { + t.Fatal("conv2 not found in ListConversations") + } + if found1.MessageCount != 1 { + t.Errorf("expected conv1 message_count=1, got %d", found1.MessageCount) + } + if found2.MessageCount != 0 { + t.Errorf("expected conv2 message_count=0, got %d", found2.MessageCount) + } + // conv2 was created after conv1, so it should appear earlier in DESC order + for i, s := range summaries { + if s.ID == conv2ID { + for j, s2 := range summaries { + if s2.ID == conv1ID { + if i > j { + t.Errorf("expected conv2 (newer) before conv1 (older) in DESC order, but conv2 at index %d, conv1 at index %d", i, j) + } + break + } + } + break + } + } + }) +}