banger/internal/guest/ssh.go
Thales Maciel 37c4c091ec
Add guest sessions and agent VM defaults
Add daemon-backed workspace and guest-session primitives so host
orchestrators can prepare /root/repo, launch long-lived guest commands,
and attach to pipe-mode sessions over the local stdio mux bridge.

Persist richer session metadata and launch diagnostics, preflight guest
cwd/command requirements, make pipe-mode attach rehydratable from guest
state after daemon restart, and allow submodules when workspace prepare
runs in full_copy mode.

At the same time, stop vm run from auto-attaching opencode, make it
print next-step commands instead, and make glibc guest images more
agent-ready by installing node, opencode, claude, and pi while syncing
opencode/claude/pi auth files into work disks on VM start.

Validation:
- GOCACHE=/tmp/banger-gocache go test ./...
- make build
- banger vm workspace prepare --help
- banger vm session --help
- banger vm session start --help
- banger vm session attach --help
2026-04-12 23:48:42 -03:00

400 lines
8.5 KiB
Go

package guest
import (
"archive/tar"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"os"
"path"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
type Client struct {
client *ssh.Client
}
type StreamSession struct {
client *Client
session *ssh.Session
stdin io.WriteCloser
stdout io.Reader
stderr io.Reader
waitCh chan error
closeOnce sync.Once
}
func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
if interval <= 0 {
interval = time.Second
}
for {
client, err := Dial(ctx, address, privateKeyPath)
if err == nil {
_ = client.Close()
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(interval):
}
}
}
func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error) {
signer, err := privateKeySigner(privateKeyPath)
if err != nil {
return nil, err
}
config := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
}
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err := dialer.DialContext(ctx, "tcp", address)
if err != nil {
return nil, err
}
sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, config)
if err != nil {
_ = conn.Close()
return nil, err
}
client := ssh.NewClient(sshConn, chans, reqs)
return &Client{client: client}, nil
}
func (c *Client) Close() error {
if c == nil || c.client == nil {
return nil
}
return c.client.Close()
}
func (c *Client) RunScript(ctx context.Context, script string, logWriter io.Writer) error {
return c.runSession(ctx, "bash -se", strings.NewReader(script), logWriter)
}
func (c *Client) UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error {
command := fmt.Sprintf("install -D -m %04o /dev/stdin %s", mode.Perm(), shellQuote(remotePath))
return c.runSession(ctx, command, bytes.NewReader(data), logWriter)
}
func (c *Client) StreamTar(ctx context.Context, sourceDir, remoteCommand string, logWriter io.Writer) error {
reader, writer := io.Pipe()
writeErr := make(chan error, 1)
go func() {
writeErr <- writeTarArchive(writer, sourceDir)
_ = writer.Close()
}()
runErr := c.runSession(ctx, remoteCommand, reader, logWriter)
tarErr := <-writeErr
return errors.Join(runErr, tarErr)
}
func (c *Client) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error {
reader, writer := io.Pipe()
writeErr := make(chan error, 1)
go func() {
writeErr <- writeTarEntriesArchive(writer, sourceDir, entries)
_ = writer.Close()
}()
runErr := c.runSession(ctx, remoteCommand, reader, logWriter)
tarErr := <-writeErr
return errors.Join(runErr, tarErr)
}
func (c *Client) StartCommand(ctx context.Context, command string) (*StreamSession, error) {
if c == nil || c.client == nil {
return nil, fmt.Errorf("ssh client is not connected")
}
session, err := c.client.NewSession()
if err != nil {
return nil, err
}
stdin, err := session.StdinPipe()
if err != nil {
_ = session.Close()
return nil, err
}
stdout, err := session.StdoutPipe()
if err != nil {
_ = session.Close()
return nil, err
}
stderr, err := session.StderrPipe()
if err != nil {
_ = session.Close()
return nil, err
}
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = session.Close()
_ = c.client.Close()
case <-done:
}
}()
if err := session.Start(command); err != nil {
close(done)
_ = session.Close()
return nil, err
}
stream := &StreamSession{
client: c,
session: session,
stdin: stdin,
stdout: stdout,
stderr: stderr,
waitCh: make(chan error, 1),
}
go func() {
err := session.Wait()
close(done)
stream.waitCh <- err
close(stream.waitCh)
}()
return stream, nil
}
func (s *StreamSession) Stdin() io.WriteCloser {
if s == nil {
return nil
}
return s.stdin
}
func (s *StreamSession) Stdout() io.Reader {
if s == nil {
return nil
}
return s.stdout
}
func (s *StreamSession) Stderr() io.Reader {
if s == nil {
return nil
}
return s.stderr
}
func (s *StreamSession) Wait() error {
if s == nil || s.waitCh == nil {
return nil
}
err, ok := <-s.waitCh
if !ok {
return nil
}
return err
}
func (s *StreamSession) Close() error {
if s == nil {
return nil
}
var err error
s.closeOnce.Do(func() {
err = errors.Join(
func() error {
if s.session != nil {
return s.session.Close()
}
return nil
}(),
func() error {
if s.client != nil {
return s.client.Close()
}
return nil
}(),
)
})
return err
}
func (c *Client) runSession(ctx context.Context, command string, stdin io.Reader, logWriter io.Writer) error {
if c == nil || c.client == nil {
return fmt.Errorf("ssh client is not connected")
}
session, err := c.client.NewSession()
if err != nil {
return err
}
defer session.Close()
session.Stdin = stdin
if logWriter != nil {
session.Stdout = logWriter
session.Stderr = logWriter
}
done := make(chan error, 1)
go func() {
select {
case <-ctx.Done():
_ = c.client.Close()
case <-done:
}
}()
err = session.Run(command)
done <- nil
return err
}
func privateKeySigner(path string) (ssh.Signer, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return ssh.ParsePrivateKey(data)
}
func AuthorizedPublicKey(path string) ([]byte, error) {
signer, err := privateKeySigner(path)
if err != nil {
return nil, err
}
return ssh.MarshalAuthorizedKey(signer.PublicKey()), nil
}
func AuthorizedPublicKeyFingerprint(path string) (string, error) {
key, err := AuthorizedPublicKey(path)
if err != nil {
return "", err
}
sum := sha256.Sum256([]byte(strings.TrimSpace(string(key))))
return hex.EncodeToString(sum[:]), nil
}
func shellQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
}
func writeTarArchive(dst io.Writer, sourceDir string) error {
tw := tar.NewWriter(dst)
defer tw.Close()
sourceDir = filepath.Clean(sourceDir)
rootName := filepath.Base(sourceDir)
return filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
name := rootName
if path != sourceDir {
relPath, err := filepath.Rel(sourceDir, path)
if err != nil {
return err
}
name = filepath.Join(rootName, relPath)
}
linkTarget := ""
if info.Mode()&os.ModeSymlink != 0 {
linkTarget, err = os.Readlink(path)
if err != nil {
return err
}
}
header, err := tar.FileInfoHeader(info, linkTarget)
if err != nil {
return err
}
header.Name = name
if err := tw.WriteHeader(header); err != nil {
return err
}
if !info.Mode().IsRegular() {
return nil
}
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()
_, err = io.Copy(tw, file)
return err
})
}
func writeTarEntriesArchive(dst io.Writer, sourceDir string, entries []string) error {
tw := tar.NewWriter(dst)
defer tw.Close()
sourceDir = filepath.Clean(sourceDir)
rootName := filepath.Base(sourceDir)
uniqueEntries := make([]string, 0, len(entries))
seen := make(map[string]struct{}, len(entries))
for _, entry := range entries {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
entry = filepath.Clean(entry)
if entry == "." || entry == ".." || strings.HasPrefix(entry, ".."+string(filepath.Separator)) {
return fmt.Errorf("tar entry %q escapes source dir", entry)
}
if _, ok := seen[entry]; ok {
continue
}
seen[entry] = struct{}{}
uniqueEntries = append(uniqueEntries, entry)
}
sort.Strings(uniqueEntries)
for _, entry := range uniqueEntries {
fullPath := filepath.Join(sourceDir, entry)
info, err := os.Lstat(fullPath)
if err != nil {
return err
}
linkTarget := ""
if info.Mode()&os.ModeSymlink != 0 {
linkTarget, err = os.Readlink(fullPath)
if err != nil {
return err
}
}
header, err := tar.FileInfoHeader(info, linkTarget)
if err != nil {
return err
}
header.Name = path.Join(rootName, filepath.ToSlash(entry))
if err := tw.WriteHeader(header); err != nil {
return err
}
if !info.Mode().IsRegular() {
continue
}
file, err := os.Open(fullPath)
if err != nil {
return err
}
if _, err := io.Copy(tw, file); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
}
return nil
}