package guest import ( "archive/tar" "bytes" "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "net" "os" "path" "path/filepath" "sort" "strings" "time" "golang.org/x/crypto/ssh" ) type Client struct { client *ssh.Client } func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { if interval <= 0 { interval = time.Second } for { client, err := Dial(ctx, address, privateKeyPath) if err == nil { _ = client.Close() return nil } select { case <-ctx.Done(): return ctx.Err() case <-time.After(interval): } } } func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error) { signer, err := privateKeySigner(privateKeyPath) if err != nil { return nil, err } config := &ssh.ClientConfig{ User: "root", Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 10 * time.Second, } dialer := &net.Dialer{Timeout: 10 * time.Second} conn, err := dialer.DialContext(ctx, "tcp", address) if err != nil { return nil, err } sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, config) if err != nil { _ = conn.Close() return nil, err } client := ssh.NewClient(sshConn, chans, reqs) return &Client{client: client}, nil } func (c *Client) Close() error { if c == nil || c.client == nil { return nil } return c.client.Close() } func (c *Client) RunScript(ctx context.Context, script string, logWriter io.Writer) error { return c.runSession(ctx, "bash -se", strings.NewReader(script), logWriter) } func (c *Client) UploadFile(ctx context.Context, remotePath string, mode os.FileMode, data []byte, logWriter io.Writer) error { command := fmt.Sprintf("install -D -m %04o /dev/stdin %s", mode.Perm(), shellQuote(remotePath)) return c.runSession(ctx, command, bytes.NewReader(data), logWriter) } func (c *Client) StreamTar(ctx context.Context, sourceDir, remoteCommand string, logWriter io.Writer) error { reader, writer := io.Pipe() writeErr := make(chan error, 1) go func() { writeErr <- writeTarArchive(writer, sourceDir) _ = writer.Close() }() runErr := c.runSession(ctx, remoteCommand, reader, logWriter) tarErr := <-writeErr return errors.Join(runErr, tarErr) } func (c *Client) StreamTarEntries(ctx context.Context, sourceDir string, entries []string, remoteCommand string, logWriter io.Writer) error { reader, writer := io.Pipe() writeErr := make(chan error, 1) go func() { writeErr <- writeTarEntriesArchive(writer, sourceDir, entries) _ = writer.Close() }() runErr := c.runSession(ctx, remoteCommand, reader, logWriter) tarErr := <-writeErr return errors.Join(runErr, tarErr) } func (c *Client) runSession(ctx context.Context, command string, stdin io.Reader, logWriter io.Writer) error { if c == nil || c.client == nil { return fmt.Errorf("ssh client is not connected") } session, err := c.client.NewSession() if err != nil { return err } defer session.Close() session.Stdin = stdin if logWriter != nil { session.Stdout = logWriter session.Stderr = logWriter } done := make(chan error, 1) go func() { select { case <-ctx.Done(): _ = c.client.Close() case <-done: } }() err = session.Run(command) done <- nil return err } func privateKeySigner(path string) (ssh.Signer, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } return ssh.ParsePrivateKey(data) } func AuthorizedPublicKey(path string) ([]byte, error) { signer, err := privateKeySigner(path) if err != nil { return nil, err } return ssh.MarshalAuthorizedKey(signer.PublicKey()), nil } func AuthorizedPublicKeyFingerprint(path string) (string, error) { key, err := AuthorizedPublicKey(path) if err != nil { return "", err } sum := sha256.Sum256([]byte(strings.TrimSpace(string(key)))) return hex.EncodeToString(sum[:]), nil } func shellQuote(value string) string { return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" } func writeTarArchive(dst io.Writer, sourceDir string) error { tw := tar.NewWriter(dst) defer tw.Close() sourceDir = filepath.Clean(sourceDir) rootName := filepath.Base(sourceDir) return filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } name := rootName if path != sourceDir { relPath, err := filepath.Rel(sourceDir, path) if err != nil { return err } name = filepath.Join(rootName, relPath) } linkTarget := "" if info.Mode()&os.ModeSymlink != 0 { linkTarget, err = os.Readlink(path) if err != nil { return err } } header, err := tar.FileInfoHeader(info, linkTarget) if err != nil { return err } header.Name = name if err := tw.WriteHeader(header); err != nil { return err } if !info.Mode().IsRegular() { return nil } file, err := os.Open(path) if err != nil { return err } defer file.Close() _, err = io.Copy(tw, file) return err }) } func writeTarEntriesArchive(dst io.Writer, sourceDir string, entries []string) error { tw := tar.NewWriter(dst) defer tw.Close() sourceDir = filepath.Clean(sourceDir) rootName := filepath.Base(sourceDir) uniqueEntries := make([]string, 0, len(entries)) seen := make(map[string]struct{}, len(entries)) for _, entry := range entries { entry = strings.TrimSpace(entry) if entry == "" { continue } entry = filepath.Clean(entry) if entry == "." || entry == ".." || strings.HasPrefix(entry, ".."+string(filepath.Separator)) { return fmt.Errorf("tar entry %q escapes source dir", entry) } if _, ok := seen[entry]; ok { continue } seen[entry] = struct{}{} uniqueEntries = append(uniqueEntries, entry) } sort.Strings(uniqueEntries) for _, entry := range uniqueEntries { fullPath := filepath.Join(sourceDir, entry) info, err := os.Lstat(fullPath) if err != nil { return err } linkTarget := "" if info.Mode()&os.ModeSymlink != 0 { linkTarget, err = os.Readlink(fullPath) if err != nil { return err } } header, err := tar.FileInfoHeader(info, linkTarget) if err != nil { return err } header.Name = path.Join(rootName, filepath.ToSlash(entry)) if err := tw.WriteHeader(header); err != nil { return err } if !info.Mode().IsRegular() { continue } file, err := os.Open(fullPath) if err != nil { return err } if _, err := io.Copy(tw, file); err != nil { _ = file.Close() return err } if err := file.Close(); err != nil { return err } } return nil }