banger/internal/guest/ssh.go
Thales Maciel 2ebc6f99c6
Add repo-backed vm run command
Create a CLI-only banger vm run [path] flow that resolves the enclosing git repository, creates a VM, imports a guest checkout, and launches opencode attach automatically from the host.

Build the guest checkout by bundling git history plus the resolved base and head commits, cloning that bundle in the guest, and overlaying tracked plus untracked non-ignored files over SSH so local working-tree changes carry over. Support guest-only branch creation with --branch and --from, reject bare repos and submodules, and add selective tar helpers plus CLI seams to keep the workflow testable.

Validate with go test ./..., make build, banger vm run --help, and the expected --from requires --branch error path.
2026-03-21 23:34:20 -03:00

279 lines
6.4 KiB
Go

package guest
import (
"archive/tar"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"os"
"path"
"path/filepath"
"sort"
"strings"
"time"
"golang.org/x/crypto/ssh"
)
type Client struct {
client *ssh.Client
}
func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval time.Duration) error {
if interval <= 0 {
interval = time.Second
}
for {
client, err := Dial(ctx, address, privateKeyPath)
if err == nil {
_ = client.Close()
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(interval):
}
}
}
func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error) {
signer, err := privateKeySigner(privateKeyPath)
if err != nil {
return nil, err
}
config := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
}
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err := dialer.DialContext(ctx, "tcp", address)
if err != nil {
return nil, err
}
sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, config)
if err != nil {
_ = conn.Close()
return nil, err
}
client := ssh.NewClient(sshConn, chans, reqs)
return &Client{client: client}, nil
}
func (c *Client) Close() error {
if c == nil || c.client == nil {
return nil
}
return c.client.Close()
}
func (c *Client) RunScript(ctx context.Context, script string, logWriter io.Writer) error {
return c.runSession(ctx, "bash -se", strings.NewReader(script), logWriter)
}
func (c *Client) UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error {
command := fmt.Sprintf("install -D -m %04o /dev/stdin %s", mode.Perm(), shellQuote(remotePath))
return c.runSession(ctx, command, bytes.NewReader(data), logWriter)
}
func (c *Client) StreamTar(ctx context.Context, sourceDir, remoteCommand string, logWriter io.Writer) error {
reader, writer := io.Pipe()
writeErr := make(chan error, 1)
go func() {
writeErr <- writeTarArchive(writer, sourceDir)
_ = writer.Close()
}()
runErr := c.runSession(ctx, remoteCommand, reader, logWriter)
tarErr := <-writeErr
return errors.Join(runErr, tarErr)
}
func (c *Client) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error {
reader, writer := io.Pipe()
writeErr := make(chan error, 1)
go func() {
writeErr <- writeTarEntriesArchive(writer, sourceDir, entries)
_ = writer.Close()
}()
runErr := c.runSession(ctx, remoteCommand, reader, logWriter)
tarErr := <-writeErr
return errors.Join(runErr, tarErr)
}
func (c *Client) runSession(ctx context.Context, command string, stdin io.Reader, logWriter io.Writer) error {
if c == nil || c.client == nil {
return fmt.Errorf("ssh client is not connected")
}
session, err := c.client.NewSession()
if err != nil {
return err
}
defer session.Close()
session.Stdin = stdin
if logWriter != nil {
session.Stdout = logWriter
session.Stderr = logWriter
}
done := make(chan error, 1)
go func() {
select {
case <-ctx.Done():
_ = c.client.Close()
case <-done:
}
}()
err = session.Run(command)
done <- nil
return err
}
func privateKeySigner(path string) (ssh.Signer, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return ssh.ParsePrivateKey(data)
}
func AuthorizedPublicKey(path string) ([]byte, error) {
signer, err := privateKeySigner(path)
if err != nil {
return nil, err
}
return ssh.MarshalAuthorizedKey(signer.PublicKey()), nil
}
func AuthorizedPublicKeyFingerprint(path string) (string, error) {
key, err := AuthorizedPublicKey(path)
if err != nil {
return "", err
}
sum := sha256.Sum256([]byte(strings.TrimSpace(string(key))))
return hex.EncodeToString(sum[:]), nil
}
func shellQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'"
}
func writeTarArchive(dst io.Writer, sourceDir string) error {
tw := tar.NewWriter(dst)
defer tw.Close()
sourceDir = filepath.Clean(sourceDir)
rootName := filepath.Base(sourceDir)
return filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
name := rootName
if path != sourceDir {
relPath, err := filepath.Rel(sourceDir, path)
if err != nil {
return err
}
name = filepath.Join(rootName, relPath)
}
linkTarget := ""
if info.Mode()&os.ModeSymlink != 0 {
linkTarget, err = os.Readlink(path)
if err != nil {
return err
}
}
header, err := tar.FileInfoHeader(info, linkTarget)
if err != nil {
return err
}
header.Name = name
if err := tw.WriteHeader(header); err != nil {
return err
}
if !info.Mode().IsRegular() {
return nil
}
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()
_, err = io.Copy(tw, file)
return err
})
}
func writeTarEntriesArchive(dst io.Writer, sourceDir string, entries []string) error {
tw := tar.NewWriter(dst)
defer tw.Close()
sourceDir = filepath.Clean(sourceDir)
rootName := filepath.Base(sourceDir)
uniqueEntries := make([]string, 0, len(entries))
seen := make(map[string]struct{}, len(entries))
for _, entry := range entries {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
entry = filepath.Clean(entry)
if entry == "." || entry == ".." || strings.HasPrefix(entry, ".."+string(filepath.Separator)) {
return fmt.Errorf("tar entry %q escapes source dir", entry)
}
if _, ok := seen[entry]; ok {
continue
}
seen[entry] = struct{}{}
uniqueEntries = append(uniqueEntries, entry)
}
sort.Strings(uniqueEntries)
for _, entry := range uniqueEntries {
fullPath := filepath.Join(sourceDir, entry)
info, err := os.Lstat(fullPath)
if err != nil {
return err
}
linkTarget := ""
if info.Mode()&os.ModeSymlink != 0 {
linkTarget, err = os.Readlink(fullPath)
if err != nil {
return err
}
}
header, err := tar.FileInfoHeader(info, linkTarget)
if err != nil {
return err
}
header.Name = path.Join(rootName, filepath.ToSlash(entry))
if err := tw.WriteHeader(header); err != nil {
return err
}
if !info.Mode().IsRegular() {
continue
}
file, err := os.Open(fullPath)
if err != nil {
return err
}
if _, err := io.Copy(tw, file); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
}
return nil
}