diff --git a/internal/daemon/images_helpers_test.go b/internal/daemon/images_helpers_test.go new file mode 100644 index 0000000..0615820 --- /dev/null +++ b/internal/daemon/images_helpers_test.go @@ -0,0 +1,24 @@ +package daemon + +import "testing" + +func TestFirstNonEmpty(t *testing.T) { + cases := []struct { + name string + values []string + want string + }{ + {"all empty", []string{"", " ", "\t"}, ""}, + {"first wins", []string{"a", "b"}, "a"}, + {"skips blanks", []string{"", " ", "first", "second"}, "first"}, + {"nil input", nil, ""}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := firstNonEmpty(tc.values...) + if got != tc.want { + t.Errorf("firstNonEmpty(%v) = %q, want %q", tc.values, got, tc.want) + } + }) + } +} diff --git a/internal/daemon/opstate/registry_test.go b/internal/daemon/opstate/registry_test.go new file mode 100644 index 0000000..2ea56b7 --- /dev/null +++ b/internal/daemon/opstate/registry_test.go @@ -0,0 +1,74 @@ +package opstate + +import ( + "sync/atomic" + "testing" + "time" +) + +type fakeOp struct { + id string + done atomic.Bool + updatedAt time.Time + canceled atomic.Bool +} + +func (f *fakeOp) ID() string { return f.id } +func (f *fakeOp) IsDone() bool { return f.done.Load() } +func (f *fakeOp) UpdatedAt() time.Time { return f.updatedAt } +func (f *fakeOp) Cancel() { f.canceled.Store(true) } + +func TestRegistryInsertAndGet(t *testing.T) { + var r Registry[*fakeOp] + op := &fakeOp{id: "op-1", updatedAt: time.Now()} + r.Insert(op) + got, ok := r.Get("op-1") + if !ok { + t.Fatal("Get after Insert missed") + } + if got.ID() != "op-1" { + t.Fatalf("Get().ID = %q", got.ID()) + } + + _, ok = r.Get("missing") + if ok { + t.Fatal("Get on missing key should miss") + } +} + +func TestRegistryPruneDropsCompletedOldOps(t *testing.T) { + var r Registry[*fakeOp] + now := time.Now() + + recent := &fakeOp{id: "recent", updatedAt: now} + recent.done.Store(true) + + stale := &fakeOp{id: "stale", updatedAt: now.Add(-time.Hour)} + stale.done.Store(true) + + pending := &fakeOp{id: "pending", updatedAt: now.Add(-time.Hour)} + // NOT done → stays even though old. + + r.Insert(recent) + r.Insert(stale) + r.Insert(pending) + + cutoff := now.Add(-time.Minute) + r.Prune(cutoff) + + if _, ok := r.Get("stale"); ok { + t.Error("stale op should have been pruned") + } + if _, ok := r.Get("recent"); !ok { + t.Error("recent op should survive (newer than cutoff)") + } + if _, ok := r.Get("pending"); !ok { + t.Error("pending op should survive (not done)") + } +} + +func TestRegistryPruneNoOpOnEmpty(t *testing.T) { + var r Registry[*fakeOp] + // Just shouldn't panic. + r.Prune(time.Now()) +} diff --git a/internal/daemon/session/session_test.go b/internal/daemon/session/session_test.go new file mode 100644 index 0000000..ec093f2 --- /dev/null +++ b/internal/daemon/session/session_test.go @@ -0,0 +1,440 @@ +package session + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "banger/internal/model" + + "golang.org/x/crypto/ssh" +) + +func TestRelativeStateDir(t *testing.T) { + got := RelativeStateDir("abc") + if strings.HasPrefix(got, "/root/") { + t.Fatalf("RelativeStateDir(%q) = %q, should strip /root/ prefix", "abc", got) + } + if !strings.Contains(got, "abc") { + t.Fatalf("missing session id in %q", got) + } + absolute := StateDir("abc") + if got != strings.TrimPrefix(absolute, "/root/") { + t.Fatalf("relative = %q, want %q", got, strings.TrimPrefix(absolute, "/root/")) + } +} + +func TestDefaultCWD(t *testing.T) { + if DefaultCWD("") != "/root" { + t.Error("empty should return /root") + } + if DefaultCWD(" ") != "/root" { + t.Error("whitespace should return /root") + } + if DefaultCWD("/work") != "/work" { + t.Error("explicit should pass through") + } +} + +func TestShellQuote(t *testing.T) { + if got := ShellQuote(""); got != "''" { + t.Errorf("empty: got %q, want ''", got) + } + if got := ShellQuote("x"); got != "'x'" { + t.Errorf("plain: got %q", got) + } + if got := ShellQuote("it's"); got != `'it'"'"'s'` { + t.Errorf("apostrophe: got %q", got) + } +} + +func TestExitCode(t *testing.T) { + if code, ok := ExitCode(nil); !ok || code != 0 { + t.Errorf("nil err: got (%d, %v), want (0, true)", code, ok) + } + // Build an ssh.ExitError using its real type — can't hand-construct, + // so wrap via errors.As check with a stub. + raw := &ssh.ExitError{} + if _, ok := ExitCode(raw); !ok { + t.Error("ssh.ExitError: ok should be true") + } + if _, ok := ExitCode(errors.New("bare error")); ok { + t.Error("bare error: ok should be false") + } +} + +func TestCloneStringMap(t *testing.T) { + if CloneStringMap(nil) != nil { + t.Error("nil in → nil out") + } + if CloneStringMap(map[string]string{}) != nil { + t.Error("empty in → nil out") + } + src := map[string]string{"a": "1", "b": "2"} + cloned := CloneStringMap(src) + if len(cloned) != 2 { + t.Fatalf("len = %d, want 2", len(cloned)) + } + cloned["a"] = "changed" + if src["a"] != "1" { + t.Error("mutating clone leaked back to source") + } +} + +func TestTailFileContent(t *testing.T) { + // Missing file → empty, no error. + got, err := TailFileContent(filepath.Join(t.TempDir(), "missing"), 10) + if err != nil || got != "" { + t.Errorf("missing: got (%q, %v), want ('', nil)", got, err) + } + + path := filepath.Join(t.TempDir(), "log") + lines := "one\ntwo\nthree\nfour\nfive" + if err := os.WriteFile(path, []byte(lines), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + full, err := TailFileContent(path, 0) + if err != nil || full != lines { + t.Errorf("0 lines: got (%q, %v), want (%q, nil)", full, err, lines) + } + + // Request more lines than exist → full content. + all, err := TailFileContent(path, 999) + if err != nil || all != lines { + t.Errorf("999 lines: got %q", all) + } + + last2, err := TailFileContent(path, 2) + if err != nil { + t.Fatalf("2 lines: %v", err) + } + if !strings.Contains(last2, "five") { + t.Errorf("2 lines missing last line: %q", last2) + } +} + +func TestProcessAlive(t *testing.T) { + if ProcessAlive(0) { + t.Error("pid 0 should not be alive") + } + if ProcessAlive(-1) { + t.Error("negative pid should not be alive") + } + // Swap the syscall seam. + original := syscallKill + t.Cleanup(func() { syscallKill = original }) + + syscallKill = func(pid int, signal os.Signal) error { return nil } + if !ProcessAlive(42) { + t.Error("syscallKill=nil should report alive") + } + + syscallKill = func(pid int, signal os.Signal) error { return fmt.Errorf("no such process") } + if ProcessAlive(42) { + t.Error("syscallKill error should report dead") + } +} + +func TestFormatStepError(t *testing.T) { + base := errors.New("boom") + err := FormatStepError("prepare", base, "") + if !errors.Is(err, base) { + t.Error("FormatStepError should wrap the base error") + } + if !strings.Contains(err.Error(), "prepare") { + t.Errorf("missing action: %v", err) + } + + errWithLog := FormatStepError("prepare", base, " log line\n") + if !strings.Contains(errWithLog.Error(), "log line") { + t.Errorf("missing log: %v", errWithLog) + } +} + +func TestParseStateHappyPath(t *testing.T) { + raw := `status=running +pid=123 +exit= +alive=true +error= +` + snap, err := ParseState(raw) + if err != nil { + t.Fatalf("ParseState: %v", err) + } + if snap.Status != "running" { + t.Errorf("Status = %q", snap.Status) + } + if snap.GuestPID != 123 { + t.Errorf("GuestPID = %d", snap.GuestPID) + } + if snap.ExitCode != nil { + t.Errorf("ExitCode should be nil when empty, got %v", snap.ExitCode) + } + if !snap.Alive { + t.Error("Alive should be true") + } +} + +func TestParseStateWithExit(t *testing.T) { + raw := `status=exited +pid=123 +exit=7 +alive=false +error=something bad +` + snap, err := ParseState(raw) + if err != nil { + t.Fatalf("ParseState: %v", err) + } + if snap.ExitCode == nil || *snap.ExitCode != 7 { + t.Errorf("ExitCode = %v, want 7", snap.ExitCode) + } + if snap.LastError != "something bad" { + t.Errorf("LastError = %q", snap.LastError) + } + if snap.Alive { + t.Error("Alive should be false") + } +} + +func TestParseStateIgnoresMalformedLines(t *testing.T) { + raw := "no-equals-here\nstatus=ok\n" + snap, err := ParseState(raw) + if err != nil { + t.Fatalf("ParseState: %v", err) + } + if snap.Status != "ok" { + t.Errorf("Status = %q, want ok", snap.Status) + } +} + +func TestInspectStateFromDir(t *testing.T) { + dir := t.TempDir() + writeFile := func(name, content string) { + if err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0o600); err != nil { + t.Fatalf("WriteFile(%s): %v", name, err) + } + } + writeFile("status", "running\n") + writeFile("pid", "42\n") + writeFile("exit_code", "0\n") + writeFile("error", "\n") + + original := syscallKill + t.Cleanup(func() { syscallKill = original }) + syscallKill = func(pid int, signal os.Signal) error { return nil } + + snap, err := InspectStateFromDir(dir) + if err != nil { + t.Fatalf("InspectStateFromDir: %v", err) + } + if snap.Status != "running" { + t.Errorf("Status = %q", snap.Status) + } + if snap.GuestPID != 42 { + t.Errorf("GuestPID = %d", snap.GuestPID) + } + if snap.ExitCode == nil || *snap.ExitCode != 0 { + t.Errorf("ExitCode = %v, want 0", snap.ExitCode) + } + if !snap.Alive { + t.Error("Alive should reflect syscallKill result (true)") + } +} + +func TestInspectStateFromDirMissingFiles(t *testing.T) { + snap, err := InspectStateFromDir(t.TempDir()) + if err != nil { + t.Fatalf("InspectStateFromDir (empty): %v", err) + } + if snap.Status != "" || snap.GuestPID != 0 || snap.ExitCode != nil { + t.Errorf("empty dir: snap = %+v", snap) + } +} + +func TestApplyStateSnapshotNilReceiver(t *testing.T) { + ApplyStateSnapshot(nil, StateSnapshot{}, true) // should not panic +} + +func TestApplyStateSnapshotExitedSuccess(t *testing.T) { + exit := 0 + sess := &model.GuestSession{Status: model.GuestSessionStatusRunning, Attachable: true, Reattachable: true} + ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true) + if sess.Status != model.GuestSessionStatusExited { + t.Errorf("Status = %q, want exited", sess.Status) + } + if sess.Attachable || sess.Reattachable { + t.Error("attach flags should be cleared on exit") + } + if sess.EndedAt.IsZero() { + t.Error("EndedAt should be set") + } +} + +func TestApplyStateSnapshotExitedFailure(t *testing.T) { + exit := 2 + sess := &model.GuestSession{Status: model.GuestSessionStatusRunning} + ApplyStateSnapshot(sess, StateSnapshot{ExitCode: &exit}, true) + if sess.Status != model.GuestSessionStatusFailed { + t.Errorf("Status = %q, want failed", sess.Status) + } +} + +func TestApplyStateSnapshotVMGone(t *testing.T) { + sess := &model.GuestSession{Status: model.GuestSessionStatusRunning} + ApplyStateSnapshot(sess, StateSnapshot{Alive: false}, false) + if sess.Status != model.GuestSessionStatusFailed { + t.Errorf("Status = %q, want failed", sess.Status) + } + if sess.LastError == "" { + t.Error("LastError should be populated when VM is gone") + } +} + +func TestApplyStateSnapshotRunningStatusSetsAttachableForPipe(t *testing.T) { + // When the guest-side status file reports "running" (Alive=false from + // kill -0 may still fail transiently), ApplyStateSnapshot transitions + // the session to running and sets attach flags for pipe-mode. + sess := &model.GuestSession{ + Status: model.GuestSessionStatusStarting, + StdinMode: model.GuestSessionStdinPipe, + } + ApplyStateSnapshot(sess, StateSnapshot{Status: string(model.GuestSessionStatusRunning), GuestPID: 11}, true) + if sess.Status != model.GuestSessionStatusRunning { + t.Errorf("Status = %q, want running", sess.Status) + } + if !sess.Attachable || !sess.Reattachable { + t.Error("pipe-mode running session should be attachable + reattachable") + } + if sess.AttachBackend != AttachBackendSSHBridge { + t.Errorf("AttachBackend = %q, want %q", sess.AttachBackend, AttachBackendSSHBridge) + } +} + +func TestApplyStateSnapshotAliveEarlyReturn(t *testing.T) { + // Alive-true returns immediately after setting status; no attach + // flags set on this path (by design — attach metadata only attaches + // to status-driven transitions). + sess := &model.GuestSession{ + Status: model.GuestSessionStatusStarting, + StdinMode: model.GuestSessionStdinPipe, + } + ApplyStateSnapshot(sess, StateSnapshot{Alive: true, GuestPID: 11}, true) + if sess.Status != model.GuestSessionStatusRunning { + t.Errorf("Status = %q, want running", sess.Status) + } + if sess.StartedAt.IsZero() { + t.Error("StartedAt should have been set") + } +} + +func TestStateChanged(t *testing.T) { + base := model.GuestSession{Status: model.GuestSessionStatusRunning, GuestPID: 10} + + // Identical → no change. + if StateChanged(base, base) { + t.Error("identical states should not be considered changed") + } + + // Status change. + changed := base + changed.Status = model.GuestSessionStatusExited + if !StateChanged(base, changed) { + t.Error("status change should be detected") + } + + // ExitCode change from nil → value. + exit := 3 + changed = base + changed.ExitCode = &exit + if !StateChanged(base, changed) { + t.Error("exit-code appearing should be detected") + } + + // Both have the same exit code → no change. + a := base + a.ExitCode = &exit + b := base + b.ExitCode = &exit + if StateChanged(a, b) { + t.Error("matching exit codes should not trigger change") + } + + // Different exit codes. + other := 5 + b.ExitCode = &other + if !StateChanged(a, b) { + t.Error("differing exit codes should be detected") + } + + // Timestamp change. + changed = base + changed.StartedAt = time.Now() + if !StateChanged(base, changed) { + t.Error("StartedAt change should be detected") + } +} + +func TestFailLaunch(t *testing.T) { + in := model.GuestSession{Status: model.GuestSessionStatusStarting, Attachable: true} + out := FailLaunch(in, "provision", " ssh did not come up ", " raw output\n") + if out.Status != model.GuestSessionStatusFailed { + t.Errorf("Status = %q, want failed", out.Status) + } + if out.LastError != "ssh did not come up" { + t.Errorf("LastError = %q (not trimmed?)", out.LastError) + } + if out.LaunchStage != "provision" || out.LaunchMessage != "ssh did not come up" { + t.Errorf("launch fields not set: %+v", out) + } + if out.LaunchRawLog != "raw output" { + t.Errorf("rawLog = %q (not trimmed?)", out.LaunchRawLog) + } + if out.Attachable { + t.Error("Attachable should be cleared") + } +} + +func TestNormalizeRequiredCommands(t *testing.T) { + got := NormalizeRequiredCommands("pi", []string{"pi", "git", "", "git", " ", "make"}) + want := []string{"pi", "git", "make"} + if len(got) != len(want) { + t.Fatalf("len = %d, want %d (%v)", len(got), len(want), got) + } + for i, v := range want { + if got[i] != v { + t.Errorf("position %d: got %q, want %q", i, got[i], v) + } + } +} + +func TestInspectScriptContainsAllStateFiles(t *testing.T) { + script := InspectScript("sess-abc") + for _, key := range []string{"status", "pid", "exit_code", "error", "alive"} { + if !strings.Contains(script, key) { + t.Errorf("script missing %q:\n%s", key, script) + } + } + if !strings.Contains(script, "sess-abc") { + t.Error("script missing session id") + } +} + +func TestSignalScriptIncludesSignalAndDirPaths(t *testing.T) { + script := SignalScript("sess-x", "TERM") + if !strings.Contains(script, "TERM") { + t.Error("missing signal") + } + if !strings.Contains(script, "sess-x") { + t.Error("missing session id") + } + if !strings.Contains(script, "monitor_pid") || !strings.Contains(script, "stdin_keepalive") { + t.Errorf("expected both monitor + stdin_keepalive kills, got:\n%s", script) + } +} diff --git a/internal/hostnat/runner_test.go b/internal/hostnat/runner_test.go new file mode 100644 index 0000000..7853e53 --- /dev/null +++ b/internal/hostnat/runner_test.go @@ -0,0 +1,258 @@ +package hostnat + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + "testing" +) + +type call struct { + sudo bool + name string + args []string +} + +type fakeRunner struct { + calls []call + // runResp maps "name arg0 arg1 ..." (Run, no sudo) to a scripted + // (stdout, err) pair. Missing entries return error. + runResp map[string]callResp + // sudoMatcher decides whether a RunSudo call succeeds. If nil, all + // RunSudo calls succeed with empty stdout. + sudoMatcher func(args []string) ([]byte, error) +} + +type callResp struct { + out []byte + err error +} + +func (r *fakeRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) { + c := call{name: name, args: append([]string(nil), args...)} + r.calls = append(r.calls, c) + key := name + " " + strings.Join(args, " ") + if resp, ok := r.runResp[key]; ok { + return resp.out, resp.err + } + return nil, fmt.Errorf("unexpected Run: %s", key) +} + +func (r *fakeRunner) RunSudo(ctx context.Context, args ...string) ([]byte, error) { + c := call{sudo: true, args: append([]string(nil), args...)} + r.calls = append(r.calls, c) + if r.sudoMatcher != nil { + return r.sudoMatcher(args) + } + return nil, nil +} + +func TestDefaultUplink(t *testing.T) { + t.Parallel() + r := &fakeRunner{ + runResp: map[string]callResp{ + "ip route show default": {out: []byte("default via 10.0.0.1 dev wlan0 proto dhcp\n")}, + }, + } + got, err := DefaultUplink(context.Background(), r) + if err != nil { + t.Fatalf("DefaultUplink: %v", err) + } + if got != "wlan0" { + t.Fatalf("got %q, want wlan0", got) + } +} + +func TestDefaultUplinkPropagatesRunError(t *testing.T) { + t.Parallel() + r := &fakeRunner{} + _, err := DefaultUplink(context.Background(), r) + if err == nil { + t.Fatal("expected error from DefaultUplink when Run fails") + } +} + +func TestRuleKey(t *testing.T) { + rule := Rule{Table: "nat", Chain: "POSTROUTING", Args: []string{"-s", "172.16.0.5/32"}} + key := RuleKey(rule) + if !strings.Contains(key, "nat") || !strings.Contains(key, "POSTROUTING") || !strings.Contains(key, "172.16.0.5/32") { + t.Fatalf("key missing expected parts: %q", key) + } + + // Different args → different key. + other := Rule{Table: "nat", Chain: "POSTROUTING", Args: []string{"-s", "10.0.0.5/32"}} + if RuleKey(rule) == RuleKey(other) { + t.Fatal("RuleKey should differ for different args") + } +} + +func TestEnsureEnableInstallsRules(t *testing.T) { + t.Parallel() + r := &fakeRunner{ + runResp: map[string]callResp{ + "ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")}, + }, + sudoMatcher: func(args []string) ([]byte, error) { + // The first sudo call is sysctl; every subsequent call is + // `iptables -C ...` (probe) followed by `iptables -A ...` + // because the probe should report the rule is NOT present. + if args[0] == "sysctl" { + return nil, nil + } + if args[0] != "iptables" { + return nil, fmt.Errorf("unexpected sudo prefix: %v", args) + } + // Fail -C (rule absent) so Ensure issues -A. + for _, a := range args { + if a == "-C" { + return nil, errors.New("rule absent") + } + } + return nil, nil + }, + } + + if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", true); err != nil { + t.Fatalf("Ensure: %v", err) + } + + // Expect at least: 1 ip route, 1 sysctl, and for 3 rules: -C + -A = 6 iptables calls. + if len(r.calls) < 8 { + t.Fatalf("call count = %d, want >= 8; calls=%+v", len(r.calls), r.calls) + } + // First call is ip route; second is sysctl. + if r.calls[0].name != "ip" { + t.Errorf("calls[0] = %+v, want ip route", r.calls[0]) + } + if !r.calls[1].sudo || r.calls[1].args[0] != "sysctl" { + t.Errorf("calls[1] = %+v, want sudo sysctl", r.calls[1]) + } + // Somewhere we must have an iptables -A POSTROUTING call. + var sawAppend bool + for _, c := range r.calls { + if c.sudo && len(c.args) >= 3 && c.args[0] == "iptables" && contains(c.args, "-A") && contains(c.args, "POSTROUTING") { + sawAppend = true + break + } + } + if !sawAppend { + t.Fatal("no iptables -A POSTROUTING call observed") + } +} + +func TestEnsureEnableSkipsAppendWhenRulePresent(t *testing.T) { + t.Parallel() + r := &fakeRunner{ + runResp: map[string]callResp{ + "ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")}, + }, + sudoMatcher: func(args []string) ([]byte, error) { + // Probe succeeds → Ensure should NOT follow up with -A. + return nil, nil + }, + } + if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", true); err != nil { + t.Fatalf("Ensure: %v", err) + } + + // No -A iptables calls should have been issued. + for _, c := range r.calls { + if c.sudo && contains(c.args, "iptables") && contains(c.args, "-A") { + t.Fatalf("unexpected -A call with probe success: %+v", c) + } + } +} + +func TestEnsureDisableRemovesRulesWhenPresent(t *testing.T) { + t.Parallel() + r := &fakeRunner{ + runResp: map[string]callResp{ + "ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")}, + }, + sudoMatcher: func(args []string) ([]byte, error) { + // Every probe succeeds → rule is present → -D is issued. + return nil, nil + }, + } + if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", false); err != nil { + t.Fatalf("Ensure(disable): %v", err) + } + var sawDelete bool + for _, c := range r.calls { + if c.sudo && contains(c.args, "iptables") && contains(c.args, "-D") { + sawDelete = true + break + } + } + if !sawDelete { + t.Fatal("expected at least one iptables -D call") + } + // No sysctl on disable path. + for _, c := range r.calls { + if c.sudo && len(c.args) > 0 && c.args[0] == "sysctl" { + t.Fatal("sysctl should not run on disable path") + } + } +} + +func TestEnsureDisableSkipsRemovalWhenAbsent(t *testing.T) { + t.Parallel() + r := &fakeRunner{ + runResp: map[string]callResp{ + "ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")}, + }, + sudoMatcher: func(args []string) ([]byte, error) { + return nil, errors.New("rule not present") + }, + } + if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", false); err != nil { + t.Fatalf("Ensure(disable, absent): %v", err) + } + for _, c := range r.calls { + if c.sudo && contains(c.args, "iptables") && contains(c.args, "-D") { + t.Fatalf("unexpected -D with absent rule: %+v", c) + } + } +} + +func TestEnsurePropagatesUplinkError(t *testing.T) { + t.Parallel() + r := &fakeRunner{} // no runResp → ip route fails + err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", true) + if err == nil { + t.Fatal("expected uplink error to propagate") + } +} + +func TestEnsureValidatesInputs(t *testing.T) { + t.Parallel() + r := &fakeRunner{ + runResp: map[string]callResp{ + "ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")}, + }, + } + if err := Ensure(context.Background(), r, "", "tap-x", true); err == nil { + t.Fatal("expected error for empty guestIP") + } +} + +func TestRuleArgsWithoutTable(t *testing.T) { + // Sanity: RuleArgs should only prepend -t when Table is set. + bare := Rule{Chain: "FORWARD", Args: []string{"-i", "eth0"}} + got := RuleArgs("-A", bare) + want := []string{"-A", "FORWARD", "-i", "eth0"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %v, want %v", got, want) + } +} + +func contains(xs []string, target string) bool { + for _, x := range xs { + if x == target { + return true + } + } + return false +} diff --git a/internal/store/guest_session_test.go b/internal/store/guest_session_test.go new file mode 100644 index 0000000..eff1477 --- /dev/null +++ b/internal/store/guest_session_test.go @@ -0,0 +1,214 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "banger/internal/model" +) + +func sampleGuestSession(id, vmID, name string) model.GuestSession { + now := fixedTime() + exit := 7 + return model.GuestSession{ + ID: id, + VMID: vmID, + Name: name, + Backend: "ssh", + AttachBackend: "vsock", + AttachMode: "rpc", + Command: "pi", + Args: []string{"--mode", "rpc"}, + CWD: "/root/repo", + Env: map[string]string{"FOO": "bar"}, + StdinMode: model.GuestSessionStdinMode("pipe"), + Status: model.GuestSessionStatus("exited"), + ExitCode: &exit, + GuestPID: 1234, + GuestStateDir: "/tmp/guest-" + id, + StdoutLogPath: "/tmp/" + id + ".stdout", + StderrLogPath: "/tmp/" + id + ".stderr", + Tags: map[string]string{"role": "planner"}, + LastError: "", + Attachable: true, + Reattachable: true, + LaunchStage: "started", + LaunchMessage: "ok", + LaunchRawLog: "boot log...", + CreatedAt: now, + StartedAt: now, + UpdatedAt: now, + EndedAt: now.Add(time.Minute), + } +} + +// openTestStoreWithVMs opens a fresh store seeded with the given VM IDs so +// guest_sessions FK constraints are satisfied. Each VM gets a minimal +// image it references. +func openTestStoreWithVMs(t *testing.T, vmIDs ...string) *Store { + t.Helper() + ctx := context.Background() + store := openTestStore(t) + + image := sampleImage("stub-image") + if err := store.UpsertImage(ctx, image); err != nil { + t.Fatalf("UpsertImage: %v", err) + } + for i, id := range vmIDs { + vm := sampleVM(id, image.ID, fmt.Sprintf("172.16.0.%d", i+2)) + vm.ID = id + if err := store.UpsertVM(ctx, vm); err != nil { + t.Fatalf("UpsertVM(%s): %v", id, err) + } + } + return store +} + +func TestGuestSessionUpsertAndGetByID(t *testing.T) { + t.Parallel() + ctx := context.Background() + store := openTestStoreWithVMs(t, "vm-1") + + session := sampleGuestSession("sess-1", "vm-1", "planner") + if err := store.UpsertGuestSession(ctx, session); err != nil { + t.Fatalf("UpsertGuestSession: %v", err) + } + + got, err := store.GetGuestSessionByID(ctx, "sess-1") + if err != nil { + t.Fatalf("GetGuestSessionByID: %v", err) + } + if !reflect.DeepEqual(got, session) { + t.Fatalf("round-trip mismatch:\n got %+v\n want %+v", got, session) + } +} + +func TestGuestSessionUpsertIsIdempotent(t *testing.T) { + t.Parallel() + ctx := context.Background() + store := openTestStoreWithVMs(t, "vm-1") + + session := sampleGuestSession("sess-1", "vm-1", "planner") + if err := store.UpsertGuestSession(ctx, session); err != nil { + t.Fatalf("UpsertGuestSession (first): %v", err) + } + + // Mutate + re-upsert → existing row updated. + session.Command = "pi --other" + session.Status = model.GuestSessionStatus("running") + session.ExitCode = nil + if err := store.UpsertGuestSession(ctx, session); err != nil { + t.Fatalf("UpsertGuestSession (second): %v", err) + } + + got, err := store.GetGuestSessionByID(ctx, "sess-1") + if err != nil { + t.Fatalf("GetGuestSessionByID: %v", err) + } + if got.Command != "pi --other" { + t.Errorf("command = %q, want 'pi --other'", got.Command) + } + if got.Status != model.GuestSessionStatus("running") { + t.Errorf("status = %q, want running", got.Status) + } + if got.ExitCode != nil { + t.Errorf("ExitCode = %v, want nil after clearing", got.ExitCode) + } +} + +func TestGetGuestSessionByIDOrName(t *testing.T) { + t.Parallel() + ctx := context.Background() + store := openTestStoreWithVMs(t, "vm-1") + + session := sampleGuestSession("sess-1", "vm-1", "planner") + if err := store.UpsertGuestSession(ctx, session); err != nil { + t.Fatalf("UpsertGuestSession: %v", err) + } + + byID, err := store.GetGuestSession(ctx, "vm-1", "sess-1") + if err != nil { + t.Fatalf("GetGuestSession by ID: %v", err) + } + if byID.ID != "sess-1" { + t.Errorf("by-ID: got %q, want sess-1", byID.ID) + } + + byName, err := store.GetGuestSession(ctx, "vm-1", "planner") + if err != nil { + t.Fatalf("GetGuestSession by name: %v", err) + } + if byName.Name != "planner" { + t.Errorf("by-name: got %q, want planner", byName.Name) + } + + // Scoped to the VM. + if _, err := store.GetGuestSession(ctx, "vm-unknown", "sess-1"); !errors.Is(err, sql.ErrNoRows) { + t.Errorf("wrong-vm lookup = %v, want sql.ErrNoRows", err) + } +} + +func TestListGuestSessionsByVMOrdersByCreatedAt(t *testing.T) { + t.Parallel() + ctx := context.Background() + store := openTestStoreWithVMs(t, "vm-1", "vm-2") + + base := fixedTime() + first := sampleGuestSession("sess-early", "vm-1", "first") + first.CreatedAt = base + second := sampleGuestSession("sess-late", "vm-1", "second") + second.CreatedAt = base.Add(time.Hour) + other := sampleGuestSession("sess-other", "vm-2", "other") + + for _, s := range []model.GuestSession{second, first, other} { + if err := store.UpsertGuestSession(ctx, s); err != nil { + t.Fatalf("UpsertGuestSession: %v", err) + } + } + + sessions, err := store.ListGuestSessionsByVM(ctx, "vm-1") + if err != nil { + t.Fatalf("ListGuestSessionsByVM: %v", err) + } + if len(sessions) != 2 { + t.Fatalf("len = %d, want 2 (vm-1 only)", len(sessions)) + } + if sessions[0].ID != "sess-early" || sessions[1].ID != "sess-late" { + t.Fatalf("order: got %q, %q; want sess-early, sess-late", sessions[0].ID, sessions[1].ID) + } + + empty, err := store.ListGuestSessionsByVM(ctx, "vm-unknown") + if err != nil { + t.Fatalf("ListGuestSessionsByVM (unknown vm): %v", err) + } + if len(empty) != 0 { + t.Fatalf("unknown vm sessions = %+v, want empty", empty) + } +} + +func TestDeleteGuestSession(t *testing.T) { + t.Parallel() + ctx := context.Background() + store := openTestStoreWithVMs(t, "vm-1") + + session := sampleGuestSession("sess-1", "vm-1", "planner") + if err := store.UpsertGuestSession(ctx, session); err != nil { + t.Fatalf("UpsertGuestSession: %v", err) + } + if err := store.DeleteGuestSession(ctx, "sess-1"); err != nil { + t.Fatalf("DeleteGuestSession: %v", err) + } + if _, err := store.GetGuestSessionByID(ctx, "sess-1"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("after delete err = %v, want sql.ErrNoRows", err) + } + + // Deleting something that doesn't exist is a no-op (matches SQL DELETE semantics). + if err := store.DeleteGuestSession(ctx, "sess-nope"); err != nil { + t.Fatalf("DeleteGuestSession on missing row: %v", err) + } +}