package store import ( "context" "database/sql" "encoding/json" "errors" "fmt" "net/url" "path/filepath" "sync" "time" _ "modernc.org/sqlite" "banger/internal/model" ) type Store struct { db *sql.DB writeMu sync.Mutex } func Open(path string) (*Store, error) { dsn, err := sqliteDSN(path) if err != nil { return nil, err } db, err := sql.Open("sqlite", dsn) if err != nil { return nil, err } store := &Store{db: db} if err := store.migrate(); err != nil { _ = db.Close() return nil, err } return store, nil } func (s *Store) Close() error { return s.db.Close() } func sqliteDSN(path string) (string, error) { absPath, err := filepath.Abs(path) if err != nil { return "", fmt.Errorf("resolve sqlite path: %w", err) } query := url.Values{} for _, pragma := range []string{ "journal_mode(WAL)", "synchronous(NORMAL)", "foreign_keys(1)", "busy_timeout(5000)", "temp_store(MEMORY)", "wal_autocheckpoint(1000)", "journal_size_limit(67108864)", } { query.Add("_pragma", pragma) } return (&url.URL{ Scheme: "file", Path: filepath.ToSlash(absPath), RawQuery: query.Encode(), }).String(), nil } func (s *Store) migrate() error { stmts := []string{ `CREATE TABLE IF NOT EXISTS images ( id TEXT PRIMARY KEY, name TEXT NOT NULL UNIQUE, managed INTEGER NOT NULL DEFAULT 0, artifact_dir TEXT, rootfs_path TEXT NOT NULL, work_seed_path TEXT, kernel_path TEXT NOT NULL, initrd_path TEXT, modules_dir TEXT, packages_path TEXT, build_size TEXT, seeded_ssh_public_key_fingerprint TEXT, docker INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL, updated_at TEXT NOT NULL );`, `CREATE TABLE IF NOT EXISTS vms ( id TEXT PRIMARY KEY, name TEXT NOT NULL UNIQUE, image_id TEXT NOT NULL, guest_ip TEXT NOT NULL UNIQUE, state TEXT NOT NULL, created_at TEXT NOT NULL, updated_at TEXT NOT NULL, last_touched_at TEXT NOT NULL, spec_json TEXT NOT NULL, runtime_json TEXT NOT NULL, stats_json TEXT NOT NULL DEFAULT '{}', FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE RESTRICT );`, `CREATE TABLE IF NOT EXISTS guest_sessions ( id TEXT PRIMARY KEY, vm_id TEXT NOT NULL, name TEXT NOT NULL, backend TEXT NOT NULL, command TEXT NOT NULL, args_json TEXT NOT NULL DEFAULT '[]', cwd TEXT, env_json TEXT NOT NULL DEFAULT '{}', stdin_mode TEXT NOT NULL, status TEXT NOT NULL, exit_code INTEGER, guest_pid INTEGER NOT NULL DEFAULT 0, guest_state_dir TEXT, stdout_log_path TEXT, stderr_log_path TEXT, tags_json TEXT NOT NULL DEFAULT '{}', last_error TEXT, attachable INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL, started_at TEXT, updated_at TEXT NOT NULL, ended_at TEXT, UNIQUE(vm_id, name), FOREIGN KEY(vm_id) REFERENCES vms(id) ON DELETE CASCADE );`, } for _, stmt := range stmts { if _, err := s.db.Exec(stmt); err != nil { return err } } if err := ensureColumnExists(s.db, "images", "work_seed_path", "TEXT"); err != nil { return err } if err := ensureColumnExists(s.db, "images", "seeded_ssh_public_key_fingerprint", "TEXT"); err != nil { return err } for _, spec := range []struct{ table, column, typ string }{ {"guest_sessions", "attach_backend", "TEXT"}, {"guest_sessions", "attach_mode", "TEXT"}, {"guest_sessions", "reattachable", "INTEGER NOT NULL DEFAULT 0"}, {"guest_sessions", "launch_stage", "TEXT"}, {"guest_sessions", "launch_message", "TEXT"}, {"guest_sessions", "launch_raw_log", "TEXT"}, } { if err := ensureColumnExists(s.db, spec.table, spec.column, spec.typ); err != nil { return err } } return nil } func (s *Store) UpsertImage(ctx context.Context, image model.Image) error { s.writeMu.Lock() defer s.writeMu.Unlock() const query = ` INSERT INTO images ( id, name, managed, artifact_dir, rootfs_path, work_seed_path, kernel_path, initrd_path, modules_dir, build_size, seeded_ssh_public_key_fingerprint, docker, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET name=excluded.name, managed=excluded.managed, artifact_dir=excluded.artifact_dir, rootfs_path=excluded.rootfs_path, work_seed_path=excluded.work_seed_path, kernel_path=excluded.kernel_path, initrd_path=excluded.initrd_path, modules_dir=excluded.modules_dir, build_size=excluded.build_size, seeded_ssh_public_key_fingerprint=excluded.seeded_ssh_public_key_fingerprint, docker=excluded.docker, updated_at=excluded.updated_at` _, err := s.db.ExecContext(ctx, query, image.ID, image.Name, boolToInt(image.Managed), image.ArtifactDir, image.RootfsPath, image.WorkSeedPath, image.KernelPath, image.InitrdPath, image.ModulesDir, image.BuildSize, image.SeededSSHPublicKeyFingerprint, boolToInt(image.Docker), image.CreatedAt.Format(time.RFC3339), image.UpdatedAt.Format(time.RFC3339), ) return err } func (s *Store) GetImageByName(ctx context.Context, name string) (model.Image, error) { return s.getImage(ctx, "SELECT id, name, managed, artifact_dir, rootfs_path, work_seed_path, kernel_path, initrd_path, modules_dir, build_size, seeded_ssh_public_key_fingerprint, docker, created_at, updated_at FROM images WHERE name = ?", name) } func (s *Store) GetImageByID(ctx context.Context, id string) (model.Image, error) { return s.getImage(ctx, "SELECT id, name, managed, artifact_dir, rootfs_path, work_seed_path, kernel_path, initrd_path, modules_dir, build_size, seeded_ssh_public_key_fingerprint, docker, created_at, updated_at FROM images WHERE id = ?", id) } func (s *Store) ListImages(ctx context.Context) ([]model.Image, error) { rows, err := s.db.QueryContext(ctx, "SELECT id, name, managed, artifact_dir, rootfs_path, work_seed_path, kernel_path, initrd_path, modules_dir, build_size, seeded_ssh_public_key_fingerprint, docker, created_at, updated_at FROM images ORDER BY created_at ASC") if err != nil { return nil, err } defer rows.Close() var images []model.Image for rows.Next() { image, err := scanImage(rows) if err != nil { return nil, err } images = append(images, image) } return images, rows.Err() } func (s *Store) DeleteImage(ctx context.Context, id string) error { s.writeMu.Lock() defer s.writeMu.Unlock() _, err := s.db.ExecContext(ctx, "DELETE FROM images WHERE id = ?", id) return err } func (s *Store) UpsertVM(ctx context.Context, vm model.VMRecord) error { s.writeMu.Lock() defer s.writeMu.Unlock() specJSON, err := json.Marshal(vm.Spec) if err != nil { return err } runtimeJSON, err := json.Marshal(vm.Runtime) if err != nil { return err } statsJSON, err := json.Marshal(vm.Stats) if err != nil { return err } const query = ` INSERT INTO vms ( id, name, image_id, guest_ip, state, created_at, updated_at, last_touched_at, spec_json, runtime_json, stats_json ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET name=excluded.name, image_id=excluded.image_id, guest_ip=excluded.guest_ip, state=excluded.state, updated_at=excluded.updated_at, last_touched_at=excluded.last_touched_at, spec_json=excluded.spec_json, runtime_json=excluded.runtime_json, stats_json=excluded.stats_json` _, err = s.db.ExecContext(ctx, query, vm.ID, vm.Name, vm.ImageID, vm.Runtime.GuestIP, string(vm.State), vm.CreatedAt.Format(time.RFC3339), vm.UpdatedAt.Format(time.RFC3339), vm.LastTouchedAt.Format(time.RFC3339), string(specJSON), string(runtimeJSON), string(statsJSON), ) return err } func (s *Store) GetVM(ctx context.Context, idOrName string) (model.VMRecord, error) { const query = ` SELECT id, name, image_id, guest_ip, state, created_at, updated_at, last_touched_at, spec_json, runtime_json, stats_json FROM vms WHERE id = ? OR name = ? ` row := s.db.QueryRowContext(ctx, query, idOrName, idOrName) return scanVMRow(row) } func (s *Store) GetVMByID(ctx context.Context, id string) (model.VMRecord, error) { row := s.db.QueryRowContext(ctx, ` SELECT id, name, image_id, guest_ip, state, created_at, updated_at, last_touched_at, spec_json, runtime_json, stats_json FROM vms WHERE id = ?`, id) return scanVMRow(row) } func (s *Store) ListVMs(ctx context.Context) ([]model.VMRecord, error) { rows, err := s.db.QueryContext(ctx, ` SELECT id, name, image_id, guest_ip, state, created_at, updated_at, last_touched_at, spec_json, runtime_json, stats_json FROM vms ORDER BY created_at ASC`) if err != nil { return nil, err } defer rows.Close() var vms []model.VMRecord for rows.Next() { vm, err := scanVMRows(rows) if err != nil { return nil, err } vms = append(vms, vm) } return vms, rows.Err() } func (s *Store) DeleteVM(ctx context.Context, id string) error { s.writeMu.Lock() defer s.writeMu.Unlock() _, err := s.db.ExecContext(ctx, "DELETE FROM vms WHERE id = ?", id) return err } func (s *Store) FindVMsUsingImage(ctx context.Context, imageID string) ([]model.VMRecord, error) { rows, err := s.db.QueryContext(ctx, ` SELECT id, name, image_id, guest_ip, state, created_at, updated_at, last_touched_at, spec_json, runtime_json, stats_json FROM vms WHERE image_id = ?`, imageID) if err != nil { return nil, err } defer rows.Close() var vms []model.VMRecord for rows.Next() { vm, err := scanVMRows(rows) if err != nil { return nil, err } vms = append(vms, vm) } return vms, rows.Err() } func (s *Store) UpsertGuestSession(ctx context.Context, session model.GuestSession) error { s.writeMu.Lock() defer s.writeMu.Unlock() argsJSON, err := json.Marshal(session.Args) if err != nil { return err } envJSON, err := json.Marshal(session.Env) if err != nil { return err } tagsJSON, err := json.Marshal(session.Tags) if err != nil { return err } const query = ` INSERT INTO guest_sessions ( id, vm_id, name, backend, attach_backend, attach_mode, command, args_json, cwd, env_json, stdin_mode, status, exit_code, guest_pid, guest_state_dir, stdout_log_path, stderr_log_path, tags_json, last_error, attachable, reattachable, launch_stage, launch_message, launch_raw_log, created_at, started_at, updated_at, ended_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET vm_id=excluded.vm_id, name=excluded.name, backend=excluded.backend, attach_backend=excluded.attach_backend, attach_mode=excluded.attach_mode, command=excluded.command, args_json=excluded.args_json, cwd=excluded.cwd, env_json=excluded.env_json, stdin_mode=excluded.stdin_mode, status=excluded.status, exit_code=excluded.exit_code, guest_pid=excluded.guest_pid, guest_state_dir=excluded.guest_state_dir, stdout_log_path=excluded.stdout_log_path, stderr_log_path=excluded.stderr_log_path, tags_json=excluded.tags_json, last_error=excluded.last_error, attachable=excluded.attachable, reattachable=excluded.reattachable, launch_stage=excluded.launch_stage, launch_message=excluded.launch_message, launch_raw_log=excluded.launch_raw_log, started_at=excluded.started_at, updated_at=excluded.updated_at, ended_at=excluded.ended_at` _, err = s.db.ExecContext(ctx, query, session.ID, session.VMID, session.Name, session.Backend, session.AttachBackend, session.AttachMode, session.Command, string(argsJSON), session.CWD, string(envJSON), string(session.StdinMode), string(session.Status), nullableInt(session.ExitCode), session.GuestPID, session.GuestStateDir, session.StdoutLogPath, session.StderrLogPath, string(tagsJSON), session.LastError, boolToInt(session.Attachable), boolToInt(session.Reattachable), session.LaunchStage, session.LaunchMessage, session.LaunchRawLog, session.CreatedAt.Format(time.RFC3339), nullableTimeString(session.StartedAt), session.UpdatedAt.Format(time.RFC3339), nullableTimeString(session.EndedAt), ) return err } func (s *Store) GetGuestSessionByID(ctx context.Context, id string) (model.GuestSession, error) { row := s.db.QueryRowContext(ctx, guestSessionSelectSQL+" WHERE id = ?", id) return scanGuestSessionRow(row) } func (s *Store) GetGuestSession(ctx context.Context, vmID, idOrName string) (model.GuestSession, error) { row := s.db.QueryRowContext(ctx, guestSessionSelectSQL+" WHERE vm_id = ? AND (id = ? OR name = ?)", vmID, idOrName, idOrName) return scanGuestSessionRow(row) } func (s *Store) ListGuestSessionsByVM(ctx context.Context, vmID string) ([]model.GuestSession, error) { rows, err := s.db.QueryContext(ctx, guestSessionSelectSQL+" WHERE vm_id = ? ORDER BY created_at ASC", vmID) if err != nil { return nil, err } defer rows.Close() var sessions []model.GuestSession for rows.Next() { session, err := scanGuestSession(rows) if err != nil { return nil, err } sessions = append(sessions, session) } return sessions, rows.Err() } func (s *Store) DeleteGuestSession(ctx context.Context, id string) error { s.writeMu.Lock() defer s.writeMu.Unlock() _, err := s.db.ExecContext(ctx, "DELETE FROM guest_sessions WHERE id = ?", id) return err } func (s *Store) NextGuestIP(ctx context.Context, bridgeIPPrefix string) (string, error) { used := map[string]struct{}{} rows, err := s.db.QueryContext(ctx, "SELECT guest_ip FROM vms") if err != nil { return "", err } defer rows.Close() for rows.Next() { var ip string if err := rows.Scan(&ip); err != nil { return "", err } used[ip] = struct{}{} } if err := rows.Err(); err != nil { return "", err } for i := 2; i < 255; i++ { candidate := fmt.Sprintf("%s.%d", bridgeIPPrefix, i) if _, exists := used[candidate]; !exists { return candidate, nil } } return "", errors.New("no guest IPs available") } func (s *Store) getImage(ctx context.Context, query string, arg string) (model.Image, error) { row := s.db.QueryRowContext(ctx, query, arg) return scanImageRow(row) } func scanImage(rows scanner) (model.Image, error) { return scanImageRow(rows) } type scanner interface { Scan(dest ...any) error } func scanImageRow(row scanner) (model.Image, error) { var image model.Image var managed, docker int var workSeedPath sql.NullString var seededSSHPublicKeyFingerprint sql.NullString var createdAt, updatedAt string err := row.Scan( &image.ID, &image.Name, &managed, &image.ArtifactDir, &image.RootfsPath, &workSeedPath, &image.KernelPath, &image.InitrdPath, &image.ModulesDir, &image.BuildSize, &seededSSHPublicKeyFingerprint, &docker, &createdAt, &updatedAt, ) if err != nil { return image, err } image.Managed = managed == 1 image.Docker = docker == 1 image.WorkSeedPath = workSeedPath.String image.SeededSSHPublicKeyFingerprint = seededSSHPublicKeyFingerprint.String image.CreatedAt, err = time.Parse(time.RFC3339, createdAt) if err != nil { return image, err } image.UpdatedAt, err = time.Parse(time.RFC3339, updatedAt) if err != nil { return image, err } return image, nil } func scanVMRow(row scanner) (model.VMRecord, error) { return scanVMInto(row) } func scanVMRows(rows scanner) (model.VMRecord, error) { return scanVMInto(rows) } func scanVMInto(row scanner) (model.VMRecord, error) { var vm model.VMRecord var state, createdAt, updatedAt, touchedAt, specJSON, runtimeJSON, statsJSON string err := row.Scan( &vm.ID, &vm.Name, &vm.ImageID, &vm.Runtime.GuestIP, &state, &createdAt, &updatedAt, &touchedAt, &specJSON, &runtimeJSON, &statsJSON, ) if err != nil { return vm, err } vm.State = model.VMState(state) if err := json.Unmarshal([]byte(specJSON), &vm.Spec); err != nil { return vm, err } if err := json.Unmarshal([]byte(runtimeJSON), &vm.Runtime); err != nil { return vm, err } if statsJSON != "" { if err := json.Unmarshal([]byte(statsJSON), &vm.Stats); err != nil { return vm, err } } var parseErr error vm.CreatedAt, parseErr = time.Parse(time.RFC3339, createdAt) if parseErr != nil { return vm, parseErr } vm.UpdatedAt, parseErr = time.Parse(time.RFC3339, updatedAt) if parseErr != nil { return vm, parseErr } vm.LastTouchedAt, parseErr = time.Parse(time.RFC3339, touchedAt) if parseErr != nil { return vm, parseErr } return vm, nil } func ensureColumnExists(db *sql.DB, table, column, columnType string) error { rows, err := db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table)) if err != nil { return err } defer rows.Close() for rows.Next() { var ( cid int name string valueType string notNull int defaultV sql.NullString pk int ) if err := rows.Scan(&cid, &name, &valueType, ¬Null, &defaultV, &pk); err != nil { return err } if name == column { return nil } } if err := rows.Err(); err != nil { return err } _, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, columnType)) return err } func boolToInt(value bool) int { if value { return 1 } return 0 } const guestSessionSelectSQL = ` SELECT id, vm_id, name, backend, attach_backend, attach_mode, command, args_json, cwd, env_json, stdin_mode, status, exit_code, guest_pid, guest_state_dir, stdout_log_path, stderr_log_path, tags_json, last_error, attachable, reattachable, launch_stage, launch_message, launch_raw_log, created_at, started_at, updated_at, ended_at FROM guest_sessions` func scanGuestSession(rows scanner) (model.GuestSession, error) { return scanGuestSessionRow(rows) } func scanGuestSessionRow(row scanner) (model.GuestSession, error) { var session model.GuestSession var ( argsJSON string envJSON string tagsJSON string stdinMode string status string exitCode sql.NullInt64 startedAt sql.NullString endedAt sql.NullString attachable int reattachable int createdRaw string updatedRaw string ) err := row.Scan( &session.ID, &session.VMID, &session.Name, &session.Backend, &session.AttachBackend, &session.AttachMode, &session.Command, &argsJSON, &session.CWD, &envJSON, &stdinMode, &status, &exitCode, &session.GuestPID, &session.GuestStateDir, &session.StdoutLogPath, &session.StderrLogPath, &tagsJSON, &session.LastError, &attachable, &reattachable, &session.LaunchStage, &session.LaunchMessage, &session.LaunchRawLog, &createdRaw, &startedAt, &updatedRaw, &endedAt, ) if err != nil { return session, err } session.StdinMode = model.GuestSessionStdinMode(stdinMode) session.Status = model.GuestSessionStatus(status) session.Attachable = attachable == 1 session.Reattachable = reattachable == 1 if argsJSON != "" { if err := json.Unmarshal([]byte(argsJSON), &session.Args); err != nil { return session, err } } if envJSON != "" { if err := json.Unmarshal([]byte(envJSON), &session.Env); err != nil { return session, err } } if tagsJSON != "" { if err := json.Unmarshal([]byte(tagsJSON), &session.Tags); err != nil { return session, err } } if exitCode.Valid { value := int(exitCode.Int64) session.ExitCode = &value } var parseErr error session.CreatedAt, parseErr = time.Parse(time.RFC3339, createdRaw) if parseErr != nil { return session, parseErr } session.UpdatedAt, parseErr = time.Parse(time.RFC3339, updatedRaw) if parseErr != nil { return session, parseErr } if startedAt.Valid && startedAt.String != "" { session.StartedAt, parseErr = time.Parse(time.RFC3339, startedAt.String) if parseErr != nil { return session, parseErr } } if endedAt.Valid && endedAt.String != "" { session.EndedAt, parseErr = time.Parse(time.RFC3339, endedAt.String) if parseErr != nil { return session, parseErr } } return session, nil } func nullableTimeString(value time.Time) any { if value.IsZero() { return nil } return value.Format(time.RFC3339) } func nullableInt(value *int) any { if value == nil { return nil } return *value }