package guest import ( "archive/tar" "context" "errors" "fmt" "io" "net" "os" "path/filepath" "strings" "time" "golang.org/x/crypto/ssh" ) type Client struct { client *ssh.Client } func WaitForSSH(ctx context.Context, address, privateKeyPath string, interval time.Duration) error { if interval <= 0 { interval = time.Second } for { client, err := Dial(ctx, address, privateKeyPath) if err == nil { _ = client.Close() return nil } select { case <-ctx.Done(): return ctx.Err() case <-time.After(interval): } } } func Dial(ctx context.Context, address, privateKeyPath string) (*Client, error) { signer, err := privateKeySigner(privateKeyPath) if err != nil { return nil, err } config := &ssh.ClientConfig{ User: "root", Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 10 * time.Second, } dialer := &net.Dialer{Timeout: 10 * time.Second} conn, err := dialer.DialContext(ctx, "tcp", address) if err != nil { return nil, err } sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, config) if err != nil { _ = conn.Close() return nil, err } client := ssh.NewClient(sshConn, chans, reqs) return &Client{client: client}, nil } func (c *Client) Close() error { if c == nil || c.client == nil { return nil } return c.client.Close() } func (c *Client) RunScript(ctx context.Context, script string, logWriter io.Writer) error { return c.runSession(ctx, "bash -se", strings.NewReader(script), logWriter) } func (c *Client) StreamTar(ctx context.Context, sourceDir, remoteCommand string, logWriter io.Writer) error { reader, writer := io.Pipe() writeErr := make(chan error, 1) go func() { writeErr <- writeTarArchive(writer, sourceDir) _ = writer.Close() }() runErr := c.runSession(ctx, remoteCommand, reader, logWriter) tarErr := <-writeErr return errors.Join(runErr, tarErr) } func (c *Client) runSession(ctx context.Context, command string, stdin io.Reader, logWriter io.Writer) error { if c == nil || c.client == nil { return fmt.Errorf("ssh client is not connected") } session, err := c.client.NewSession() if err != nil { return err } defer session.Close() session.Stdin = stdin if logWriter != nil { session.Stdout = logWriter session.Stderr = logWriter } done := make(chan error, 1) go func() { select { case <-ctx.Done(): _ = c.client.Close() case <-done: } }() err = session.Run(command) done <- nil return err } func privateKeySigner(path string) (ssh.Signer, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } return ssh.ParsePrivateKey(data) } func writeTarArchive(dst io.Writer, sourceDir string) error { tw := tar.NewWriter(dst) defer tw.Close() sourceDir = filepath.Clean(sourceDir) rootName := filepath.Base(sourceDir) return filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } name := rootName if path != sourceDir { relPath, err := filepath.Rel(sourceDir, path) if err != nil { return err } name = filepath.Join(rootName, relPath) } linkTarget := "" if info.Mode()&os.ModeSymlink != 0 { linkTarget, err = os.Readlink(path) if err != nil { return err } } header, err := tar.FileInfoHeader(info, linkTarget) if err != nil { return err } header.Name = name if err := tw.WriteHeader(header); err != nil { return err } if !info.Mode().IsRegular() { return nil } file, err := os.Open(path) if err != nil { return err } defer file.Close() _, err = io.Copy(tw, file) return err }) }