feat(06-01): add conversation and message CRUD methods with integration tests
- CreateConversation, AddMessage, GetConversation, ListConversations on *Store - ErrNotFound sentinel for unknown conversation IDs - Message, Conversation, ConversationSummary types - LIMIT 100 soft guard on ListConversations (T-06-01-04) - Integration tests cover full round-trip, invalid role, ErrNotFound, ordering
This commit is contained in:
parent
4bc22dc7b9
commit
623cff0d76
2 changed files with 442 additions and 0 deletions
172
internal/store/conversations.go
Normal file
172
internal/store/conversations.go
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned when a requested resource does not exist.
|
||||
var ErrNotFound = errors.New("store: not found")
|
||||
|
||||
// Message represents a single chat message stored in the messages table.
|
||||
type Message struct {
|
||||
ID string
|
||||
ConversationID string
|
||||
Role string
|
||||
Content string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// Conversation is a full conversation with all its messages.
|
||||
type Conversation struct {
|
||||
ID string
|
||||
Model string
|
||||
StartedAt time.Time
|
||||
Messages []Message
|
||||
}
|
||||
|
||||
// ConversationSummary is a lightweight view of a conversation used in list responses.
|
||||
// It omits individual messages to reduce query cost (T-06-01-04: LIMIT 100 guard).
|
||||
type ConversationSummary struct {
|
||||
ID string
|
||||
Model string
|
||||
StartedAt time.Time
|
||||
MessageCount int
|
||||
}
|
||||
|
||||
// CreateConversation inserts a new conversation row and returns its UUID.
|
||||
// All parameters are passed as $N placeholders (T-06-01-01).
|
||||
func (s *Store) CreateConversation(ctx context.Context, model string) (string, error) {
|
||||
var id string
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`INSERT INTO conversations (model) VALUES ($1) RETURNING id`,
|
||||
model,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("store: create conversation: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// AddMessage appends a message to an existing conversation.
|
||||
// An invalid role value will cause a CHECK constraint violation from the DB (T-06-01-03).
|
||||
func (s *Store) AddMessage(ctx context.Context, conversationID, role, content string) (string, error) {
|
||||
var id string
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`INSERT INTO messages (conversation_id, role, content) VALUES ($1, $2, $3) RETURNING id`,
|
||||
conversationID, role, content,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("store: add message: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetConversation fetches a conversation and all its messages ordered by created_at ASC.
|
||||
// Returns nil, ErrNotFound if the conversation ID does not exist.
|
||||
func (s *Store) GetConversation(ctx context.Context, id string) (*Conversation, error) {
|
||||
rows, err := s.pool.Query(ctx,
|
||||
`SELECT c.id, c.started_at, c.model,
|
||||
m.id, m.conversation_id, m.role, m.content, m.created_at
|
||||
FROM conversations c
|
||||
LEFT JOIN messages m ON m.conversation_id = c.id
|
||||
WHERE c.id = $1
|
||||
ORDER BY m.created_at ASC NULLS LAST`,
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: get conversation query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conv *Conversation
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
cID string
|
||||
cStartedAt time.Time
|
||||
cModel string
|
||||
// Message fields — nullable because LEFT JOIN may produce NULLs
|
||||
// when the conversation has no messages.
|
||||
mID *string
|
||||
mConversationID *string
|
||||
mRole *string
|
||||
mContent *string
|
||||
mCreatedAt *time.Time
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&cID, &cStartedAt, &cModel,
|
||||
&mID, &mConversationID, &mRole, &mContent, &mCreatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("store: get conversation scan: %w", err)
|
||||
}
|
||||
|
||||
if conv == nil {
|
||||
conv = &Conversation{
|
||||
ID: cID,
|
||||
Model: cModel,
|
||||
StartedAt: cStartedAt,
|
||||
Messages: []Message{},
|
||||
}
|
||||
}
|
||||
|
||||
if mID != nil {
|
||||
conv.Messages = append(conv.Messages, Message{
|
||||
ID: *mID,
|
||||
ConversationID: *mConversationID,
|
||||
Role: *mRole,
|
||||
Content: *mContent,
|
||||
CreatedAt: *mCreatedAt,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("store: get conversation rows: %w", err)
|
||||
}
|
||||
|
||||
if conv == nil {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// ListConversations returns all conversations ordered by started_at DESC.
|
||||
// A soft LIMIT 100 guards against unbounded growth on a busy homelab instance (T-06-01-04).
|
||||
func (s *Store) ListConversations(ctx context.Context) ([]ConversationSummary, error) {
|
||||
rows, err := s.pool.Query(ctx,
|
||||
`SELECT c.id, c.started_at, c.model, COUNT(m.id) AS message_count
|
||||
FROM conversations c
|
||||
LEFT JOIN messages m ON m.conversation_id = c.id
|
||||
GROUP BY c.id
|
||||
ORDER BY c.started_at DESC
|
||||
LIMIT 100`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: list conversations query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var summaries []ConversationSummary
|
||||
for rows.Next() {
|
||||
var s ConversationSummary
|
||||
if err := rows.Scan(&s.ID, &s.StartedAt, &s.Model, &s.MessageCount); err != nil {
|
||||
return nil, fmt.Errorf("store: list conversations scan: %w", err)
|
||||
}
|
||||
summaries = append(summaries, s)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("store: list conversations rows: %w", err)
|
||||
}
|
||||
|
||||
if summaries == nil {
|
||||
summaries = []ConversationSummary{}
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
}
|
||||
270
internal/store/store_test.go
Normal file
270
internal/store/store_test.go
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
//go:build integration
|
||||
|
||||
package store_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "git.georgsen.dk/hwlab/internal/store"
|
||||
)
|
||||
|
||||
func testDSN(t *testing.T) string {
|
||||
t.Helper()
|
||||
dsn := os.Getenv("HWLAB_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Skip("HWLAB_DATABASE_URL not set, skipping integration tests")
|
||||
}
|
||||
return dsn
|
||||
}
|
||||
|
||||
// TestNewStore verifies basic connection lifecycle.
|
||||
func TestNewStore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("invalid DSN returns error", func(t *testing.T) {
|
||||
_, err := NewStore(ctx, "invalid-dsn")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid DSN, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid DSN returns store", func(t *testing.T) {
|
||||
dsn := testDSN(t)
|
||||
s, err := NewStore(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
if s.Pool() == nil {
|
||||
t.Fatal("expected non-nil pool")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Close does not panic", func(t *testing.T) {
|
||||
dsn := testDSN(t)
|
||||
s, err := NewStore(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
s.Close() // must not panic
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunMigrations verifies idempotent table creation.
|
||||
func TestRunMigrations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dsn := testDSN(t)
|
||||
|
||||
s, err := NewStore(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
t.Run("first run creates tables", func(t *testing.T) {
|
||||
if err := RunMigrations(ctx, s.Pool()); err != nil {
|
||||
t.Fatalf("RunMigrations failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second run is idempotent", func(t *testing.T) {
|
||||
if err := RunMigrations(ctx, s.Pool()); err != nil {
|
||||
t.Fatalf("RunMigrations second call failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConversationCRUD verifies full CRUD round-trip.
|
||||
func TestConversationCRUD(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dsn := testDSN(t)
|
||||
|
||||
s, err := NewStore(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
if err := RunMigrations(ctx, s.Pool()); err != nil {
|
||||
t.Fatalf("RunMigrations failed: %v", err)
|
||||
}
|
||||
|
||||
t.Run("CreateConversation returns ID", func(t *testing.T) {
|
||||
convID, err := s.CreateConversation(ctx, "gemma-4-e4b")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConversation failed: %v", err)
|
||||
}
|
||||
if convID == "" {
|
||||
t.Fatal("expected non-empty conversation ID")
|
||||
}
|
||||
// Cleanup
|
||||
defer func() {
|
||||
_, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", convID)
|
||||
}()
|
||||
})
|
||||
|
||||
t.Run("AddMessage returns message ID", func(t *testing.T) {
|
||||
convID, err := s.CreateConversation(ctx, "gemma-4-e4b")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConversation failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", convID)
|
||||
}()
|
||||
|
||||
msgID, err := s.AddMessage(ctx, convID, "user", "Hello, AI!")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage failed: %v", err)
|
||||
}
|
||||
if msgID == "" {
|
||||
t.Fatal("expected non-empty message ID")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AddMessage with invalid role returns error", func(t *testing.T) {
|
||||
convID, err := s.CreateConversation(ctx, "test-model")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConversation failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", convID)
|
||||
}()
|
||||
|
||||
_, err = s.AddMessage(ctx, convID, "invalid_role", "content")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid role, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetConversation returns ErrNotFound for unknown ID", func(t *testing.T) {
|
||||
conv, err := s.GetConversation(ctx, "00000000-0000-0000-0000-000000000000")
|
||||
if err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound, got err=%v conv=%v", err, conv)
|
||||
}
|
||||
if conv != nil {
|
||||
t.Fatal("expected nil conversation for unknown ID")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("full round-trip: create, add messages, get", func(t *testing.T) {
|
||||
convID, err := s.CreateConversation(ctx, "gemma-4-e4b")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConversation failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", convID)
|
||||
}()
|
||||
|
||||
msg1ID, err := s.AddMessage(ctx, convID, "user", "What is a switch?")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage (user) failed: %v", err)
|
||||
}
|
||||
// Small sleep to ensure ordering by created_at
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
msg2ID, err := s.AddMessage(ctx, convID, "assistant", "A switch is a network device...")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage (assistant) failed: %v", err)
|
||||
}
|
||||
|
||||
conv, err := s.GetConversation(ctx, convID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConversation failed: %v", err)
|
||||
}
|
||||
if conv.ID != convID {
|
||||
t.Errorf("expected conv ID %s, got %s", convID, conv.ID)
|
||||
}
|
||||
if conv.Model != "gemma-4-e4b" {
|
||||
t.Errorf("expected model gemma-4-e4b, got %s", conv.Model)
|
||||
}
|
||||
if len(conv.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(conv.Messages))
|
||||
}
|
||||
if conv.Messages[0].ID != msg1ID {
|
||||
t.Errorf("expected first message ID %s, got %s", msg1ID, conv.Messages[0].ID)
|
||||
}
|
||||
if conv.Messages[1].ID != msg2ID {
|
||||
t.Errorf("expected second message ID %s, got %s", msg2ID, conv.Messages[1].ID)
|
||||
}
|
||||
if conv.Messages[0].Role != "user" {
|
||||
t.Errorf("expected first message role 'user', got %s", conv.Messages[0].Role)
|
||||
}
|
||||
if conv.Messages[1].Role != "assistant" {
|
||||
t.Errorf("expected second message role 'assistant', got %s", conv.Messages[1].Role)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ListConversations returns summaries ordered by started_at DESC", func(t *testing.T) {
|
||||
// Create two conversations
|
||||
conv1ID, err := s.CreateConversation(ctx, "model-a")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConversation 1 failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", conv1ID)
|
||||
}()
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
conv2ID, err := s.CreateConversation(ctx, "model-b")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConversation 2 failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _ = s.Pool().Exec(ctx, "DELETE FROM conversations WHERE id = $1", conv2ID)
|
||||
}()
|
||||
|
||||
// Add a message to conv1
|
||||
_, err = s.AddMessage(ctx, conv1ID, "user", "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage failed: %v", err)
|
||||
}
|
||||
|
||||
summaries, err := s.ListConversations(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListConversations failed: %v", err)
|
||||
}
|
||||
|
||||
// Find our two conversations in the list
|
||||
var found1, found2 *ConversationSummary
|
||||
for i := range summaries {
|
||||
switch summaries[i].ID {
|
||||
case conv1ID:
|
||||
found1 = &summaries[i]
|
||||
case conv2ID:
|
||||
found2 = &summaries[i]
|
||||
}
|
||||
}
|
||||
|
||||
if found1 == nil {
|
||||
t.Fatal("conv1 not found in ListConversations")
|
||||
}
|
||||
if found2 == nil {
|
||||
t.Fatal("conv2 not found in ListConversations")
|
||||
}
|
||||
if found1.MessageCount != 1 {
|
||||
t.Errorf("expected conv1 message_count=1, got %d", found1.MessageCount)
|
||||
}
|
||||
if found2.MessageCount != 0 {
|
||||
t.Errorf("expected conv2 message_count=0, got %d", found2.MessageCount)
|
||||
}
|
||||
// conv2 was created after conv1, so it should appear earlier in DESC order
|
||||
for i, s := range summaries {
|
||||
if s.ID == conv2ID {
|
||||
for j, s2 := range summaries {
|
||||
if s2.ID == conv1ID {
|
||||
if i > j {
|
||||
t.Errorf("expected conv2 (newer) before conv1 (older) in DESC order, but conv2 at index %d, conv1 at index %d", i, j)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue