coverage: medium batch — hostnat runner, store guest-sessions, daemon helpers
Reuses existing fixtures (CommandRunner fakes, SQLite tempfile store, pure-Go seams). No new infra needed. hostnat 50% -> 98% (iptables orchestration via fake runner) store 78% -> 91% (guest_sessions CRUD roundtrip) daemon/session 57% -> 95% (script gen, state parse, snapshot apply) daemon/opstate 67% -> 100% (Registry Insert/Get/Prune) daemon (firstNonEmpty) slight bump Total 54.0% -> 56.5%.
This commit is contained in:
parent
f8979de58a
commit
346eaba673
5 changed files with 1010 additions and 0 deletions
24
internal/daemon/images_helpers_test.go
Normal file
24
internal/daemon/images_helpers_test.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package daemon
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFirstNonEmpty(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
values []string
|
||||
want string
|
||||
}{
|
||||
{"all empty", []string{"", " ", "\t"}, ""},
|
||||
{"first wins", []string{"a", "b"}, "a"},
|
||||
{"skips blanks", []string{"", " ", "first", "second"}, "first"},
|
||||
{"nil input", nil, ""},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := firstNonEmpty(tc.values...)
|
||||
if got != tc.want {
|
||||
t.Errorf("firstNonEmpty(%v) = %q, want %q", tc.values, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
74
internal/daemon/opstate/registry_test.go
Normal file
74
internal/daemon/opstate/registry_test.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
package opstate
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fakeOp struct {
|
||||
id string
|
||||
done atomic.Bool
|
||||
updatedAt time.Time
|
||||
canceled atomic.Bool
|
||||
}
|
||||
|
||||
func (f *fakeOp) ID() string { return f.id }
|
||||
func (f *fakeOp) IsDone() bool { return f.done.Load() }
|
||||
func (f *fakeOp) UpdatedAt() time.Time { return f.updatedAt }
|
||||
func (f *fakeOp) Cancel() { f.canceled.Store(true) }
|
||||
|
||||
func TestRegistryInsertAndGet(t *testing.T) {
|
||||
var r Registry[*fakeOp]
|
||||
op := &fakeOp{id: "op-1", updatedAt: time.Now()}
|
||||
r.Insert(op)
|
||||
got, ok := r.Get("op-1")
|
||||
if !ok {
|
||||
t.Fatal("Get after Insert missed")
|
||||
}
|
||||
if got.ID() != "op-1" {
|
||||
t.Fatalf("Get().ID = %q", got.ID())
|
||||
}
|
||||
|
||||
_, ok = r.Get("missing")
|
||||
if ok {
|
||||
t.Fatal("Get on missing key should miss")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPruneDropsCompletedOldOps(t *testing.T) {
|
||||
var r Registry[*fakeOp]
|
||||
now := time.Now()
|
||||
|
||||
recent := &fakeOp{id: "recent", updatedAt: now}
|
||||
recent.done.Store(true)
|
||||
|
||||
stale := &fakeOp{id: "stale", updatedAt: now.Add(-time.Hour)}
|
||||
stale.done.Store(true)
|
||||
|
||||
pending := &fakeOp{id: "pending", updatedAt: now.Add(-time.Hour)}
|
||||
// NOT done → stays even though old.
|
||||
|
||||
r.Insert(recent)
|
||||
r.Insert(stale)
|
||||
r.Insert(pending)
|
||||
|
||||
cutoff := now.Add(-time.Minute)
|
||||
r.Prune(cutoff)
|
||||
|
||||
if _, ok := r.Get("stale"); ok {
|
||||
t.Error("stale op should have been pruned")
|
||||
}
|
||||
if _, ok := r.Get("recent"); !ok {
|
||||
t.Error("recent op should survive (newer than cutoff)")
|
||||
}
|
||||
if _, ok := r.Get("pending"); !ok {
|
||||
t.Error("pending op should survive (not done)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPruneNoOpOnEmpty(t *testing.T) {
|
||||
var r Registry[*fakeOp]
|
||||
// Just shouldn't panic.
|
||||
r.Prune(time.Now())
|
||||
}
|
||||
440
internal/daemon/session/session_test.go
Normal file
440
internal/daemon/session/session_test.go
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"banger/internal/model"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestRelativeStateDir(t *testing.T) {
|
||||
got := RelativeStateDir("abc")
|
||||
if strings.HasPrefix(got, "/root/") {
|
||||
t.Fatalf("RelativeStateDir(%q) = %q, should strip /root/ prefix", "abc", got)
|
||||
}
|
||||
if !strings.Contains(got, "abc") {
|
||||
t.Fatalf("missing session id in %q", got)
|
||||
}
|
||||
absolute := StateDir("abc")
|
||||
if got != strings.TrimPrefix(absolute, "/root/") {
|
||||
t.Fatalf("relative = %q, want %q", got, strings.TrimPrefix(absolute, "/root/"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCWD(t *testing.T) {
|
||||
if DefaultCWD("") != "/root" {
|
||||
t.Error("empty should return /root")
|
||||
}
|
||||
if DefaultCWD(" ") != "/root" {
|
||||
t.Error("whitespace should return /root")
|
||||
}
|
||||
if DefaultCWD("/work") != "/work" {
|
||||
t.Error("explicit should pass through")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellQuote(t *testing.T) {
|
||||
if got := ShellQuote(""); got != "''" {
|
||||
t.Errorf("empty: got %q, want ''", got)
|
||||
}
|
||||
if got := ShellQuote("x"); got != "'x'" {
|
||||
t.Errorf("plain: got %q", got)
|
||||
}
|
||||
if got := ShellQuote("it's"); got != `'it'"'"'s'` {
|
||||
t.Errorf("apostrophe: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitCode(t *testing.T) {
|
||||
if code, ok := ExitCode(nil); !ok || code != 0 {
|
||||
t.Errorf("nil err: got (%d, %v), want (0, true)", code, ok)
|
||||
}
|
||||
// Build an ssh.ExitError using its real type — can't hand-construct,
|
||||
// so wrap via errors.As check with a stub.
|
||||
raw := &ssh.ExitError{}
|
||||
if _, ok := ExitCode(raw); !ok {
|
||||
t.Error("ssh.ExitError: ok should be true")
|
||||
}
|
||||
if _, ok := ExitCode(errors.New("bare error")); ok {
|
||||
t.Error("bare error: ok should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneStringMap(t *testing.T) {
|
||||
if CloneStringMap(nil) != nil {
|
||||
t.Error("nil in → nil out")
|
||||
}
|
||||
if CloneStringMap(map[string]string{}) != nil {
|
||||
t.Error("empty in → nil out")
|
||||
}
|
||||
src := map[string]string{"a": "1", "b": "2"}
|
||||
cloned := CloneStringMap(src)
|
||||
if len(cloned) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(cloned))
|
||||
}
|
||||
cloned["a"] = "changed"
|
||||
if src["a"] != "1" {
|
||||
t.Error("mutating clone leaked back to source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTailFileContent(t *testing.T) {
|
||||
// Missing file → empty, no error.
|
||||
got, err := TailFileContent(filepath.Join(t.TempDir(), "missing"), 10)
|
||||
if err != nil || got != "" {
|
||||
t.Errorf("missing: got (%q, %v), want ('', nil)", got, err)
|
||||
}
|
||||
|
||||
path := filepath.Join(t.TempDir(), "log")
|
||||
lines := "one\ntwo\nthree\nfour\nfive"
|
||||
if err := os.WriteFile(path, []byte(lines), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
full, err := TailFileContent(path, 0)
|
||||
if err != nil || full != lines {
|
||||
t.Errorf("0 lines: got (%q, %v), want (%q, nil)", full, err, lines)
|
||||
}
|
||||
|
||||
// Request more lines than exist → full content.
|
||||
all, err := TailFileContent(path, 999)
|
||||
if err != nil || all != lines {
|
||||
t.Errorf("999 lines: got %q", all)
|
||||
}
|
||||
|
||||
last2, err := TailFileContent(path, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("2 lines: %v", err)
|
||||
}
|
||||
if !strings.Contains(last2, "five") {
|
||||
t.Errorf("2 lines missing last line: %q", last2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessAlive(t *testing.T) {
|
||||
if ProcessAlive(0) {
|
||||
t.Error("pid 0 should not be alive")
|
||||
}
|
||||
if ProcessAlive(-1) {
|
||||
t.Error("negative pid should not be alive")
|
||||
}
|
||||
// Swap the syscall seam.
|
||||
original := syscallKill
|
||||
t.Cleanup(func() { syscallKill = original })
|
||||
|
||||
syscallKill = func(pid int, signal os.Signal) error { return nil }
|
||||
if !ProcessAlive(42) {
|
||||
t.Error("syscallKill=nil should report alive")
|
||||
}
|
||||
|
||||
syscallKill = func(pid int, signal os.Signal) error { return fmt.Errorf("no such process") }
|
||||
if ProcessAlive(42) {
|
||||
t.Error("syscallKill error should report dead")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatStepError(t *testing.T) {
|
||||
base := errors.New("boom")
|
||||
err := FormatStepError("prepare", base, "")
|
||||
if !errors.Is(err, base) {
|
||||
t.Error("FormatStepError should wrap the base error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "prepare") {
|
||||
t.Errorf("missing action: %v", err)
|
||||
}
|
||||
|
||||
errWithLog := FormatStepError("prepare", base, " log line\n")
|
||||
if !strings.Contains(errWithLog.Error(), "log line") {
|
||||
t.Errorf("missing log: %v", errWithLog)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStateHappyPath(t *testing.T) {
|
||||
raw := `status=running
|
||||
pid=123
|
||||
exit=
|
||||
alive=true
|
||||
error=
|
||||
`
|
||||
snap, err := ParseState(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseState: %v", err)
|
||||
}
|
||||
if snap.Status != "running" {
|
||||
t.Errorf("Status = %q", snap.Status)
|
||||
}
|
||||
if snap.GuestPID != 123 {
|
||||
t.Errorf("GuestPID = %d", snap.GuestPID)
|
||||
}
|
||||
if snap.ExitCode != nil {
|
||||
t.Errorf("ExitCode should be nil when empty, got %v", snap.ExitCode)
|
||||
}
|
||||
if !snap.Alive {
|
||||
t.Error("Alive should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStateWithExit(t *testing.T) {
|
||||
raw := `status=exited
|
||||
pid=123
|
||||
exit=7
|
||||
alive=false
|
||||
error=something bad
|
||||
`
|
||||
snap, err := ParseState(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseState: %v", err)
|
||||
}
|
||||
if snap.ExitCode == nil || *snap.ExitCode != 7 {
|
||||
t.Errorf("ExitCode = %v, want 7", snap.ExitCode)
|
||||
}
|
||||
if snap.LastError != "something bad" {
|
||||
t.Errorf("LastError = %q", snap.LastError)
|
||||
}
|
||||
if snap.Alive {
|
||||
t.Error("Alive should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStateIgnoresMalformedLines(t *testing.T) {
|
||||
raw := "no-equals-here\nstatus=ok\n"
|
||||
snap, err := ParseState(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseState: %v", err)
|
||||
}
|
||||
if snap.Status != "ok" {
|
||||
t.Errorf("Status = %q, want ok", snap.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectStateFromDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeFile := func(name, content string) {
|
||||
if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(%s): %v", name, err)
|
||||
}
|
||||
}
|
||||
writeFile("status", "running\n")
|
||||
writeFile("pid", "42\n")
|
||||
writeFile("exit_code", "0\n")
|
||||
writeFile("error", "\n")
|
||||
|
||||
original := syscallKill
|
||||
t.Cleanup(func() { syscallKill = original })
|
||||
syscallKill = func(pid int, signal os.Signal) error { return nil }
|
||||
|
||||
snap, err := InspectStateFromDir(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("InspectStateFromDir: %v", err)
|
||||
}
|
||||
if snap.Status != "running" {
|
||||
t.Errorf("Status = %q", snap.Status)
|
||||
}
|
||||
if snap.GuestPID != 42 {
|
||||
t.Errorf("GuestPID = %d", snap.GuestPID)
|
||||
}
|
||||
if snap.ExitCode == nil || *snap.ExitCode != 0 {
|
||||
t.Errorf("ExitCode = %v, want 0", snap.ExitCode)
|
||||
}
|
||||
if !snap.Alive {
|
||||
t.Error("Alive should reflect syscallKill result (true)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectStateFromDirMissingFiles(t *testing.T) {
|
||||
snap, err := InspectStateFromDir(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("InspectStateFromDir (empty): %v", err)
|
||||
}
|
||||
if snap.Status != "" || snap.GuestPID != 0 || snap.ExitCode != nil {
|
||||
t.Errorf("empty dir: snap = %+v", snap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotNilReceiver(t *testing.T) {
|
||||
ApplyStateSnapshot(nil, StateSnapshot{}, true) // should not panic
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotExitedSuccess(t *testing.T) {
|
||||
exit := 0
|
||||
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning, Attachable: true, Reattachable: true}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true)
|
||||
if sess.Status != model.GuestSessionStatusExited {
|
||||
t.Errorf("Status = %q, want exited", sess.Status)
|
||||
}
|
||||
if sess.Attachable || sess.Reattachable {
|
||||
t.Error("attach flags should be cleared on exit")
|
||||
}
|
||||
if sess.EndedAt.IsZero() {
|
||||
t.Error("EndedAt should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotExitedFailure(t *testing.T) {
|
||||
exit := 2
|
||||
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true)
|
||||
if sess.Status != model.GuestSessionStatusFailed {
|
||||
t.Errorf("Status = %q, want failed", sess.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotVMGone(t *testing.T) {
|
||||
sess := &model.GuestSession{Status: model.GuestSessionStatusRunning}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{Alive: false}, false)
|
||||
if sess.Status != model.GuestSessionStatusFailed {
|
||||
t.Errorf("Status = %q, want failed", sess.Status)
|
||||
}
|
||||
if sess.LastError == "" {
|
||||
t.Error("LastError should be populated when VM is gone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotRunningStatusSetsAttachableForPipe(t *testing.T) {
|
||||
// When the guest-side status file reports "running" (Alive=false from
|
||||
// kill -0 may still fail transiently), ApplyStateSnapshot transitions
|
||||
// the session to running and sets attach flags for pipe-mode.
|
||||
sess := &model.GuestSession{
|
||||
Status: model.GuestSessionStatusStarting,
|
||||
StdinMode: model.GuestSessionStdinPipe,
|
||||
}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{Status: string(model.GuestSessionStatusRunning), GuestPID: 11}, true)
|
||||
if sess.Status != model.GuestSessionStatusRunning {
|
||||
t.Errorf("Status = %q, want running", sess.Status)
|
||||
}
|
||||
if !sess.Attachable || !sess.Reattachable {
|
||||
t.Error("pipe-mode running session should be attachable + reattachable")
|
||||
}
|
||||
if sess.AttachBackend != AttachBackendSSHBridge {
|
||||
t.Errorf("AttachBackend = %q, want %q", sess.AttachBackend, AttachBackendSSHBridge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStateSnapshotAliveEarlyReturn(t *testing.T) {
|
||||
// Alive-true returns immediately after setting status; no attach
|
||||
// flags set on this path (by design — attach metadata only attaches
|
||||
// to status-driven transitions).
|
||||
sess := &model.GuestSession{
|
||||
Status: model.GuestSessionStatusStarting,
|
||||
StdinMode: model.GuestSessionStdinPipe,
|
||||
}
|
||||
ApplyStateSnapshot(sess, StateSnapshot{Alive: true, GuestPID: 11}, true)
|
||||
if sess.Status != model.GuestSessionStatusRunning {
|
||||
t.Errorf("Status = %q, want running", sess.Status)
|
||||
}
|
||||
if sess.StartedAt.IsZero() {
|
||||
t.Error("StartedAt should have been set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateChanged(t *testing.T) {
|
||||
base := model.GuestSession{Status: model.GuestSessionStatusRunning, GuestPID: 10}
|
||||
|
||||
// Identical → no change.
|
||||
if StateChanged(base, base) {
|
||||
t.Error("identical states should not be considered changed")
|
||||
}
|
||||
|
||||
// Status change.
|
||||
changed := base
|
||||
changed.Status = model.GuestSessionStatusExited
|
||||
if !StateChanged(base, changed) {
|
||||
t.Error("status change should be detected")
|
||||
}
|
||||
|
||||
// ExitCode change from nil → value.
|
||||
exit := 3
|
||||
changed = base
|
||||
changed.ExitCode = &exit
|
||||
if !StateChanged(base, changed) {
|
||||
t.Error("exit-code appearing should be detected")
|
||||
}
|
||||
|
||||
// Both have the same exit code → no change.
|
||||
a := base
|
||||
a.ExitCode = &exit
|
||||
b := base
|
||||
b.ExitCode = &exit
|
||||
if StateChanged(a, b) {
|
||||
t.Error("matching exit codes should not trigger change")
|
||||
}
|
||||
|
||||
// Different exit codes.
|
||||
other := 5
|
||||
b.ExitCode = &other
|
||||
if !StateChanged(a, b) {
|
||||
t.Error("differing exit codes should be detected")
|
||||
}
|
||||
|
||||
// Timestamp change.
|
||||
changed = base
|
||||
changed.StartedAt = time.Now()
|
||||
if !StateChanged(base, changed) {
|
||||
t.Error("StartedAt change should be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailLaunch(t *testing.T) {
|
||||
in := model.GuestSession{Status: model.GuestSessionStatusStarting, Attachable: true}
|
||||
out := FailLaunch(in, "provision", " ssh did not come up ", " raw output\n")
|
||||
if out.Status != model.GuestSessionStatusFailed {
|
||||
t.Errorf("Status = %q, want failed", out.Status)
|
||||
}
|
||||
if out.LastError != "ssh did not come up" {
|
||||
t.Errorf("LastError = %q (not trimmed?)", out.LastError)
|
||||
}
|
||||
if out.LaunchStage != "provision" || out.LaunchMessage != "ssh did not come up" {
|
||||
t.Errorf("launch fields not set: %+v", out)
|
||||
}
|
||||
if out.LaunchRawLog != "raw output" {
|
||||
t.Errorf("rawLog = %q (not trimmed?)", out.LaunchRawLog)
|
||||
}
|
||||
if out.Attachable {
|
||||
t.Error("Attachable should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeRequiredCommands(t *testing.T) {
|
||||
got := NormalizeRequiredCommands("pi", []string{"pi", "git", "", "git", " ", "make"})
|
||||
want := []string{"pi", "git", "make"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len = %d, want %d (%v)", len(got), len(want), got)
|
||||
}
|
||||
for i, v := range want {
|
||||
if got[i] != v {
|
||||
t.Errorf("position %d: got %q, want %q", i, got[i], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectScriptContainsAllStateFiles(t *testing.T) {
|
||||
script := InspectScript("sess-abc")
|
||||
for _, key := range []string{"status", "pid", "exit_code", "error", "alive"} {
|
||||
if !strings.Contains(script, key) {
|
||||
t.Errorf("script missing %q:\n%s", key, script)
|
||||
}
|
||||
}
|
||||
if !strings.Contains(script, "sess-abc") {
|
||||
t.Error("script missing session id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalScriptIncludesSignalAndDirPaths(t *testing.T) {
|
||||
script := SignalScript("sess-x", "TERM")
|
||||
if !strings.Contains(script, "TERM") {
|
||||
t.Error("missing signal")
|
||||
}
|
||||
if !strings.Contains(script, "sess-x") {
|
||||
t.Error("missing session id")
|
||||
}
|
||||
if !strings.Contains(script, "monitor_pid") || !strings.Contains(script, "stdin_keepalive") {
|
||||
t.Errorf("expected both monitor + stdin_keepalive kills, got:\n%s", script)
|
||||
}
|
||||
}
|
||||
258
internal/hostnat/runner_test.go
Normal file
258
internal/hostnat/runner_test.go
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
package hostnat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type call struct {
|
||||
sudo bool
|
||||
name string
|
||||
args []string
|
||||
}
|
||||
|
||||
type fakeRunner struct {
|
||||
calls []call
|
||||
// runResp maps "name arg0 arg1 ..." (Run, no sudo) to a scripted
|
||||
// (stdout, err) pair. Missing entries return error.
|
||||
runResp map[string]callResp
|
||||
// sudoMatcher decides whether a RunSudo call succeeds. If nil, all
|
||||
// RunSudo calls succeed with empty stdout.
|
||||
sudoMatcher func(args []string) ([]byte, error)
|
||||
}
|
||||
|
||||
type callResp struct {
|
||||
out []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *fakeRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
c := call{name: name, args: append([]string(nil), args...)}
|
||||
r.calls = append(r.calls, c)
|
||||
key := name + " " + strings.Join(args, " ")
|
||||
if resp, ok := r.runResp[key]; ok {
|
||||
return resp.out, resp.err
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected Run: %s", key)
|
||||
}
|
||||
|
||||
func (r *fakeRunner) RunSudo(ctx context.Context, args ...string) ([]byte, error) {
|
||||
c := call{sudo: true, args: append([]string(nil), args...)}
|
||||
r.calls = append(r.calls, c)
|
||||
if r.sudoMatcher != nil {
|
||||
return r.sudoMatcher(args)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestDefaultUplink(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &fakeRunner{
|
||||
runResp: map[string]callResp{
|
||||
"ip route show default": {out: []byte("default via 10.0.0.1 dev wlan0 proto dhcp\n")},
|
||||
},
|
||||
}
|
||||
got, err := DefaultUplink(context.Background(), r)
|
||||
if err != nil {
|
||||
t.Fatalf("DefaultUplink: %v", err)
|
||||
}
|
||||
if got != "wlan0" {
|
||||
t.Fatalf("got %q, want wlan0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultUplinkPropagatesRunError(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &fakeRunner{}
|
||||
_, err := DefaultUplink(context.Background(), r)
|
||||
if err == nil {
|
||||
t.Fatal("expected error from DefaultUplink when Run fails")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleKey(t *testing.T) {
|
||||
rule := Rule{Table: "nat", Chain: "POSTROUTING", Args: []string{"-s", "172.16.0.5/32"}}
|
||||
key := RuleKey(rule)
|
||||
if !strings.Contains(key, "nat") || !strings.Contains(key, "POSTROUTING") || !strings.Contains(key, "172.16.0.5/32") {
|
||||
t.Fatalf("key missing expected parts: %q", key)
|
||||
}
|
||||
|
||||
// Different args → different key.
|
||||
other := Rule{Table: "nat", Chain: "POSTROUTING", Args: []string{"-s", "10.0.0.5/32"}}
|
||||
if RuleKey(rule) == RuleKey(other) {
|
||||
t.Fatal("RuleKey should differ for different args")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureEnableInstallsRules(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &fakeRunner{
|
||||
runResp: map[string]callResp{
|
||||
"ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")},
|
||||
},
|
||||
sudoMatcher: func(args []string) ([]byte, error) {
|
||||
// The first sudo call is sysctl; every subsequent call is
|
||||
// `iptables -C ...` (probe) followed by `iptables -A ...`
|
||||
// because the probe should report the rule is NOT present.
|
||||
if args[0] == "sysctl" {
|
||||
return nil, nil
|
||||
}
|
||||
if args[0] != "iptables" {
|
||||
return nil, fmt.Errorf("unexpected sudo prefix: %v", args)
|
||||
}
|
||||
// Fail -C (rule absent) so Ensure issues -A.
|
||||
for _, a := range args {
|
||||
if a == "-C" {
|
||||
return nil, errors.New("rule absent")
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", true); err != nil {
|
||||
t.Fatalf("Ensure: %v", err)
|
||||
}
|
||||
|
||||
// Expect at least: 1 ip route, 1 sysctl, and for 3 rules: -C + -A = 6 iptables calls.
|
||||
if len(r.calls) < 8 {
|
||||
t.Fatalf("call count = %d, want >= 8; calls=%+v", len(r.calls), r.calls)
|
||||
}
|
||||
// First call is ip route; second is sysctl.
|
||||
if r.calls[0].name != "ip" {
|
||||
t.Errorf("calls[0] = %+v, want ip route", r.calls[0])
|
||||
}
|
||||
if !r.calls[1].sudo || r.calls[1].args[0] != "sysctl" {
|
||||
t.Errorf("calls[1] = %+v, want sudo sysctl", r.calls[1])
|
||||
}
|
||||
// Somewhere we must have an iptables -A POSTROUTING call.
|
||||
var sawAppend bool
|
||||
for _, c := range r.calls {
|
||||
if c.sudo && len(c.args) >= 3 && c.args[0] == "iptables" && contains(c.args, "-A") && contains(c.args, "POSTROUTING") {
|
||||
sawAppend = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sawAppend {
|
||||
t.Fatal("no iptables -A POSTROUTING call observed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureEnableSkipsAppendWhenRulePresent(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &fakeRunner{
|
||||
runResp: map[string]callResp{
|
||||
"ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")},
|
||||
},
|
||||
sudoMatcher: func(args []string) ([]byte, error) {
|
||||
// Probe succeeds → Ensure should NOT follow up with -A.
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", true); err != nil {
|
||||
t.Fatalf("Ensure: %v", err)
|
||||
}
|
||||
|
||||
// No -A iptables calls should have been issued.
|
||||
for _, c := range r.calls {
|
||||
if c.sudo && contains(c.args, "iptables") && contains(c.args, "-A") {
|
||||
t.Fatalf("unexpected -A call with probe success: %+v", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureDisableRemovesRulesWhenPresent(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &fakeRunner{
|
||||
runResp: map[string]callResp{
|
||||
"ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")},
|
||||
},
|
||||
sudoMatcher: func(args []string) ([]byte, error) {
|
||||
// Every probe succeeds → rule is present → -D is issued.
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", false); err != nil {
|
||||
t.Fatalf("Ensure(disable): %v", err)
|
||||
}
|
||||
var sawDelete bool
|
||||
for _, c := range r.calls {
|
||||
if c.sudo && contains(c.args, "iptables") && contains(c.args, "-D") {
|
||||
sawDelete = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sawDelete {
|
||||
t.Fatal("expected at least one iptables -D call")
|
||||
}
|
||||
// No sysctl on disable path.
|
||||
for _, c := range r.calls {
|
||||
if c.sudo && len(c.args) > 0 && c.args[0] == "sysctl" {
|
||||
t.Fatal("sysctl should not run on disable path")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureDisableSkipsRemovalWhenAbsent(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &fakeRunner{
|
||||
runResp: map[string]callResp{
|
||||
"ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")},
|
||||
},
|
||||
sudoMatcher: func(args []string) ([]byte, error) {
|
||||
return nil, errors.New("rule not present")
|
||||
},
|
||||
}
|
||||
if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", false); err != nil {
|
||||
t.Fatalf("Ensure(disable, absent): %v", err)
|
||||
}
|
||||
for _, c := range r.calls {
|
||||
if c.sudo && contains(c.args, "iptables") && contains(c.args, "-D") {
|
||||
t.Fatalf("unexpected -D with absent rule: %+v", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePropagatesUplinkError(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &fakeRunner{} // no runResp → ip route fails
|
||||
err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", true)
|
||||
if err == nil {
|
||||
t.Fatal("expected uplink error to propagate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureValidatesInputs(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &fakeRunner{
|
||||
runResp: map[string]callResp{
|
||||
"ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")},
|
||||
},
|
||||
}
|
||||
if err := Ensure(context.Background(), r, "", "tap-x", true); err == nil {
|
||||
t.Fatal("expected error for empty guestIP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleArgsWithoutTable(t *testing.T) {
|
||||
// Sanity: RuleArgs should only prepend -t when Table is set.
|
||||
bare := Rule{Chain: "FORWARD", Args: []string{"-i", "eth0"}}
|
||||
got := RuleArgs("-A", bare)
|
||||
want := []string{"-A", "FORWARD", "-i", "eth0"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(xs []string, target string) bool {
|
||||
for _, x := range xs {
|
||||
if x == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
214
internal/store/guest_session_test.go
Normal file
214
internal/store/guest_session_test.go
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"banger/internal/model"
|
||||
)
|
||||
|
||||
func sampleGuestSession(id, vmID, name string) model.GuestSession {
|
||||
now := fixedTime()
|
||||
exit := 7
|
||||
return model.GuestSession{
|
||||
ID: id,
|
||||
VMID: vmID,
|
||||
Name: name,
|
||||
Backend: "ssh",
|
||||
AttachBackend: "vsock",
|
||||
AttachMode: "rpc",
|
||||
Command: "pi",
|
||||
Args: []string{"--mode", "rpc"},
|
||||
CWD: "/root/repo",
|
||||
Env: map[string]string{"FOO": "bar"},
|
||||
StdinMode: model.GuestSessionStdinMode("pipe"),
|
||||
Status: model.GuestSessionStatus("exited"),
|
||||
ExitCode: &exit,
|
||||
GuestPID: 1234,
|
||||
GuestStateDir: "/tmp/guest-" + id,
|
||||
StdoutLogPath: "/tmp/" + id + ".stdout",
|
||||
StderrLogPath: "/tmp/" + id + ".stderr",
|
||||
Tags: map[string]string{"role": "planner"},
|
||||
LastError: "",
|
||||
Attachable: true,
|
||||
Reattachable: true,
|
||||
LaunchStage: "started",
|
||||
LaunchMessage: "ok",
|
||||
LaunchRawLog: "boot log...",
|
||||
CreatedAt: now,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now,
|
||||
EndedAt: now.Add(time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
// openTestStoreWithVMs opens a fresh store seeded with the given VM IDs so
|
||||
// guest_sessions FK constraints are satisfied. Each VM gets a minimal
|
||||
// image it references.
|
||||
func openTestStoreWithVMs(t *testing.T, vmIDs ...string) *Store {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
store := openTestStore(t)
|
||||
|
||||
image := sampleImage("stub-image")
|
||||
if err := store.UpsertImage(ctx, image); err != nil {
|
||||
t.Fatalf("UpsertImage: %v", err)
|
||||
}
|
||||
for i, id := range vmIDs {
|
||||
vm := sampleVM(id, image.ID, fmt.Sprintf("172.16.0.%d", i+2))
|
||||
vm.ID = id
|
||||
if err := store.UpsertVM(ctx, vm); err != nil {
|
||||
t.Fatalf("UpsertVM(%s): %v", id, err)
|
||||
}
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func TestGuestSessionUpsertAndGetByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
store := openTestStoreWithVMs(t, "vm-1")
|
||||
|
||||
session := sampleGuestSession("sess-1", "vm-1", "planner")
|
||||
if err := store.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
got, err := store.GetGuestSessionByID(ctx, "sess-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetGuestSessionByID: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, session) {
|
||||
t.Fatalf("round-trip mismatch:\n got %+v\n want %+v", got, session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuestSessionUpsertIsIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
store := openTestStoreWithVMs(t, "vm-1")
|
||||
|
||||
session := sampleGuestSession("sess-1", "vm-1", "planner")
|
||||
if err := store.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession (first): %v", err)
|
||||
}
|
||||
|
||||
// Mutate + re-upsert → existing row updated.
|
||||
session.Command = "pi --other"
|
||||
session.Status = model.GuestSessionStatus("running")
|
||||
session.ExitCode = nil
|
||||
if err := store.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession (second): %v", err)
|
||||
}
|
||||
|
||||
got, err := store.GetGuestSessionByID(ctx, "sess-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetGuestSessionByID: %v", err)
|
||||
}
|
||||
if got.Command != "pi --other" {
|
||||
t.Errorf("command = %q, want 'pi --other'", got.Command)
|
||||
}
|
||||
if got.Status != model.GuestSessionStatus("running") {
|
||||
t.Errorf("status = %q, want running", got.Status)
|
||||
}
|
||||
if got.ExitCode != nil {
|
||||
t.Errorf("ExitCode = %v, want nil after clearing", got.ExitCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGuestSessionByIDOrName(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
store := openTestStoreWithVMs(t, "vm-1")
|
||||
|
||||
session := sampleGuestSession("sess-1", "vm-1", "planner")
|
||||
if err := store.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
|
||||
byID, err := store.GetGuestSession(ctx, "vm-1", "sess-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetGuestSession by ID: %v", err)
|
||||
}
|
||||
if byID.ID != "sess-1" {
|
||||
t.Errorf("by-ID: got %q, want sess-1", byID.ID)
|
||||
}
|
||||
|
||||
byName, err := store.GetGuestSession(ctx, "vm-1", "planner")
|
||||
if err != nil {
|
||||
t.Fatalf("GetGuestSession by name: %v", err)
|
||||
}
|
||||
if byName.Name != "planner" {
|
||||
t.Errorf("by-name: got %q, want planner", byName.Name)
|
||||
}
|
||||
|
||||
// Scoped to the VM.
|
||||
if _, err := store.GetGuestSession(ctx, "vm-unknown", "sess-1"); !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Errorf("wrong-vm lookup = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListGuestSessionsByVMOrdersByCreatedAt(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
store := openTestStoreWithVMs(t, "vm-1", "vm-2")
|
||||
|
||||
base := fixedTime()
|
||||
first := sampleGuestSession("sess-early", "vm-1", "first")
|
||||
first.CreatedAt = base
|
||||
second := sampleGuestSession("sess-late", "vm-1", "second")
|
||||
second.CreatedAt = base.Add(time.Hour)
|
||||
other := sampleGuestSession("sess-other", "vm-2", "other")
|
||||
|
||||
for _, s := range []model.GuestSession{second, first, other} {
|
||||
if err := store.UpsertGuestSession(ctx, s); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
sessions, err := store.ListGuestSessionsByVM(ctx, "vm-1")
|
||||
if err != nil {
|
||||
t.Fatalf("ListGuestSessionsByVM: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Fatalf("len = %d, want 2 (vm-1 only)", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != "sess-early" || sessions[1].ID != "sess-late" {
|
||||
t.Fatalf("order: got %q, %q; want sess-early, sess-late", sessions[0].ID, sessions[1].ID)
|
||||
}
|
||||
|
||||
empty, err := store.ListGuestSessionsByVM(ctx, "vm-unknown")
|
||||
if err != nil {
|
||||
t.Fatalf("ListGuestSessionsByVM (unknown vm): %v", err)
|
||||
}
|
||||
if len(empty) != 0 {
|
||||
t.Fatalf("unknown vm sessions = %+v, want empty", empty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteGuestSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
store := openTestStoreWithVMs(t, "vm-1")
|
||||
|
||||
session := sampleGuestSession("sess-1", "vm-1", "planner")
|
||||
if err := store.UpsertGuestSession(ctx, session); err != nil {
|
||||
t.Fatalf("UpsertGuestSession: %v", err)
|
||||
}
|
||||
if err := store.DeleteGuestSession(ctx, "sess-1"); err != nil {
|
||||
t.Fatalf("DeleteGuestSession: %v", err)
|
||||
}
|
||||
if _, err := store.GetGuestSessionByID(ctx, "sess-1"); !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("after delete err = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
|
||||
// Deleting something that doesn't exist is a no-op (matches SQL DELETE semantics).
|
||||
if err := store.DeleteGuestSession(ctx, "sess-nope"); err != nil {
|
||||
t.Fatalf("DeleteGuestSession on missing row: %v", err)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue