diff --git a/internal/daemon/guest_sessions.go b/internal/daemon/guest_sessions.go index 6dd3938..0477e40 100644 --- a/internal/daemon/guest_sessions.go +++ b/internal/daemon/guest_sessions.go @@ -1,7 +1,6 @@ package daemon import ( - "bufio" "bytes" "context" "errors" @@ -10,27 +9,13 @@ import ( "net" "os" "path/filepath" - "sort" - "strconv" "strings" - "syscall" "time" + "banger/internal/daemon/session" "banger/internal/guest" "banger/internal/model" "banger/internal/system" - - "golang.org/x/crypto/ssh" -) - -const ( - guestSessionBackendSSH = "ssh" - guestSessionAttachBackendNone = "none" - guestSessionAttachBackendSSHBridge = "ssh_rehydratable" - guestSessionAttachModeExclusive = "exclusive" - guestSessionTransportUnixSocket = "unix_socket" - guestSessionStateRoot = "/root/.local/state/banger/sessions" - guestSessionLogTailLine = 200 ) var guestSessionHostCommandOutputFunc = func(ctx context.Context, name string, args ...string) ([]byte, error) { @@ -70,178 +55,94 @@ func (d *Daemon) dialGuest(ctx context.Context, address string) (guestSSHClient, return guest.Dial(ctx, address, d.config.SSHKeyPath) } -func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, session model.GuestSession) (model.GuestSession, error) { +func (d *Daemon) waitForGuestSessionReadyHook(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) { if d != nil && d.waitForGuestSessionReady != nil { - return d.waitForGuestSessionReady(ctx, vm, session) + return d.waitForGuestSessionReady(ctx, vm, s) } - return d.waitForGuestSessionReadyDefault(ctx, vm, session) + return d.waitForGuestSessionReadyDefault(ctx, vm, s) } -func (d *Daemon) waitForGuestSessionReadyDefault(ctx context.Context, vm model.VMRecord, session model.GuestSession) (model.GuestSession, error) { +func (d *Daemon) waitForGuestSessionReadyDefault(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) { for { - updated, err := d.refreshGuestSession(ctx, vm, session) + updated, err := d.refreshGuestSession(ctx, vm, s) if err == nil { - session = updated - if session.GuestPID != 0 || session.ExitCode != nil || session.Status == model.GuestSessionStatusRunning || session.Status == model.GuestSessionStatusFailed || session.Status == model.GuestSessionStatusExited { - return session, nil + s = updated + if s.GuestPID != 0 || s.ExitCode != nil || s.Status == model.GuestSessionStatusRunning || s.Status == model.GuestSessionStatusFailed || s.Status == model.GuestSessionStatusExited { + return s, nil } } select { case <-ctx.Done(): - return session, ctx.Err() + return s, ctx.Err() case <-time.After(100 * time.Millisecond): } } } -func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, session model.GuestSession) (model.GuestSession, error) { - if session.Status != model.GuestSessionStatusStarting && session.Status != model.GuestSessionStatusRunning && session.Status != model.GuestSessionStatusStopping { - return session, nil +func (d *Daemon) refreshGuestSession(ctx context.Context, vm model.VMRecord, s model.GuestSession) (model.GuestSession, error) { + if s.Status != model.GuestSessionStatusStarting && s.Status != model.GuestSessionStatusRunning && s.Status != model.GuestSessionStatusStopping { + return s, nil } - snapshot, err := d.inspectGuestSessionState(ctx, vm, session) + snapshot, err := d.inspectGuestSessionState(ctx, vm, s) if err != nil { - return session, err + return s, err } - original := session - applyGuestSessionSnapshot(&session, snapshot, vm.State == model.VMStateRunning && system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath)) - if guestSessionStateChanged(original, session) { - session.UpdatedAt = model.Now() - if err := d.store.UpsertGuestSession(ctx, session); err != nil { - return session, err + original := s + session.ApplyStateSnapshot(&s, snapshot, vm.State == model.VMStateRunning && system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath)) + if session.StateChanged(original, s) { + s.UpdatedAt = model.Now() + if err := d.store.UpsertGuestSession(ctx, s); err != nil { + return s, err } } - return session, nil + return s, nil } -func applyGuestSessionSnapshot(session *model.GuestSession, snapshot guestSessionStateSnapshot, vmRunning bool) { - if session == nil { - return - } - if snapshot.GuestPID != 0 { - session.GuestPID = snapshot.GuestPID - } - if snapshot.LastError != "" { - session.LastError = snapshot.LastError - } - if snapshot.ExitCode != nil { - session.ExitCode = snapshot.ExitCode - session.Attachable = false - session.Reattachable = false - if session.StartedAt.IsZero() { - session.StartedAt = model.Now() - } - if session.EndedAt.IsZero() { - session.EndedAt = model.Now() - } - if *snapshot.ExitCode == 0 { - session.Status = model.GuestSessionStatusExited - } else { - session.Status = model.GuestSessionStatusFailed - } - return - } - if snapshot.Alive { - if session.StartedAt.IsZero() { - session.StartedAt = model.Now() - } - session.Status = model.GuestSessionStatusRunning - return - } - if !vmRunning && (session.Status == model.GuestSessionStatusStarting || session.Status == model.GuestSessionStatusRunning || session.Status == model.GuestSessionStatusStopping) { - session.Status = model.GuestSessionStatusFailed - session.Attachable = false - session.Reattachable = false - if session.LastError == "" { - session.LastError = "vm is not running" - } - if session.EndedAt.IsZero() { - session.EndedAt = model.Now() - } - return - } - if snapshot.Status == string(model.GuestSessionStatusRunning) { - if session.StartedAt.IsZero() { - session.StartedAt = model.Now() - } - session.Status = model.GuestSessionStatusRunning - } - if session.Status == model.GuestSessionStatusRunning && session.StdinMode == model.GuestSessionStdinPipe { - session.Attachable = true - session.Reattachable = true - if session.AttachBackend == "" { - session.AttachBackend = guestSessionAttachBackendSSHBridge - } - if session.AttachMode == "" { - session.AttachMode = guestSessionAttachModeExclusive - } - } -} - -func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, session model.GuestSession) (guestSessionStateSnapshot, error) { +func (d *Daemon) inspectGuestSessionState(ctx context.Context, vm model.VMRecord, s model.GuestSession) (session.StateSnapshot, error) { if vm.State == model.VMStateRunning && system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { client, err := guest.Dial(ctx, net.JoinHostPort(vm.Runtime.GuestIP, "22"), d.config.SSHKeyPath) if err != nil { - return guestSessionStateSnapshot{}, err + return session.StateSnapshot{}, err } defer client.Close() var output bytes.Buffer - if err := client.RunScript(ctx, guestSessionInspectScript(session.ID), &output); err != nil { - return guestSessionStateSnapshot{}, formatGuestSessionStepError("inspect guest session state", err, output.String()) + if err := client.RunScript(ctx, session.InspectScript(s.ID), &output); err != nil { + return session.StateSnapshot{}, session.FormatStepError("inspect guest session state", err, output.String()) } - return parseGuestSessionState(output.String()) + return session.ParseState(output.String()) } - return d.inspectGuestSessionStateFromWorkDisk(ctx, vm, session.ID) + return d.inspectGuestSessionStateFromWorkDisk(ctx, vm, s.ID) } -func (d *Daemon) inspectGuestSessionStateFromWorkDisk(ctx context.Context, vm model.VMRecord, sessionID string) (guestSessionStateSnapshot, error) { +func (d *Daemon) inspectGuestSessionStateFromWorkDisk(ctx context.Context, vm model.VMRecord, sessionID string) (session.StateSnapshot, error) { runner := d.runner if runner == nil { runner = system.NewRunner() } workMount, cleanup, err := system.MountTempDir(ctx, runner, vm.Runtime.WorkDiskPath, false) if err != nil { - return guestSessionStateSnapshot{}, err + return session.StateSnapshot{}, err } defer cleanup() - stateDir := filepath.Join(workMount, guestSessionRelativeStateDir(sessionID)) - return inspectGuestSessionStateFromDir(stateDir) -} - -func inspectGuestSessionStateFromDir(stateDir string) (guestSessionStateSnapshot, error) { - var snapshot guestSessionStateSnapshot - statusData, _ := os.ReadFile(filepath.Join(stateDir, "status")) - snapshot.Status = strings.TrimSpace(string(statusData)) - pidData, _ := os.ReadFile(filepath.Join(stateDir, "pid")) - if pidValue, err := strconv.Atoi(strings.TrimSpace(string(pidData))); err == nil { - snapshot.GuestPID = pidValue - } - exitData, _ := os.ReadFile(filepath.Join(stateDir, "exit_code")) - if exitValue, err := strconv.Atoi(strings.TrimSpace(string(exitData))); err == nil { - snapshot.ExitCode = &exitValue - } - errorData, _ := os.ReadFile(filepath.Join(stateDir, "error")) - snapshot.LastError = strings.TrimSpace(string(errorData)) - if snapshot.GuestPID != 0 { - snapshot.Alive = processAlive(snapshot.GuestPID) - } - return snapshot, nil + stateDir := filepath.Join(workMount, session.RelativeStateDir(sessionID)) + return session.InspectStateFromDir(stateDir) } func (d *Daemon) findGuestSession(ctx context.Context, vmID, idOrName string) (model.GuestSession, error) { if strings.TrimSpace(idOrName) == "" { return model.GuestSession{}, errors.New("session id or name is required") } - if session, err := d.store.GetGuestSession(ctx, vmID, idOrName); err == nil { - return session, nil + if s, err := d.store.GetGuestSession(ctx, vmID, idOrName); err == nil { + return s, nil } sessions, err := d.store.ListGuestSessionsByVM(ctx, vmID) if err != nil { return model.GuestSession{}, err } matches := make([]model.GuestSession, 0, 1) - for _, session := range sessions { - if strings.HasPrefix(session.ID, idOrName) || strings.HasPrefix(session.Name, idOrName) { - matches = append(matches, session) + for _, s := range sessions { + if strings.HasPrefix(s.ID, idOrName) || strings.HasPrefix(s.Name, idOrName) { + matches = append(matches, s) } } switch len(matches) { @@ -253,364 +154,3 @@ func (d *Daemon) findGuestSession(ctx context.Context, vmID, idOrName string) (m return model.GuestSession{}, fmt.Errorf("multiple sessions match %q", idOrName) } } - -func guestSessionScript(session model.GuestSession) string { - var script strings.Builder - script.WriteString("set -euo pipefail\n") - fmt.Fprintf(&script, "STATE_DIR=%s\n", guestShellQuote(session.GuestStateDir)) - fmt.Fprintf(&script, "STDOUT_LOG=%s\n", guestShellQuote(session.StdoutLogPath)) - fmt.Fprintf(&script, "STDERR_LOG=%s\n", guestShellQuote(session.StderrLogPath)) - fmt.Fprintf(&script, "PID_FILE=%s\n", guestShellQuote(guestSessionPIDPath(session.ID))) - fmt.Fprintf(&script, "MONITOR_PID_FILE=%s\n", guestShellQuote(guestSessionMonitorPIDPath(session.ID))) - fmt.Fprintf(&script, "EXIT_FILE=%s\n", guestShellQuote(guestSessionExitCodePath(session.ID))) - fmt.Fprintf(&script, "STATUS_FILE=%s\n", guestShellQuote(guestSessionStatusPath(session.ID))) - fmt.Fprintf(&script, "ERROR_FILE=%s\n", guestShellQuote(guestSessionErrorPath(session.ID))) - fmt.Fprintf(&script, "STDIN_PIPE=%s\n", guestShellQuote(guestSessionStdinPipePath(session.ID))) - fmt.Fprintf(&script, "STDIN_KEEPALIVE_PID_FILE=%s\n", guestShellQuote(guestSessionStdinKeepalivePIDPath(session.ID))) - fmt.Fprintf(&script, "SESSION_CWD=%s\n", guestShellQuote(defaultGuestSessionCWD(session.CWD))) - script.WriteString("mkdir -p \"$STATE_DIR\"\n") - script.WriteString(": >\"$STDOUT_LOG\"\n") - script.WriteString(": >\"$STDERR_LOG\"\n") - script.WriteString("rm -f \"$EXIT_FILE\" \"$ERROR_FILE\" \"$STDIN_KEEPALIVE_PID_FILE\"\n") - if session.StdinMode == model.GuestSessionStdinPipe { - script.WriteString("rm -f \"$STDIN_PIPE\"\n") - script.WriteString("mkfifo -m 600 \"$STDIN_PIPE\"\n") - } - script.WriteString("printf '%s\\n' \"${BASHPID:-$$}\" >\"$MONITOR_PID_FILE\"\n") - script.WriteString("printf 'starting\\n' >\"$STATUS_FILE\"\n") - script.WriteString("cd \"$SESSION_CWD\"\n") - script.WriteString("exec > >(tee -a \"$STDOUT_LOG\") 2> >(tee -a \"$STDERR_LOG\" >&2)\n") - for _, line := range guestSessionEnvLines(session.Env) { - script.WriteString(line) - script.WriteByte('\n') - } - script.WriteString("COMMAND=(") - for _, value := range append([]string{session.Command}, session.Args...) { - script.WriteByte(' ') - script.WriteString(guestShellQuote(value)) - } - script.WriteString(" )\n") - if session.StdinMode == model.GuestSessionStdinPipe { - script.WriteString("( while :; do sleep 3600; done ) >\"$STDIN_PIPE\" &\n") - script.WriteString("keepalive=$!\n") - script.WriteString("printf '%s\\n' \"$keepalive\" >\"$STDIN_KEEPALIVE_PID_FILE\"\n") - script.WriteString("\"${COMMAND[@]}\" <\"$STDIN_PIPE\" &\n") - } else { - script.WriteString("\"${COMMAND[@]}\" &\n") - } - script.WriteString("child=$!\n") - script.WriteString("printf '%s\\n' \"$child\" >\"$PID_FILE\"\n") - script.WriteString("printf 'running\\n' >\"$STATUS_FILE\"\n") - script.WriteString("wait \"$child\"\n") - script.WriteString("rc=$?\n") - if session.StdinMode == model.GuestSessionStdinPipe { - script.WriteString("if [ -f \"$STDIN_KEEPALIVE_PID_FILE\" ]; then kill \"$(cat \"$STDIN_KEEPALIVE_PID_FILE\")\" 2>/dev/null || true; fi\n") - } - script.WriteString("printf '%s\\n' \"$rc\" >\"$EXIT_FILE\"\n") - script.WriteString("if [ \"$rc\" -eq 0 ]; then printf 'exited\\n' >\"$STATUS_FILE\"; else printf 'failed\\n' >\"$STATUS_FILE\"; fi\n") - script.WriteString("exit \"$rc\"\n") - return script.String() -} - -func guestSessionInspectScript(sessionID string) string { - var script strings.Builder - script.WriteString("set -euo pipefail\n") - fmt.Fprintf(&script, "DIR=%s\n", guestShellQuote(guestSessionStateDir(sessionID))) - script.WriteString("status=''\n") - script.WriteString("pid=''\n") - script.WriteString("exit_code=''\n") - script.WriteString("last_error=''\n") - script.WriteString("alive=false\n") - script.WriteString("[ -f \"$DIR/status\" ] && status=\"$(cat \"$DIR/status\")\"\n") - script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n") - script.WriteString("[ -f \"$DIR/exit_code\" ] && exit_code=\"$(cat \"$DIR/exit_code\")\"\n") - script.WriteString("[ -f \"$DIR/error\" ] && last_error=\"$(cat \"$DIR/error\")\"\n") - script.WriteString("if [ -n \"$pid\" ] && kill -0 \"$pid\" 2>/dev/null; then alive=true; fi\n") - script.WriteString("printf 'status=%s\\n' \"$status\"\n") - script.WriteString("printf 'pid=%s\\n' \"$pid\"\n") - script.WriteString("printf 'exit=%s\\n' \"$exit_code\"\n") - script.WriteString("printf 'alive=%s\\n' \"$alive\"\n") - script.WriteString("printf 'error=%s\\n' \"$last_error\"\n") - return script.String() -} - -func guestSessionSignalScript(sessionID, signal string) string { - var script strings.Builder - script.WriteString("set -euo pipefail\n") - fmt.Fprintf(&script, "DIR=%s\n", guestShellQuote(guestSessionStateDir(sessionID))) - fmt.Fprintf(&script, "SIGNAL=%s\n", guestShellQuote(signal)) - script.WriteString("pid=''\n") - script.WriteString("monitor=''\n") - script.WriteString("keepalive=''\n") - script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n") - script.WriteString("[ -f \"$DIR/monitor_pid\" ] && monitor=\"$(cat \"$DIR/monitor_pid\")\"\n") - script.WriteString("[ -f \"$DIR/stdin_keepalive.pid\" ] && keepalive=\"$(cat \"$DIR/stdin_keepalive.pid\")\"\n") - script.WriteString("printf 'stopping\\n' >\"$DIR/status\"\n") - script.WriteString("if [ -n \"$pid\" ]; then kill -${SIGNAL} \"$pid\" 2>/dev/null || true; fi\n") - script.WriteString("if [ -n \"$monitor\" ]; then kill -${SIGNAL} \"$monitor\" 2>/dev/null || true; fi\n") - script.WriteString("if [ -n \"$keepalive\" ]; then kill -${SIGNAL} \"$keepalive\" 2>/dev/null || true; fi\n") - return script.String() -} - -func guestSessionStateDir(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateRoot, id)) -} - -func guestSessionRelativeStateDir(id string) string { - return strings.TrimPrefix(guestSessionStateDir(id), "/root/") -} - -func guestSessionScriptPath(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateDir(id), "run.sh")) -} - -func guestSessionPIDPath(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateDir(id), "pid")) -} - -func guestSessionMonitorPIDPath(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateDir(id), "monitor_pid")) -} - -func guestSessionExitCodePath(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateDir(id), "exit_code")) -} - -func guestSessionStdinPipePath(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateDir(id), "stdin.pipe")) -} - -func guestSessionStdinKeepalivePIDPath(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateDir(id), "stdin_keepalive.pid")) -} - -func guestSessionStatusPath(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateDir(id), "status")) -} - -func guestSessionErrorPath(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateDir(id), "error")) -} - -func guestSessionStdoutLogPath(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateDir(id), "stdout.log")) -} - -func guestSessionStderrLogPath(id string) string { - return filepath.ToSlash(filepath.Join(guestSessionStateDir(id), "stderr.log")) -} - -func defaultGuestSessionName(id, command, explicit string) string { - if trimmed := strings.TrimSpace(explicit); trimmed != "" { - return trimmed - } - base := filepath.Base(strings.TrimSpace(command)) - if base == "." || base == string(filepath.Separator) || base == "" { - base = "session" - } - return base + "-" + system.ShortID(id) -} - -func defaultGuestSessionCWD(value string) string { - if trimmed := strings.TrimSpace(value); trimmed != "" { - return trimmed - } - return "/root" -} - -func failGuestSessionLaunch(session model.GuestSession, stage, message, rawLog string) model.GuestSession { - now := model.Now() - session.Status = model.GuestSessionStatusFailed - session.LastError = strings.TrimSpace(message) - session.Attachable = false - session.Reattachable = false - session.LaunchStage = strings.TrimSpace(stage) - session.LaunchMessage = strings.TrimSpace(message) - session.LaunchRawLog = strings.TrimSpace(rawLog) - session.UpdatedAt = now - session.EndedAt = now - return session -} - -func normalizeGuestSessionRequiredCommands(command string, extras []string) []string { - ordered := make([]string, 0, len(extras)+1) - seen := map[string]struct{}{} - appendValue := func(value string) { - trimmed := strings.TrimSpace(value) - if trimmed == "" { - return - } - if _, ok := seen[trimmed]; ok { - return - } - seen[trimmed] = struct{}{} - ordered = append(ordered, trimmed) - } - appendValue(command) - for _, extra := range extras { - appendValue(extra) - } - return ordered -} - -func guestSessionCWDPreflightScript(cwd string) string { - var script strings.Builder - script.WriteString("set -euo pipefail\n") - fmt.Fprintf(&script, "DIR=%s\n", guestShellQuote(defaultGuestSessionCWD(cwd))) - script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\n") - return script.String() -} - -func guestSessionCommandPreflightScript(commands []string) string { - var script strings.Builder - script.WriteString("set -euo pipefail\n") - script.WriteString("check_command() {\n") - script.WriteString(" cmd=\"$1\"\n") - script.WriteString(" case \"$cmd\" in\n") - script.WriteString(" */*) [ -x \"$cmd\" ] || { echo \"missing command: $cmd\"; exit 1; } ;;\n") - script.WriteString(" *) command -v \"$cmd\" >/dev/null 2>&1 || { echo \"missing command: $cmd\"; exit 1; } ;;\n") - script.WriteString(" esac\n") - script.WriteString("}\n") - for _, command := range commands { - fmt.Fprintf(&script, "check_command %s\n", guestShellQuote(command)) - } - return script.String() -} - -func guestSessionAttachInputCommand(sessionID string) string { - path := guestSessionStdinPipePath(sessionID) - return "bash -lc " + guestShellQuote(fmt.Sprintf("set -euo pipefail\n[ -p %s ] || mkfifo -m 600 %s\nexec cat > %s\n", guestShellQuote(path), guestShellQuote(path), guestShellQuote(path))) -} - -func guestSessionAttachTailCommand(path string) string { - return "bash -lc " + guestShellQuote(fmt.Sprintf("set -euo pipefail\ntouch %s\nexec tail -n 0 -F %s 2>/dev/null\n", guestShellQuote(path), guestShellQuote(path))) -} - -func guestSessionEnvLines(values map[string]string) []string { - if len(values) == 0 { - return nil - } - keys := make([]string, 0, len(values)) - for key := range values { - keys = append(keys, key) - } - sort.Strings(keys) - lines := make([]string, 0, len(keys)) - for _, key := range keys { - lines = append(lines, "export "+key+"="+guestShellQuote(values[key])) - } - return lines -} - -func guestShellQuote(value string) string { - return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" -} - -func parseGuestSessionState(raw string) (guestSessionStateSnapshot, error) { - var snapshot guestSessionStateSnapshot - scanner := bufio.NewScanner(strings.NewReader(raw)) - for scanner.Scan() { - line := scanner.Text() - key, value, ok := strings.Cut(line, "=") - if !ok { - continue - } - switch strings.TrimSpace(key) { - case "status": - snapshot.Status = strings.TrimSpace(value) - case "pid": - if pid, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { - snapshot.GuestPID = pid - } - case "exit": - if exitCode, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { - snapshot.ExitCode = &exitCode - } - case "alive": - snapshot.Alive = strings.TrimSpace(value) == "true" - case "error": - snapshot.LastError = strings.TrimSpace(value) - } - } - return snapshot, scanner.Err() -} - -func guestSessionExitCode(err error) (int, bool) { - if err == nil { - return 0, true - } - var exitErr *ssh.ExitError - if errors.As(err, &exitErr) { - return exitErr.ExitStatus(), true - } - return 0, false -} - -func cloneStringMap(values map[string]string) map[string]string { - if len(values) == 0 { - return nil - } - cloned := make(map[string]string, len(values)) - for key, value := range values { - cloned[key] = value - } - return cloned -} - -func tailFileContent(path string, lines int) (string, error) { - data, err := os.ReadFile(path) - if err != nil { - if os.IsNotExist(err) { - return "", nil - } - return "", err - } - if lines <= 0 { - return string(data), nil - } - parts := strings.Split(string(data), "\n") - if len(parts) <= lines { - return string(data), nil - } - return strings.Join(parts[len(parts)-lines-1:], "\n"), nil -} - -func processAlive(pid int) bool { - if pid <= 0 { - return false - } - return syscallKill(pid, syscall.Signal(0)) == nil -} - -var syscallKill = func(pid int, signal os.Signal) error { - proc, err := os.FindProcess(pid) - if err != nil { - return err - } - return proc.Signal(signal) -} - -func formatGuestSessionStepError(action string, err error, log string) error { - log = strings.TrimSpace(log) - if log == "" { - return fmt.Errorf("%s: %w", action, err) - } - return fmt.Errorf("%s: %w: %s", action, err, log) -} - -func guestSessionStateChanged(before, after model.GuestSession) bool { - if before.Status != after.Status || before.GuestPID != after.GuestPID || before.LastError != after.LastError || before.Attachable != after.Attachable || before.Reattachable != after.Reattachable || before.AttachBackend != after.AttachBackend || before.AttachMode != after.AttachMode || before.LaunchStage != after.LaunchStage || before.LaunchMessage != after.LaunchMessage || before.LaunchRawLog != after.LaunchRawLog { - return true - } - if before.StartedAt != after.StartedAt || before.EndedAt != after.EndedAt { - return true - } - switch { - case before.ExitCode == nil && after.ExitCode == nil: - return false - case before.ExitCode == nil || after.ExitCode == nil: - return true - default: - return *before.ExitCode != *after.ExitCode - } -} diff --git a/internal/daemon/guest_sessions_test.go b/internal/daemon/guest_sessions_test.go index fc75367..5ec5e1b 100644 --- a/internal/daemon/guest_sessions_test.go +++ b/internal/daemon/guest_sessions_test.go @@ -13,6 +13,7 @@ import ( "time" "banger/internal/api" + sess "banger/internal/daemon/session" "banger/internal/model" "banger/internal/store" ) @@ -135,7 +136,7 @@ func TestSendToGuestSession_HappyPath(t *testing.T) { t.Fatalf("RunScript call count = %d, want 1", len(fake.ranScripts)) } script := fake.ranScripts[0] - pipePath := guestSessionStdinPipePath(session.ID) + pipePath := sess.StdinPipePath(session.ID) if !strings.Contains(script, "cat ") { t.Fatalf("send script missing cat command: %q", script) } @@ -321,15 +322,15 @@ func testGuestSession(vmID string, stdinMode model.GuestSessionStdinMode, status ID: id, VMID: vmID, Name: vmID + "-sess", - Backend: guestSessionBackendSSH, + Backend: sess.BackendSSH, Command: "pi", Args: []string{"--mode", "rpc"}, CWD: "/root/repo", StdinMode: stdinMode, Status: status, - GuestStateDir: guestSessionStateDir(id), - StdoutLogPath: guestSessionStdoutLogPath(id), - StderrLogPath: guestSessionStderrLogPath(id), + GuestStateDir: sess.StateDir(id), + StdoutLogPath: sess.StdoutLogPath(id), + StderrLogPath: sess.StderrLogPath(id), Attachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning, Reattachable: stdinMode == model.GuestSessionStdinPipe && status == model.GuestSessionStatusRunning, CreatedAt: now, @@ -355,7 +356,7 @@ func startFakeFirecracker(t *testing.T, apiSock string) *exec.Cmd { func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) { t.Parallel() - cwdScript := guestSessionCWDPreflightScript("/root/repo") + cwdScript := sess.CWDPreflightScript("/root/repo") if strings.Contains(cwdScript, `\n`) { t.Fatalf("cwd preflight script still contains escaped newline literals: %q", cwdScript) } @@ -363,7 +364,7 @@ func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) { t.Fatalf("cwd preflight script should contain real newlines: %q", cwdScript) } - commandScript := guestSessionCommandPreflightScript([]string{"git", "pi"}) + commandScript := sess.CommandPreflightScript([]string{"git", "pi"}) if strings.Contains(commandScript, `\n`) { t.Fatalf("command preflight script still contains escaped newline literals: %q", commandScript) } @@ -371,12 +372,12 @@ func TestGuestSessionPreflightScriptsUseRealNewlines(t *testing.T) { t.Fatalf("command preflight script should contain real newlines: %q", commandScript) } - attachInput := guestSessionAttachInputCommand("session-id") + attachInput := sess.AttachInputCommand("session-id") if strings.Contains(attachInput, `\n`) { t.Fatalf("attach input command still contains escaped newline literals: %q", attachInput) } - attachTail := guestSessionAttachTailCommand("/tmp/stdout.log") + attachTail := sess.AttachTailCommand("/tmp/stdout.log") if strings.Contains(attachTail, `\n`) { t.Fatalf("attach tail command still contains escaped newline literals: %q", attachTail) } diff --git a/internal/daemon/session/session.go b/internal/daemon/session/session.go new file mode 100644 index 0000000..4407520 --- /dev/null +++ b/internal/daemon/session/session.go @@ -0,0 +1,509 @@ +// Package session contains the pure helpers of the guest-session subsystem: +// bash script generators, on-guest state path helpers, state snapshot +// parsing, and small utilities like ShellQuote and FormatStepError. +// +// The orchestrator methods (StartGuestSession, BeginGuestSessionAttach, +// etc.) stay on *daemon.Daemon and compose these helpers. +package session + +import ( + "bufio" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "syscall" + + "banger/internal/model" + "banger/internal/system" + + "golang.org/x/crypto/ssh" +) + +// Constants shared between orchestration and helpers. +const ( + BackendSSH = "ssh" + AttachBackendNone = "none" + AttachBackendSSHBridge = "ssh_rehydratable" + AttachModeExclusive = "exclusive" + TransportUnixSocket = "unix_socket" + StateRoot = "/root/.local/state/banger/sessions" + LogTailLineDefault = 200 +) + +// StateSnapshot is the decoded per-session state as read from the guest. +type StateSnapshot struct { + Status string + GuestPID int + ExitCode *int + Alive bool + LastError string +} + +// -- Guest filesystem paths ------------------------------------------------- + +func StateDir(id string) string { + return filepath.ToSlash(filepath.Join(StateRoot, id)) +} + +func RelativeStateDir(id string) string { + return strings.TrimPrefix(StateDir(id), "/root/") +} + +func ScriptPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "run.sh")) } +func PIDPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "pid")) } +func MonitorPIDPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "monitor_pid")) } +func ExitCodePath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "exit_code")) } +func StdinPipePath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "stdin.pipe")) } +func StdinKeepalivePIDPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "stdin_keepalive.pid")) } +func StatusPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "status")) } +func ErrorPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "error")) } +func StdoutLogPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "stdout.log")) } +func StderrLogPath(id string) string { return filepath.ToSlash(filepath.Join(StateDir(id), "stderr.log")) } + +// -- Script generators ------------------------------------------------------ + +// Script returns the bash runner installed into the guest for session. It +// sets up state/log paths, optional stdin fifo, and wait-loop around the +// user command. +func Script(sess model.GuestSession) string { + var script strings.Builder + script.WriteString("set -euo pipefail\n") + fmt.Fprintf(&script, "STATE_DIR=%s\n", ShellQuote(sess.GuestStateDir)) + fmt.Fprintf(&script, "STDOUT_LOG=%s\n", ShellQuote(sess.StdoutLogPath)) + fmt.Fprintf(&script, "STDERR_LOG=%s\n", ShellQuote(sess.StderrLogPath)) + fmt.Fprintf(&script, "PID_FILE=%s\n", ShellQuote(PIDPath(sess.ID))) + fmt.Fprintf(&script, "MONITOR_PID_FILE=%s\n", ShellQuote(MonitorPIDPath(sess.ID))) + fmt.Fprintf(&script, "EXIT_FILE=%s\n", ShellQuote(ExitCodePath(sess.ID))) + fmt.Fprintf(&script, "STATUS_FILE=%s\n", ShellQuote(StatusPath(sess.ID))) + fmt.Fprintf(&script, "ERROR_FILE=%s\n", ShellQuote(ErrorPath(sess.ID))) + fmt.Fprintf(&script, "STDIN_PIPE=%s\n", ShellQuote(StdinPipePath(sess.ID))) + fmt.Fprintf(&script, "STDIN_KEEPALIVE_PID_FILE=%s\n", ShellQuote(StdinKeepalivePIDPath(sess.ID))) + fmt.Fprintf(&script, "SESSION_CWD=%s\n", ShellQuote(DefaultCWD(sess.CWD))) + script.WriteString("mkdir -p \"$STATE_DIR\"\n") + script.WriteString(": >\"$STDOUT_LOG\"\n") + script.WriteString(": >\"$STDERR_LOG\"\n") + script.WriteString("rm -f \"$EXIT_FILE\" \"$ERROR_FILE\" \"$STDIN_KEEPALIVE_PID_FILE\"\n") + if sess.StdinMode == model.GuestSessionStdinPipe { + script.WriteString("rm -f \"$STDIN_PIPE\"\n") + script.WriteString("mkfifo -m 600 \"$STDIN_PIPE\"\n") + } + script.WriteString("printf '%s\\n' \"${BASHPID:-$$}\" >\"$MONITOR_PID_FILE\"\n") + script.WriteString("printf 'starting\\n' >\"$STATUS_FILE\"\n") + script.WriteString("cd \"$SESSION_CWD\"\n") + script.WriteString("exec > >(tee -a \"$STDOUT_LOG\") 2> >(tee -a \"$STDERR_LOG\" >&2)\n") + for _, line := range EnvLines(sess.Env) { + script.WriteString(line) + script.WriteByte('\n') + } + script.WriteString("COMMAND=(") + for _, value := range append([]string{sess.Command}, sess.Args...) { + script.WriteByte(' ') + script.WriteString(ShellQuote(value)) + } + script.WriteString(" )\n") + if sess.StdinMode == model.GuestSessionStdinPipe { + script.WriteString("( while :; do sleep 3600; done ) >\"$STDIN_PIPE\" &\n") + script.WriteString("keepalive=$!\n") + script.WriteString("printf '%s\\n' \"$keepalive\" >\"$STDIN_KEEPALIVE_PID_FILE\"\n") + script.WriteString("\"${COMMAND[@]}\" <\"$STDIN_PIPE\" &\n") + } else { + script.WriteString("\"${COMMAND[@]}\" &\n") + } + script.WriteString("child=$!\n") + script.WriteString("printf '%s\\n' \"$child\" >\"$PID_FILE\"\n") + script.WriteString("printf 'running\\n' >\"$STATUS_FILE\"\n") + script.WriteString("wait \"$child\"\n") + script.WriteString("rc=$?\n") + if sess.StdinMode == model.GuestSessionStdinPipe { + script.WriteString("if [ -f \"$STDIN_KEEPALIVE_PID_FILE\" ]; then kill \"$(cat \"$STDIN_KEEPALIVE_PID_FILE\")\" 2>/dev/null || true; fi\n") + } + script.WriteString("printf '%s\\n' \"$rc\" >\"$EXIT_FILE\"\n") + script.WriteString("if [ \"$rc\" -eq 0 ]; then printf 'exited\\n' >\"$STATUS_FILE\"; else printf 'failed\\n' >\"$STATUS_FILE\"; fi\n") + script.WriteString("exit \"$rc\"\n") + return script.String() +} + +// InspectScript reads the on-guest state files for sessionID and prints a +// key=value block parseable by ParseState. +func InspectScript(sessionID string) string { + var script strings.Builder + script.WriteString("set -euo pipefail\n") + fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID))) + script.WriteString("status=''\n") + script.WriteString("pid=''\n") + script.WriteString("exit_code=''\n") + script.WriteString("last_error=''\n") + script.WriteString("alive=false\n") + script.WriteString("[ -f \"$DIR/status\" ] && status=\"$(cat \"$DIR/status\")\"\n") + script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n") + script.WriteString("[ -f \"$DIR/exit_code\" ] && exit_code=\"$(cat \"$DIR/exit_code\")\"\n") + script.WriteString("[ -f \"$DIR/error\" ] && last_error=\"$(cat \"$DIR/error\")\"\n") + script.WriteString("if [ -n \"$pid\" ] && kill -0 \"$pid\" 2>/dev/null; then alive=true; fi\n") + script.WriteString("printf 'status=%s\\n' \"$status\"\n") + script.WriteString("printf 'pid=%s\\n' \"$pid\"\n") + script.WriteString("printf 'exit=%s\\n' \"$exit_code\"\n") + script.WriteString("printf 'alive=%s\\n' \"$alive\"\n") + script.WriteString("printf 'error=%s\\n' \"$last_error\"\n") + return script.String() +} + +// SignalScript sends signal to sessionID's runner and monitor processes. +func SignalScript(sessionID, signal string) string { + var script strings.Builder + script.WriteString("set -euo pipefail\n") + fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(StateDir(sessionID))) + fmt.Fprintf(&script, "SIGNAL=%s\n", ShellQuote(signal)) + script.WriteString("pid=''\n") + script.WriteString("monitor=''\n") + script.WriteString("keepalive=''\n") + script.WriteString("[ -f \"$DIR/pid\" ] && pid=\"$(cat \"$DIR/pid\")\"\n") + script.WriteString("[ -f \"$DIR/monitor_pid\" ] && monitor=\"$(cat \"$DIR/monitor_pid\")\"\n") + script.WriteString("[ -f \"$DIR/stdin_keepalive.pid\" ] && keepalive=\"$(cat \"$DIR/stdin_keepalive.pid\")\"\n") + script.WriteString("printf 'stopping\\n' >\"$DIR/status\"\n") + script.WriteString("if [ -n \"$pid\" ]; then kill -${SIGNAL} \"$pid\" 2>/dev/null || true; fi\n") + script.WriteString("if [ -n \"$monitor\" ]; then kill -${SIGNAL} \"$monitor\" 2>/dev/null || true; fi\n") + script.WriteString("if [ -n \"$keepalive\" ]; then kill -${SIGNAL} \"$keepalive\" 2>/dev/null || true; fi\n") + return script.String() +} + +// CWDPreflightScript verifies cwd exists on the guest. +func CWDPreflightScript(cwd string) string { + var script strings.Builder + script.WriteString("set -euo pipefail\n") + fmt.Fprintf(&script, "DIR=%s\n", ShellQuote(DefaultCWD(cwd))) + script.WriteString("if [ ! -d \"$DIR\" ]; then echo \"missing cwd: $DIR\"; exit 1; fi\n") + return script.String() +} + +// CommandPreflightScript verifies each command is resolvable on the guest. +func CommandPreflightScript(commands []string) string { + var script strings.Builder + script.WriteString("set -euo pipefail\n") + script.WriteString("check_command() {\n") + script.WriteString(" cmd=\"$1\"\n") + script.WriteString(" case \"$cmd\" in\n") + script.WriteString(" */*) [ -x \"$cmd\" ] || { echo \"missing command: $cmd\"; exit 1; } ;;\n") + script.WriteString(" *) command -v \"$cmd\" >/dev/null 2>&1 || { echo \"missing command: $cmd\"; exit 1; } ;;\n") + script.WriteString(" esac\n") + script.WriteString("}\n") + for _, command := range commands { + fmt.Fprintf(&script, "check_command %s\n", ShellQuote(command)) + } + return script.String() +} + +// AttachInputCommand returns the guest command that creates/opens the stdin +// fifo for sessionID and cats attach-side bytes into it. +func AttachInputCommand(sessionID string) string { + path := StdinPipePath(sessionID) + return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\n[ -p %s ] || mkfifo -m 600 %s\nexec cat > %s\n", ShellQuote(path), ShellQuote(path), ShellQuote(path))) +} + +// AttachTailCommand returns the guest command that tails a log file and +// streams new content back to the attach bridge. +func AttachTailCommand(path string) string { + return "bash -lc " + ShellQuote(fmt.Sprintf("set -euo pipefail\ntouch %s\nexec tail -n 0 -F %s 2>/dev/null\n", ShellQuote(path), ShellQuote(path))) +} + +// EnvLines returns deterministic `export KEY=value` lines for the session +// launcher, ordered by key. +func EnvLines(values map[string]string) []string { + if len(values) == 0 { + return nil + } + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + sort.Strings(keys) + lines := make([]string, 0, len(keys)) + for _, key := range keys { + lines = append(lines, "export "+key+"="+ShellQuote(values[key])) + } + return lines +} + +// -- State snapshot helpers ------------------------------------------------- + +// ParseState decodes the key=value output produced by InspectScript. +func ParseState(raw string) (StateSnapshot, error) { + var snapshot StateSnapshot + scanner := bufio.NewScanner(strings.NewReader(raw)) + for scanner.Scan() { + line := scanner.Text() + key, value, ok := strings.Cut(line, "=") + if !ok { + continue + } + switch strings.TrimSpace(key) { + case "status": + snapshot.Status = strings.TrimSpace(value) + case "pid": + if pid, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { + snapshot.GuestPID = pid + } + case "exit": + if exitCode, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { + snapshot.ExitCode = &exitCode + } + case "alive": + snapshot.Alive = strings.TrimSpace(value) == "true" + case "error": + snapshot.LastError = strings.TrimSpace(value) + } + } + return snapshot, scanner.Err() +} + +// InspectStateFromDir reads the state files directly from stateDir (used +// when the guest is offline and we can mount the work disk from the host). +func InspectStateFromDir(stateDir string) (StateSnapshot, error) { + var snapshot StateSnapshot + statusData, _ := os.ReadFile(filepath.Join(stateDir, "status")) + snapshot.Status = strings.TrimSpace(string(statusData)) + pidData, _ := os.ReadFile(filepath.Join(stateDir, "pid")) + if pidValue, err := strconv.Atoi(strings.TrimSpace(string(pidData))); err == nil { + snapshot.GuestPID = pidValue + } + exitData, _ := os.ReadFile(filepath.Join(stateDir, "exit_code")) + if exitValue, err := strconv.Atoi(strings.TrimSpace(string(exitData))); err == nil { + snapshot.ExitCode = &exitValue + } + errorData, _ := os.ReadFile(filepath.Join(stateDir, "error")) + snapshot.LastError = strings.TrimSpace(string(errorData)) + if snapshot.GuestPID != 0 { + snapshot.Alive = ProcessAlive(snapshot.GuestPID) + } + return snapshot, nil +} + +// ApplyStateSnapshot mutates sess in place to reflect snapshot. vmRunning +// captures whether the VM is currently up so stale in-flight sessions can be +// failed when the VM is gone. +func ApplyStateSnapshot(sess *model.GuestSession, snapshot StateSnapshot, vmRunning bool) { + if sess == nil { + return + } + if snapshot.GuestPID != 0 { + sess.GuestPID = snapshot.GuestPID + } + if snapshot.LastError != "" { + sess.LastError = snapshot.LastError + } + if snapshot.ExitCode != nil { + sess.ExitCode = snapshot.ExitCode + sess.Attachable = false + sess.Reattachable = false + if sess.StartedAt.IsZero() { + sess.StartedAt = model.Now() + } + if sess.EndedAt.IsZero() { + sess.EndedAt = model.Now() + } + if *snapshot.ExitCode == 0 { + sess.Status = model.GuestSessionStatusExited + } else { + sess.Status = model.GuestSessionStatusFailed + } + return + } + if snapshot.Alive { + if sess.StartedAt.IsZero() { + sess.StartedAt = model.Now() + } + sess.Status = model.GuestSessionStatusRunning + return + } + if !vmRunning && (sess.Status == model.GuestSessionStatusStarting || sess.Status == model.GuestSessionStatusRunning || sess.Status == model.GuestSessionStatusStopping) { + sess.Status = model.GuestSessionStatusFailed + sess.Attachable = false + sess.Reattachable = false + if sess.LastError == "" { + sess.LastError = "vm is not running" + } + if sess.EndedAt.IsZero() { + sess.EndedAt = model.Now() + } + return + } + if snapshot.Status == string(model.GuestSessionStatusRunning) { + if sess.StartedAt.IsZero() { + sess.StartedAt = model.Now() + } + sess.Status = model.GuestSessionStatusRunning + } + if sess.Status == model.GuestSessionStatusRunning && sess.StdinMode == model.GuestSessionStdinPipe { + sess.Attachable = true + sess.Reattachable = true + if sess.AttachBackend == "" { + sess.AttachBackend = AttachBackendSSHBridge + } + if sess.AttachMode == "" { + sess.AttachMode = AttachModeExclusive + } + } +} + +// StateChanged reports whether any materially observable field differs +// between before and after, guiding whether to persist an update. +func StateChanged(before, after model.GuestSession) bool { + if before.Status != after.Status || before.GuestPID != after.GuestPID || before.LastError != after.LastError || before.Attachable != after.Attachable || before.Reattachable != after.Reattachable || before.AttachBackend != after.AttachBackend || before.AttachMode != after.AttachMode || before.LaunchStage != after.LaunchStage || before.LaunchMessage != after.LaunchMessage || before.LaunchRawLog != after.LaunchRawLog { + return true + } + if before.StartedAt != after.StartedAt || before.EndedAt != after.EndedAt { + return true + } + switch { + case before.ExitCode == nil && after.ExitCode == nil: + return false + case before.ExitCode == nil || after.ExitCode == nil: + return true + default: + return *before.ExitCode != *after.ExitCode + } +} + +// -- Launch helpers --------------------------------------------------------- + +// DefaultName returns a friendly session name: caller-provided if non-empty, +// otherwise `-`. +func DefaultName(id, command, explicit string) string { + if trimmed := strings.TrimSpace(explicit); trimmed != "" { + return trimmed + } + base := filepath.Base(strings.TrimSpace(command)) + if base == "." || base == string(filepath.Separator) || base == "" { + base = "session" + } + return base + "-" + system.ShortID(id) +} + +// DefaultCWD returns value if non-empty, else /root. +func DefaultCWD(value string) string { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + return "/root" +} + +// FailLaunch annotates sess as launch-failed with stage/message/raw log and +// returns it for persistence. +func FailLaunch(sess model.GuestSession, stage, message, rawLog string) model.GuestSession { + now := model.Now() + sess.Status = model.GuestSessionStatusFailed + sess.LastError = strings.TrimSpace(message) + sess.Attachable = false + sess.Reattachable = false + sess.LaunchStage = strings.TrimSpace(stage) + sess.LaunchMessage = strings.TrimSpace(message) + sess.LaunchRawLog = strings.TrimSpace(rawLog) + sess.UpdatedAt = now + sess.EndedAt = now + return sess +} + +// NormalizeRequiredCommands returns a de-duplicated, order-preserving list +// of required commands, with the session command first. +func NormalizeRequiredCommands(command string, extras []string) []string { + ordered := make([]string, 0, len(extras)+1) + seen := map[string]struct{}{} + appendValue := func(value string) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return + } + if _, ok := seen[trimmed]; ok { + return + } + seen[trimmed] = struct{}{} + ordered = append(ordered, trimmed) + } + appendValue(command) + for _, extra := range extras { + appendValue(extra) + } + return ordered +} + +// -- Small utilities -------------------------------------------------------- + +// ShellQuote returns value single-quoted for bash, escaping embedded quotes. +func ShellQuote(value string) string { + return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" +} + +// ExitCode extracts the exit status from an ssh.ExitError, returning +// (0, true) for nil errors. +func ExitCode(err error) (int, bool) { + if err == nil { + return 0, true + } + var exitErr *ssh.ExitError + if errors.As(err, &exitErr) { + return exitErr.ExitStatus(), true + } + return 0, false +} + +// CloneStringMap returns a shallow copy of values, or nil if empty. +func CloneStringMap(values map[string]string) map[string]string { + if len(values) == 0 { + return nil + } + cloned := make(map[string]string, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +// TailFileContent returns the last N lines of a file, or "" if the file is +// missing. +func TailFileContent(path string, lines int) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return "", nil + } + return "", err + } + if lines <= 0 { + return string(data), nil + } + parts := strings.Split(string(data), "\n") + if len(parts) <= lines { + return string(data), nil + } + return strings.Join(parts[len(parts)-lines-1:], "\n"), nil +} + +// ProcessAlive returns true if the process with pid exists. The syscallKill +// override is exposed for tests that need to simulate alive/dead processes. +func ProcessAlive(pid int) bool { + if pid <= 0 { + return false + } + return syscallKill(pid, syscall.Signal(0)) == nil +} + +// syscallKill is a test seam: tests replace it to stub process-alive checks. +var syscallKill = func(pid int, signal os.Signal) error { + proc, err := os.FindProcess(pid) + if err != nil { + return err + } + return proc.Signal(signal) +} + +// FormatStepError wraps err with an action label and trimmed on-guest log. +func FormatStepError(action string, err error, log string) error { + log = strings.TrimSpace(log) + if log == "" { + return fmt.Errorf("%s: %w", action, err) + } + return fmt.Errorf("%s: %w: %s", action, err, log) +} diff --git a/internal/daemon/session_attach.go b/internal/daemon/session_attach.go index 5a3c4a0..f5301ee 100644 --- a/internal/daemon/session_attach.go +++ b/internal/daemon/session_attach.go @@ -11,6 +11,7 @@ import ( "time" "banger/internal/api" + sess "banger/internal/daemon/session" "banger/internal/guest" "banger/internal/model" "banger/internal/sessionstream" @@ -56,7 +57,7 @@ func (d *Daemon) BeginGuestSessionAttach(ctx context.Context, params api.GuestSe return api.GuestSessionAttachBeginResult{ Session: session, AttachID: attachID, - TransportKind: guestSessionTransportUnixSocket, + TransportKind: sess.TransportUnixSocket, TransportTarget: socketPath, SocketPath: socketPath, StreamFormat: sessionstream.FormatV1, @@ -86,7 +87,7 @@ func (d *Daemon) waitForGuestSessionExit(id string, controller *guestSessionCont now := model.Now() updated.UpdatedAt = now updated.EndedAt = now - if exitCode, ok := guestSessionExitCode(err); ok { + if exitCode, ok := sess.ExitCode(err); ok { updated.ExitCode = &exitCode if exitCode == 0 { updated.Status = model.GuestSessionStatusExited @@ -165,16 +166,16 @@ func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller return fmt.Errorf("vm %q is not running", vm.Name) } address := net.JoinHostPort(vm.Runtime.GuestIP, "22") - stdinStream, err := d.openGuestSessionAttachStream(address, guestSessionAttachInputCommand(session.ID)) + stdinStream, err := d.openGuestSessionAttachStream(address, sess.AttachInputCommand(session.ID)) if err != nil { return fmt.Errorf("open guest session stdin stream: %w", err) } - stdoutStream, err := d.openGuestSessionAttachStream(address, guestSessionAttachTailCommand(session.StdoutLogPath)) + stdoutStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StdoutLogPath)) if err != nil { _ = stdinStream.Close() return fmt.Errorf("open guest session stdout stream: %w", err) } - stderrStream, err := d.openGuestSessionAttachStream(address, guestSessionAttachTailCommand(session.StderrLogPath)) + stderrStream, err := d.openGuestSessionAttachStream(address, sess.AttachTailCommand(session.StderrLogPath)) if err != nil { _ = stdinStream.Close() _ = stdoutStream.Close() diff --git a/internal/daemon/session_controller.go b/internal/daemon/session_controller.go index 8f45a36..1736f7b 100644 --- a/internal/daemon/session_controller.go +++ b/internal/daemon/session_controller.go @@ -98,14 +98,6 @@ func (c *guestSessionController) close() error { return err } -type guestSessionStateSnapshot struct { - Status string - GuestPID int - ExitCode *int - Alive bool - LastError string -} - // sessionRegistry owns the live guest-session controller map. Its lock is // independent of Daemon.mu so guest-session lookups do not contend with // unrelated daemon state. diff --git a/internal/daemon/session_lifecycle.go b/internal/daemon/session_lifecycle.go index 3ca56b4..b22d9e2 100644 --- a/internal/daemon/session_lifecycle.go +++ b/internal/daemon/session_lifecycle.go @@ -10,6 +10,7 @@ import ( "time" "banger/internal/api" + sess "banger/internal/daemon/session" "banger/internal/guest" "banger/internal/model" "banger/internal/system" @@ -50,34 +51,34 @@ func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord, session := model.GuestSession{ ID: id, VMID: vm.ID, - Name: defaultGuestSessionName(id, params.Command, params.Name), - Backend: guestSessionBackendSSH, + Name: sess.DefaultName(id, params.Command, params.Name), + Backend: sess.BackendSSH, Command: params.Command, Args: append([]string(nil), params.Args...), CWD: strings.TrimSpace(params.CWD), - Env: cloneStringMap(params.Env), + Env: sess.CloneStringMap(params.Env), StdinMode: stdinMode, Status: model.GuestSessionStatusStarting, - GuestStateDir: guestSessionStateDir(id), - StdoutLogPath: guestSessionStdoutLogPath(id), - StderrLogPath: guestSessionStderrLogPath(id), - Tags: cloneStringMap(params.Tags), + GuestStateDir: sess.StateDir(id), + StdoutLogPath: sess.StdoutLogPath(id), + StderrLogPath: sess.StderrLogPath(id), + Tags: sess.CloneStringMap(params.Tags), Attachable: stdinMode == model.GuestSessionStdinPipe, Reattachable: stdinMode == model.GuestSessionStdinPipe, CreatedAt: now, UpdatedAt: now, } if session.Attachable { - session.AttachBackend = guestSessionAttachBackendSSHBridge - session.AttachMode = guestSessionAttachModeExclusive + session.AttachBackend = sess.AttachBackendSSHBridge + session.AttachMode = sess.AttachModeExclusive } else { - session.AttachBackend = guestSessionAttachBackendNone + session.AttachBackend = sess.AttachBackendNone } if err := d.store.UpsertGuestSession(ctx, session); err != nil { return model.GuestSession{}, err } fail := func(stage, message, rawLog string) (model.GuestSession, error) { - session = failGuestSessionLaunch(session, stage, message, rawLog) + session = sess.FailLaunch(session, stage, message, rawLog) if err := d.store.UpsertGuestSession(ctx, session); err != nil { return model.GuestSession{}, err } @@ -93,20 +94,20 @@ func (d *Daemon) startGuestSessionLocked(ctx context.Context, vm model.VMRecord, } defer client.Close() var preflightLog bytes.Buffer - if err := client.RunScript(ctx, guestSessionCWDPreflightScript(session.CWD), &preflightLog); err != nil { - return fail("preflight_cwd", fmt.Sprintf("guest working directory is unavailable: %s", defaultGuestSessionCWD(session.CWD)), preflightLog.String()) + if err := client.RunScript(ctx, sess.CWDPreflightScript(session.CWD), &preflightLog); err != nil { + return fail("preflight_cwd", fmt.Sprintf("guest working directory is unavailable: %s", sess.DefaultCWD(session.CWD)), preflightLog.String()) } preflightLog.Reset() - requiredCommands := normalizeGuestSessionRequiredCommands(params.Command, params.RequiredCommands) - if err := client.RunScript(ctx, guestSessionCommandPreflightScript(requiredCommands), &preflightLog); err != nil { + requiredCommands := sess.NormalizeRequiredCommands(params.Command, params.RequiredCommands) + if err := client.RunScript(ctx, sess.CommandPreflightScript(requiredCommands), &preflightLog); err != nil { return fail("preflight_command", fmt.Sprintf("required guest command is unavailable: %s", strings.TrimSpace(preflightLog.String())), preflightLog.String()) } var uploadLog bytes.Buffer - if err := client.UploadFile(ctx, guestSessionScriptPath(id), 0o755, []byte(guestSessionScript(session)), &uploadLog); err != nil { + if err := client.UploadFile(ctx, sess.ScriptPath(id), 0o755, []byte(sess.Script(session)), &uploadLog); err != nil { return fail("upload_script", "upload guest session script failed", uploadLog.String()) } var launchLog bytes.Buffer - launchScript := fmt.Sprintf("set -euo pipefail\nnohup bash %s >/dev/null 2>&1 /dev/null 2>&1 > %s\nrm -f %s\n", - guestShellQuote(tmpPath), - guestShellQuote(guestSessionStdinPipePath(session.ID)), - guestShellQuote(tmpPath), + sess.ShellQuote(tmpPath), + sess.ShellQuote(sess.StdinPipePath(session.ID)), + sess.ShellQuote(tmpPath), ) var sendLog bytes.Buffer if err := client.RunScript(ctx, sendScript, &sendLog); err != nil { @@ -99,9 +100,9 @@ func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, ses path = session.StderrLogPath } var output bytes.Buffer - script := fmt.Sprintf("set -euo pipefail\nif [ -f %s ]; then tail -n %d %s; fi\n", guestShellQuote(path), tailLines, guestShellQuote(path)) + script := fmt.Sprintf("set -euo pipefail\nif [ -f %s ]; then tail -n %d %s; fi\n", sess.ShellQuote(path), tailLines, sess.ShellQuote(path)) if err := client.RunScript(ctx, script, &output); err != nil { - return "", formatGuestSessionStepError("read guest session log", err, output.String()) + return "", sess.FormatStepError("read guest session log", err, output.String()) } return output.String(), nil } @@ -114,6 +115,6 @@ func (d *Daemon) readGuestSessionLog(ctx context.Context, vm model.VMRecord, ses return "", err } defer cleanup() - logPath := filepath.Join(workMount, guestSessionRelativeStateDir(session.ID), stream+".log") - return tailFileContent(logPath, tailLines) + logPath := filepath.Join(workMount, sess.RelativeStateDir(session.ID), stream+".log") + return sess.TailFileContent(logPath, tailLines) } diff --git a/internal/daemon/workspace.go b/internal/daemon/workspace.go index 6510799..f19f963 100644 --- a/internal/daemon/workspace.go +++ b/internal/daemon/workspace.go @@ -14,6 +14,7 @@ import ( "time" "banger/internal/api" + sess "banger/internal/daemon/session" "banger/internal/model" "banger/internal/system" ) @@ -68,8 +69,8 @@ func (d *Daemon) ExportVMWorkspace(ctx context.Context, params api.WorkspaceExpo // past diffRef) and any additional uncommitted changes on top. patchScript := fmt.Sprintf( "set -euo pipefail\ncd %s\ngit add -A\ngit diff --cached %s --binary\n", - guestShellQuote(guestPath), - guestShellQuote(diffRef), + sess.ShellQuote(guestPath), + sess.ShellQuote(diffRef), ) patch, err := client.RunScriptOutput(ctx, patchScript) if err != nil { @@ -79,8 +80,8 @@ func (d *Daemon) ExportVMWorkspace(ctx context.Context, params api.WorkspaceExpo // Enumerate changed paths (index already staged; this is a cheap read). namesScript := fmt.Sprintf( "set -euo pipefail\ncd %s\ngit diff --cached %s --name-only\n", - guestShellQuote(guestPath), - guestShellQuote(diffRef), + sess.ShellQuote(guestPath), + sess.ShellQuote(diffRef), ) namesOut, _ := client.RunScriptOutput(ctx, namesScript) var changed []string @@ -153,9 +154,9 @@ func (d *Daemon) prepareVMWorkspaceLocked(ctx context.Context, vm model.VMRecord } if readOnly { var chmodLog bytes.Buffer - chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", guestShellQuote(guestPath)) + chmodScript := fmt.Sprintf("set -euo pipefail\nchmod -R a-w %s\n", sess.ShellQuote(guestPath)) if err := client.RunScript(ctx, chmodScript, &chmodLog); err != nil { - return model.WorkspacePrepareResult{}, formatGuestSessionStepError("set workspace readonly", err, chmodLog.String()) + return model.WorkspacePrepareResult{}, sess.FormatStepError("set workspace readonly", err, chmodLog.String()) } } return model.WorkspacePrepareResult{ @@ -246,13 +247,13 @@ func importWorkspaceRepoToGuest(ctx context.Context, client guestSSHClient, spec switch mode { case model.WorkspacePrepareModeFullCopy: var copyLog bytes.Buffer - command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", guestShellQuote(guestPath), guestShellQuote(guestPath), guestShellQuote(guestPath)) + command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath)) if err := client.StreamTar(ctx, spec.RepoRoot, command, ©Log); err != nil { - return formatGuestSessionStepError("copy full workspace", err, copyLog.String()) + return sess.FormatStepError("copy full workspace", err, copyLog.String()) } var finalizeLog bytes.Buffer if err := client.RunScript(ctx, workspaceFinalizeScript(spec, guestPath, mode), &finalizeLog); err != nil { - return formatGuestSessionStepError("finalize full workspace", err, finalizeLog.String()) + return sess.FormatStepError("finalize full workspace", err, finalizeLog.String()) } return nil case model.WorkspacePrepareModeMetadataOnly, model.WorkspacePrepareModeShallowOverlay: @@ -262,21 +263,21 @@ func importWorkspaceRepoToGuest(ctx context.Context, client guestSSHClient, spec } defer cleanup() var copyLog bytes.Buffer - command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", guestShellQuote(guestPath), guestShellQuote(guestPath), guestShellQuote(guestPath)) + command := fmt.Sprintf("rm -rf %s && mkdir -p %s && tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath), sess.ShellQuote(guestPath), sess.ShellQuote(guestPath)) if err := client.StreamTar(ctx, repoCopyDir, command, ©Log); err != nil { - return formatGuestSessionStepError("copy guest git metadata", err, copyLog.String()) + return sess.FormatStepError("copy guest git metadata", err, copyLog.String()) } var scriptLog bytes.Buffer if err := client.RunScript(ctx, workspaceFinalizeScript(spec, guestPath, mode), &scriptLog); err != nil { - return formatGuestSessionStepError("prepare guest checkout", err, scriptLog.String()) + return sess.FormatStepError("prepare guest checkout", err, scriptLog.String()) } if mode == model.WorkspacePrepareModeMetadataOnly { return nil } var overlayLog bytes.Buffer - command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", guestShellQuote(guestPath)) + command = fmt.Sprintf("tar -o -C %s --strip-components=1 -xf -", sess.ShellQuote(guestPath)) if err := client.StreamTarEntries(ctx, spec.RepoRoot, spec.OverlayPaths, command, &overlayLog); err != nil { - return formatGuestSessionStepError("overlay workspace working tree", err, overlayLog.String()) + return sess.FormatStepError("overlay workspace working tree", err, overlayLog.String()) } return nil default: @@ -287,22 +288,22 @@ func importWorkspaceRepoToGuest(ctx context.Context, client guestSSHClient, spec func workspaceFinalizeScript(spec workspaceRepoSpec, guestPath string, mode model.WorkspacePrepareMode) string { var script strings.Builder script.WriteString("set -euo pipefail\n") - fmt.Fprintf(&script, "DIR=%s\n", guestShellQuote(guestPath)) + fmt.Fprintf(&script, "DIR=%s\n", sess.ShellQuote(guestPath)) script.WriteString("git config --global --add safe.directory \"$DIR\"\n") if mode != model.WorkspacePrepareModeFullCopy { script.WriteString("find \"$DIR\" -mindepth 1 -maxdepth 1 ! -name .git -exec rm -rf {} +\n") } switch { case strings.TrimSpace(spec.BranchName) != "": - fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", guestShellQuote(spec.BranchName), guestShellQuote(spec.BaseCommit)) + fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.BranchName), sess.ShellQuote(spec.BaseCommit)) case strings.TrimSpace(spec.CurrentBranch) != "": - fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", guestShellQuote(spec.CurrentBranch), guestShellQuote(spec.HeadCommit)) + fmt.Fprintf(&script, "git -C \"$DIR\" checkout -B %s %s\n", sess.ShellQuote(spec.CurrentBranch), sess.ShellQuote(spec.HeadCommit)) default: - fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", guestShellQuote(spec.HeadCommit)) + fmt.Fprintf(&script, "git -C \"$DIR\" checkout --detach %s\n", sess.ShellQuote(spec.HeadCommit)) } if strings.TrimSpace(spec.GitUserName) != "" && strings.TrimSpace(spec.GitUserEmail) != "" { - fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", guestShellQuote(spec.GitUserName)) - fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", guestShellQuote(spec.GitUserEmail)) + fmt.Fprintf(&script, "git -C \"$DIR\" config user.name %s\n", sess.ShellQuote(spec.GitUserName)) + fmt.Fprintf(&script, "git -C \"$DIR\" config user.email %s\n", sess.ShellQuote(spec.GitUserEmail)) } return script.String() }