package daemon import ( "context" "errors" "fmt" "io" "net" "os" "path/filepath" "time" "banger/internal/api" "banger/internal/guest" "banger/internal/model" "banger/internal/sessionstream" "banger/internal/system" ) func (d *Daemon) BeginGuestSessionAttach(ctx context.Context, params api.GuestSessionAttachBeginParams) (api.GuestSessionAttachBeginResult, error) { vm, err := d.FindVM(ctx, params.VMIDOrName) if err != nil { return api.GuestSessionAttachBeginResult{}, err } session, err := d.findGuestSession(ctx, vm.ID, params.SessionIDOrName) if err != nil { return api.GuestSessionAttachBeginResult{}, err } session, _ = d.refreshGuestSession(ctx, vm, session) if !session.Attachable { return api.GuestSessionAttachBeginResult{}, errors.New("session is not attachable") } controller := &guestSessionController{} if !d.claimGuestSessionController(session.ID, controller) { return api.GuestSessionAttachBeginResult{}, errors.New("session already has an active attach") } attachID, err := model.NewID() if err != nil { d.clearGuestSessionController(session.ID) return api.GuestSessionAttachBeginResult{}, err } socketPath := filepath.Join(d.layout.RuntimeDir, "guest-session-attach-"+attachID[:12]+".sock") _ = os.Remove(socketPath) listener, err := net.Listen("unix", socketPath) if err != nil { d.clearGuestSessionController(session.ID) return api.GuestSessionAttachBeginResult{}, err } if err := os.Chmod(socketPath, 0o600); err != nil { _ = listener.Close() _ = os.Remove(socketPath) d.clearGuestSessionController(session.ID) return api.GuestSessionAttachBeginResult{}, err } go d.serveGuestSessionAttach(session, controller, attachID, socketPath, listener) return api.GuestSessionAttachBeginResult{ Session: session, AttachID: attachID, TransportKind: guestSessionTransportUnixSocket, TransportTarget: socketPath, SocketPath: socketPath, StreamFormat: sessionstream.FormatV1, }, nil } func (d *Daemon) forwardGuestSessionOutput(_ string, controller *guestSessionController, channel byte, reader io.Reader) { buffer := make([]byte, 32*1024) for { n, err := reader.Read(buffer) if n > 0 { controller.writeFrame(channel, buffer[:n]) } if err != nil { if !errors.Is(err, io.EOF) { controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()}) } return } } } func (d *Daemon) waitForGuestSessionExit(id string, controller *guestSessionController, session model.GuestSession) { err := controller.stream.Wait() updated := session updated.Attachable = false now := model.Now() updated.UpdatedAt = now updated.EndedAt = now if exitCode, ok := guestSessionExitCode(err); ok { updated.ExitCode = &exitCode if exitCode == 0 { updated.Status = model.GuestSessionStatusExited } else { updated.Status = model.GuestSessionStatusFailed } } if err != nil && updated.LastError == "" { updated.LastError = err.Error() } if vm, getErr := d.store.GetVMByID(context.Background(), updated.VMID); getErr == nil { if refreshed, refreshErr := d.refreshGuestSession(context.Background(), vm, updated); refreshErr == nil { updated = refreshed } } _ = d.store.UpsertGuestSession(context.Background(), updated) controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: updated.ExitCode}) _ = controller.close() d.clearGuestSessionController(id) } func (d *Daemon) serveGuestSessionAttach(session model.GuestSession, controller *guestSessionController, _ string, socketPath string, listener net.Listener) { defer func() { _ = listener.Close() _ = os.Remove(socketPath) _ = controller.close() d.clearGuestSessionController(session.ID) }() conn, err := listener.Accept() if err != nil { return } defer conn.Close() if err := controller.setAttach(conn); err != nil { _ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()}) return } defer controller.clearAttach(conn) if err := d.attachGuestSessionBridge(session, controller); err != nil { _ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()}) return } for { channel, payload, err := sessionstream.ReadFrame(conn) if err != nil { return } switch channel { case sessionstream.ChannelStdin: if controller.stdin == nil { continue } if _, err := controller.stdin.Write(payload); err != nil { _ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()}) return } case sessionstream.ChannelControl: message, err := sessionstream.ReadControl(payload) if err != nil { _ = sessionstream.WriteControl(conn, sessionstream.ControlMessage{Type: "error", Error: err.Error()}) return } if message.Type == "eof" && controller.stdin != nil { _ = controller.stdin.Close() } } } } func (d *Daemon) attachGuestSessionBridge(session model.GuestSession, controller *guestSessionController) error { vm, err := d.store.GetVMByID(context.Background(), session.VMID) if err != nil { return err } if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { return fmt.Errorf("vm %q is not running", vm.Name) } address := net.JoinHostPort(vm.Runtime.GuestIP, "22") stdinStream, err := d.openGuestSessionAttachStream(address, guestSessionAttachInputCommand(session.ID)) if err != nil { return fmt.Errorf("open guest session stdin stream: %w", err) } stdoutStream, err := d.openGuestSessionAttachStream(address, guestSessionAttachTailCommand(session.StdoutLogPath)) if err != nil { _ = stdinStream.Close() return fmt.Errorf("open guest session stdout stream: %w", err) } stderrStream, err := d.openGuestSessionAttachStream(address, guestSessionAttachTailCommand(session.StderrLogPath)) if err != nil { _ = stdinStream.Close() _ = stdoutStream.Close() return fmt.Errorf("open guest session stderr stream: %w", err) } controller.streams = append(controller.streams, stdinStream, stdoutStream, stderrStream) controller.stdin = stdinStream.Stdin() go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStdout, stdoutStream.Stdout()) go d.forwardGuestSessionOutput(session.ID, controller, sessionstream.ChannelStderr, stderrStream.Stdout()) go d.watchGuestSessionAttach(session.ID, controller, session) return nil } func (d *Daemon) openGuestSessionAttachStream(address, command string) (*guest.StreamSession, error) { client, err := guest.Dial(context.Background(), address, d.config.SSHKeyPath) if err != nil { return nil, err } stream, err := client.StartCommand(context.Background(), command) if err != nil { _ = client.Close() return nil, err } return stream, nil } func (d *Daemon) watchGuestSessionAttach(id string, controller *guestSessionController, session model.GuestSession) { ticker := time.NewTicker(250 * time.Millisecond) defer ticker.Stop() for range ticker.C { vm, err := d.store.GetVMByID(context.Background(), session.VMID) if err != nil { controller.writeControl(sessionstream.ControlMessage{Type: "error", Error: err.Error()}) _ = controller.close() return } refreshed, err := d.refreshGuestSession(context.Background(), vm, session) if err == nil { session = refreshed } if session.Status == model.GuestSessionStatusExited || session.Status == model.GuestSessionStatusFailed { controller.writeControl(sessionstream.ControlMessage{Type: "exit", ExitCode: session.ExitCode}) _ = controller.close() return } } }