From a3cc29652355bc7836c48c2789b3656344171c94 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Sat, 18 Apr 2026 17:47:24 -0300 Subject: [PATCH] guest: tests for fingerprint, shellQuote, tar-entries edge cases, nil receivers Pure-Go additions (no SSH server fixture): AuthorizedPublicKeyFingerprint, shellQuote escaping, writeTarEntriesArchive error paths (.., ., missing, duplicates, blank entries) and symlink handling, StreamSession/Client nil-receiver safety, WaitForSSH context cancellation. internal/guest coverage 17.8% -> 47.6%. Total 52.1% -> 52.6%. The remaining uncovered paths need a real in-process SSH server; skip. --- internal/guest/ssh_more_test.go | 293 ++++++++++++++++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100644 internal/guest/ssh_more_test.go diff --git a/internal/guest/ssh_more_test.go b/internal/guest/ssh_more_test.go new file mode 100644 index 0000000..271605e --- /dev/null +++ b/internal/guest/ssh_more_test.go @@ -0,0 +1,293 @@ +package guest + +import ( + "archive/tar" + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "io" + "net" + "os" + "path/filepath" + "regexp" + "strings" + "testing" + "time" +) + +func writeTestKey(t *testing.T) string { + t.Helper() + privateKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + keyPath := filepath.Join(t.TempDir(), "id_rsa") + if err := os.WriteFile(keyPath, privateKeyPEM, 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + return keyPath +} + +func TestAuthorizedPublicKeyFingerprint(t *testing.T) { + t.Parallel() + keyPath := writeTestKey(t) + + fp, err := AuthorizedPublicKeyFingerprint(keyPath) + if err != nil { + t.Fatalf("AuthorizedPublicKeyFingerprint: %v", err) + } + if !regexp.MustCompile(`^[0-9a-f]{64}$`).MatchString(fp) { + t.Fatalf("fingerprint = %q, want 64 hex chars", fp) + } + + fp2, err := AuthorizedPublicKeyFingerprint(keyPath) + if err != nil { + t.Fatalf("AuthorizedPublicKeyFingerprint (second): %v", err) + } + if fp != fp2 { + t.Fatalf("fingerprint not deterministic: %q vs %q", fp, fp2) + } +} + +func TestAuthorizedPublicKeyFingerprintMissingFile(t *testing.T) { + t.Parallel() + _, err := AuthorizedPublicKeyFingerprint(filepath.Join(t.TempDir(), "nope")) + if err == nil { + t.Fatal("expected error for missing key file") + } +} + +func TestAuthorizedPublicKeyBadPEM(t *testing.T) { + t.Parallel() + keyPath := filepath.Join(t.TempDir(), "bad") + if err := os.WriteFile(keyPath, []byte("not a private key"), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + if _, err := AuthorizedPublicKey(keyPath); err == nil { + t.Fatal("expected ParsePrivateKey error") + } +} + +func TestShellQuote(t *testing.T) { + t.Parallel() + cases := []struct { + in, want string + }{ + {"", "''"}, + {"simple", "'simple'"}, + {"with space", "'with space'"}, + {"it's", `'it'"'"'s'`}, + {"a'b'c", `'a'"'"'b'"'"'c'`}, + {"/path/to/file", "'/path/to/file'"}, + } + for _, tc := range cases { + got := shellQuote(tc.in) + if got != tc.want { + t.Errorf("shellQuote(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestWriteTarEntriesArchiveRejectsEscape(t *testing.T) { + t.Parallel() + dir := t.TempDir() + var buf bytes.Buffer + err := writeTarEntriesArchive(&buf, dir, []string{"../escape"}) + if err == nil { + t.Fatal("expected error for escaping entry") + } + if !strings.Contains(err.Error(), "escapes source dir") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestWriteTarEntriesArchiveRejectsDot(t *testing.T) { + t.Parallel() + dir := t.TempDir() + var buf bytes.Buffer + for _, bad := range []string{".", ".."} { + if err := writeTarEntriesArchive(&buf, dir, []string{bad}); err == nil { + t.Errorf("expected error for entry %q", bad) + } + } +} + +func TestWriteTarEntriesArchiveDedupsAndSkipsBlank(t *testing.T) { + t.Parallel() + sourceDir := filepath.Join(t.TempDir(), "repo") + if err := os.MkdirAll(sourceDir, 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(filepath.Join(sourceDir, "a.txt"), []byte("A"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + var buf bytes.Buffer + if err := writeTarEntriesArchive(&buf, sourceDir, []string{"a.txt", "a.txt", "", " "}); err != nil { + t.Fatalf("writeTarEntriesArchive: %v", err) + } + + tr := tar.NewReader(&buf) + var names []string + for { + h, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("tar.Next: %v", err) + } + names = append(names, h.Name) + } + if len(names) != 1 || names[0] != "repo/a.txt" { + t.Fatalf("names = %v, want [repo/a.txt]", names) + } +} + +func TestWriteTarEntriesArchiveSymlink(t *testing.T) { + t.Parallel() + sourceDir := filepath.Join(t.TempDir(), "repo") + if err := os.MkdirAll(sourceDir, 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(filepath.Join(sourceDir, "target.txt"), []byte("T"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + linkPath := filepath.Join(sourceDir, "link") + if err := os.Symlink("target.txt", linkPath); err != nil { + t.Skipf("symlink unsupported: %v", err) + } + + var buf bytes.Buffer + if err := writeTarEntriesArchive(&buf, sourceDir, []string{"link"}); err != nil { + t.Fatalf("writeTarEntriesArchive: %v", err) + } + + tr := tar.NewReader(&buf) + h, err := tr.Next() + if err != nil { + t.Fatalf("tar.Next: %v", err) + } + if h.Typeflag != tar.TypeSymlink { + t.Fatalf("typeflag = %v, want TypeSymlink", h.Typeflag) + } + if h.Linkname != "target.txt" { + t.Fatalf("linkname = %q, want target.txt", h.Linkname) + } +} + +func TestWriteTarEntriesArchiveMissingPath(t *testing.T) { + t.Parallel() + sourceDir := t.TempDir() + var buf bytes.Buffer + err := writeTarEntriesArchive(&buf, sourceDir, []string{"missing.txt"}) + if err == nil { + t.Fatal("expected error for missing entry") + } +} + +func TestStreamSessionNilSafe(t *testing.T) { + t.Parallel() + var s *StreamSession + if s.Stdin() != nil || s.Stdout() != nil || s.Stderr() != nil { + t.Fatal("nil StreamSession getters should return nil") + } + if err := s.Wait(); err != nil { + t.Fatalf("nil Wait error: %v", err) + } + if err := s.Close(); err != nil { + t.Fatalf("nil Close error: %v", err) + } +} + +func TestClientNilClose(t *testing.T) { + t.Parallel() + var c *Client + if err := c.Close(); err != nil { + t.Fatalf("nil Close error: %v", err) + } + c2 := &Client{} + if err := c2.Close(); err != nil { + t.Fatalf("empty Close error: %v", err) + } +} + +func TestClientRunScriptOutputNotConnected(t *testing.T) { + t.Parallel() + var c *Client + if _, err := c.RunScriptOutput(context.Background(), "true"); err == nil { + t.Fatal("expected not-connected error") + } + c2 := &Client{} + if _, err := c2.RunScriptOutput(context.Background(), "true"); err == nil { + t.Fatal("expected not-connected error") + } +} + +func TestClientStartCommandNotConnected(t *testing.T) { + t.Parallel() + var c *Client + if _, err := c.StartCommand(context.Background(), "true"); err == nil { + t.Fatal("expected not-connected error") + } +} + +func TestClientRunScriptNotConnected(t *testing.T) { + t.Parallel() + var c *Client + if err := c.RunScript(context.Background(), "true", io.Discard); err == nil { + t.Fatal("expected not-connected error") + } +} + +// freeAddr grabs a loopback port by listening briefly, then closing. Next +// Dial to it deterministically fails with "connection refused" — no real +// server on the far end, no flakiness from random ports being taken. +func freeAddr(t *testing.T) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen: %v", err) + } + addr := ln.Addr().String() + if err := ln.Close(); err != nil { + t.Fatalf("Close listener: %v", err) + } + return addr +} + +func TestWaitForSSHContextCancel(t *testing.T) { + t.Parallel() + keyPath := writeTestKey(t) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + err := WaitForSSH(ctx, freeAddr(t), keyPath, 10*time.Millisecond) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("err = %v, want context.DeadlineExceeded", err) + } + if elapsed := time.Since(start); elapsed > 2*time.Second { + t.Fatalf("took too long: %v", elapsed) + } +} + +func TestDialReturnsErrorForBadKey(t *testing.T) { + t.Parallel() + keyPath := filepath.Join(t.TempDir(), "bogus") + if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + _, err := Dial(context.Background(), freeAddr(t), keyPath) + if err == nil { + t.Fatal("expected error for bad key") + } +}