Pure-Go additions (no SSH server fixture): AuthorizedPublicKeyFingerprint, shellQuote escaping, writeTarEntriesArchive error paths (.., ., missing, duplicates, blank entries) and symlink handling, StreamSession/Client nil-receiver safety, WaitForSSH context cancellation. internal/guest coverage 17.8% -> 47.6%. Total 52.1% -> 52.6%. The remaining uncovered paths need a real in-process SSH server; skip.
293 lines
7.4 KiB
Go
293 lines
7.4 KiB
Go
package guest
|
|
|
|
import (
|
|
"archive/tar"
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func writeTestKey(t *testing.T) string {
|
|
t.Helper()
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
if err != nil {
|
|
t.Fatalf("GenerateKey: %v", err)
|
|
}
|
|
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
|
})
|
|
keyPath := filepath.Join(t.TempDir(), "id_rsa")
|
|
if err := os.WriteFile(keyPath, privateKeyPEM, 0o600); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
return keyPath
|
|
}
|
|
|
|
func TestAuthorizedPublicKeyFingerprint(t *testing.T) {
|
|
t.Parallel()
|
|
keyPath := writeTestKey(t)
|
|
|
|
fp, err := AuthorizedPublicKeyFingerprint(keyPath)
|
|
if err != nil {
|
|
t.Fatalf("AuthorizedPublicKeyFingerprint: %v", err)
|
|
}
|
|
if !regexp.MustCompile(`^[0-9a-f]{64}$`).MatchString(fp) {
|
|
t.Fatalf("fingerprint = %q, want 64 hex chars", fp)
|
|
}
|
|
|
|
fp2, err := AuthorizedPublicKeyFingerprint(keyPath)
|
|
if err != nil {
|
|
t.Fatalf("AuthorizedPublicKeyFingerprint (second): %v", err)
|
|
}
|
|
if fp != fp2 {
|
|
t.Fatalf("fingerprint not deterministic: %q vs %q", fp, fp2)
|
|
}
|
|
}
|
|
|
|
func TestAuthorizedPublicKeyFingerprintMissingFile(t *testing.T) {
|
|
t.Parallel()
|
|
_, err := AuthorizedPublicKeyFingerprint(filepath.Join(t.TempDir(), "nope"))
|
|
if err == nil {
|
|
t.Fatal("expected error for missing key file")
|
|
}
|
|
}
|
|
|
|
func TestAuthorizedPublicKeyBadPEM(t *testing.T) {
|
|
t.Parallel()
|
|
keyPath := filepath.Join(t.TempDir(), "bad")
|
|
if err := os.WriteFile(keyPath, []byte("not a private key"), 0o600); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
if _, err := AuthorizedPublicKey(keyPath); err == nil {
|
|
t.Fatal("expected ParsePrivateKey error")
|
|
}
|
|
}
|
|
|
|
func TestShellQuote(t *testing.T) {
|
|
t.Parallel()
|
|
cases := []struct {
|
|
in, want string
|
|
}{
|
|
{"", "''"},
|
|
{"simple", "'simple'"},
|
|
{"with space", "'with space'"},
|
|
{"it's", `'it'"'"'s'`},
|
|
{"a'b'c", `'a'"'"'b'"'"'c'`},
|
|
{"/path/to/file", "'/path/to/file'"},
|
|
}
|
|
for _, tc := range cases {
|
|
got := shellQuote(tc.in)
|
|
if got != tc.want {
|
|
t.Errorf("shellQuote(%q) = %q, want %q", tc.in, got, tc.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWriteTarEntriesArchiveRejectsEscape(t *testing.T) {
|
|
t.Parallel()
|
|
dir := t.TempDir()
|
|
var buf bytes.Buffer
|
|
err := writeTarEntriesArchive(&buf, dir, []string{"../escape"})
|
|
if err == nil {
|
|
t.Fatal("expected error for escaping entry")
|
|
}
|
|
if !strings.Contains(err.Error(), "escapes source dir") {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWriteTarEntriesArchiveRejectsDot(t *testing.T) {
|
|
t.Parallel()
|
|
dir := t.TempDir()
|
|
var buf bytes.Buffer
|
|
for _, bad := range []string{".", ".."} {
|
|
if err := writeTarEntriesArchive(&buf, dir, []string{bad}); err == nil {
|
|
t.Errorf("expected error for entry %q", bad)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWriteTarEntriesArchiveDedupsAndSkipsBlank(t *testing.T) {
|
|
t.Parallel()
|
|
sourceDir := filepath.Join(t.TempDir(), "repo")
|
|
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
|
|
t.Fatalf("MkdirAll: %v", err)
|
|
}
|
|
if err := os.WriteFile(filepath.Join(sourceDir, "a.txt"), []byte("A"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"a.txt", "a.txt", "", " "}); err != nil {
|
|
t.Fatalf("writeTarEntriesArchive: %v", err)
|
|
}
|
|
|
|
tr := tar.NewReader(&buf)
|
|
var names []string
|
|
for {
|
|
h, err := tr.Next()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("tar.Next: %v", err)
|
|
}
|
|
names = append(names, h.Name)
|
|
}
|
|
if len(names) != 1 || names[0] != "repo/a.txt" {
|
|
t.Fatalf("names = %v, want [repo/a.txt]", names)
|
|
}
|
|
}
|
|
|
|
func TestWriteTarEntriesArchiveSymlink(t *testing.T) {
|
|
t.Parallel()
|
|
sourceDir := filepath.Join(t.TempDir(), "repo")
|
|
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
|
|
t.Fatalf("MkdirAll: %v", err)
|
|
}
|
|
if err := os.WriteFile(filepath.Join(sourceDir, "target.txt"), []byte("T"), 0o644); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
linkPath := filepath.Join(sourceDir, "link")
|
|
if err := os.Symlink("target.txt", linkPath); err != nil {
|
|
t.Skipf("symlink unsupported: %v", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"link"}); err != nil {
|
|
t.Fatalf("writeTarEntriesArchive: %v", err)
|
|
}
|
|
|
|
tr := tar.NewReader(&buf)
|
|
h, err := tr.Next()
|
|
if err != nil {
|
|
t.Fatalf("tar.Next: %v", err)
|
|
}
|
|
if h.Typeflag != tar.TypeSymlink {
|
|
t.Fatalf("typeflag = %v, want TypeSymlink", h.Typeflag)
|
|
}
|
|
if h.Linkname != "target.txt" {
|
|
t.Fatalf("linkname = %q, want target.txt", h.Linkname)
|
|
}
|
|
}
|
|
|
|
func TestWriteTarEntriesArchiveMissingPath(t *testing.T) {
|
|
t.Parallel()
|
|
sourceDir := t.TempDir()
|
|
var buf bytes.Buffer
|
|
err := writeTarEntriesArchive(&buf, sourceDir, []string{"missing.txt"})
|
|
if err == nil {
|
|
t.Fatal("expected error for missing entry")
|
|
}
|
|
}
|
|
|
|
func TestStreamSessionNilSafe(t *testing.T) {
|
|
t.Parallel()
|
|
var s *StreamSession
|
|
if s.Stdin() != nil || s.Stdout() != nil || s.Stderr() != nil {
|
|
t.Fatal("nil StreamSession getters should return nil")
|
|
}
|
|
if err := s.Wait(); err != nil {
|
|
t.Fatalf("nil Wait error: %v", err)
|
|
}
|
|
if err := s.Close(); err != nil {
|
|
t.Fatalf("nil Close error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestClientNilClose(t *testing.T) {
|
|
t.Parallel()
|
|
var c *Client
|
|
if err := c.Close(); err != nil {
|
|
t.Fatalf("nil Close error: %v", err)
|
|
}
|
|
c2 := &Client{}
|
|
if err := c2.Close(); err != nil {
|
|
t.Fatalf("empty Close error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestClientRunScriptOutputNotConnected(t *testing.T) {
|
|
t.Parallel()
|
|
var c *Client
|
|
if _, err := c.RunScriptOutput(context.Background(), "true"); err == nil {
|
|
t.Fatal("expected not-connected error")
|
|
}
|
|
c2 := &Client{}
|
|
if _, err := c2.RunScriptOutput(context.Background(), "true"); err == nil {
|
|
t.Fatal("expected not-connected error")
|
|
}
|
|
}
|
|
|
|
func TestClientStartCommandNotConnected(t *testing.T) {
|
|
t.Parallel()
|
|
var c *Client
|
|
if _, err := c.StartCommand(context.Background(), "true"); err == nil {
|
|
t.Fatal("expected not-connected error")
|
|
}
|
|
}
|
|
|
|
func TestClientRunScriptNotConnected(t *testing.T) {
|
|
t.Parallel()
|
|
var c *Client
|
|
if err := c.RunScript(context.Background(), "true", io.Discard); err == nil {
|
|
t.Fatal("expected not-connected error")
|
|
}
|
|
}
|
|
|
|
// freeAddr grabs a loopback port by listening briefly, then closing. Next
|
|
// Dial to it deterministically fails with "connection refused" — no real
|
|
// server on the far end, no flakiness from random ports being taken.
|
|
func freeAddr(t *testing.T) string {
|
|
t.Helper()
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("net.Listen: %v", err)
|
|
}
|
|
addr := ln.Addr().String()
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("Close listener: %v", err)
|
|
}
|
|
return addr
|
|
}
|
|
|
|
func TestWaitForSSHContextCancel(t *testing.T) {
|
|
t.Parallel()
|
|
keyPath := writeTestKey(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
start := time.Now()
|
|
err := WaitForSSH(ctx, freeAddr(t), keyPath, 10*time.Millisecond)
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
|
t.Fatalf("err = %v, want context.DeadlineExceeded", err)
|
|
}
|
|
if elapsed := time.Since(start); elapsed > 2*time.Second {
|
|
t.Fatalf("took too long: %v", elapsed)
|
|
}
|
|
}
|
|
|
|
func TestDialReturnsErrorForBadKey(t *testing.T) {
|
|
t.Parallel()
|
|
keyPath := filepath.Join(t.TempDir(), "bogus")
|
|
if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
_, err := Dial(context.Background(), freeAddr(t), keyPath)
|
|
if err == nil {
|
|
t.Fatal("expected error for bad key")
|
|
}
|
|
}
|